Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,40 @@ def update_one(self, upsert=False, write_concern=None, **update):
"""
return self.update(
upsert=upsert, multi=False, write_concern=write_concern, **update)

def andModify(self, upsert=False, sort=None, full_response=False, **update):
"""Perform an atomic update on the fields matched by the query. Returns
a document. Essentially a wrapper for PyMongo's find_and_modify; itself
a wrapper for MongoDB's findAndModify.

:param upsert: Any existing document with that "_id" is overwritten.
:param sort: a list of (key, direction) pairs specifying the sort order
for this query.
:param full_response: return the entire response object from the server.
:param update: Django-style update keyword arguments

.. versionadded:: TBC
"""
if not update and not upsert:
raise OperationError("No update parameters, must either update or remove")

queryset = self.clone()
query = queryset._query
update = transform.update(queryset._document, **update)

try:
result = queryset._collection.find_and_modify(query, update, upsert=upsert, sort=sort, full_response=full_response)
if full_response:
if not result['value'] is None:
result['value'] = self._document._from_son(result['value'])
else:
if not result is None:
result = self._document._from_son(result)
return result

except pymongo.errors.OperationFailure, err:
raise OperationError(u'findAndModify failed (%s)' % unicode(err))

def with_id(self, object_id):
"""Retrieve the object matching the id provided. Uses `object_id` only
and raises InvalidQueryError if a filter has been applied. Returns
Expand Down Expand Up @@ -1476,4 +1509,4 @@ def _ensure_indexes(self):
msg = ("Doc.objects()._ensure_indexes() is deprecated. "
"Use Doc.ensure_indexes() instead.")
warnings.warn(msg, DeprecationWarning)
self._document.__class__.ensure_indexes()
self._document.__class__.ensure_indexes()
96 changes: 95 additions & 1 deletion tests/queryset/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def test_upsert(self):
bob = self.Person.objects.first()
self.assertEqual("Bob", bob.name)
self.assertEqual(30, bob.age)

def test_upsert_one(self):
self.Person.drop_collection()

Expand All @@ -576,6 +576,46 @@ def test_upsert_one(self):
self.assertEqual("Bob", bob.name)
self.assertEqual(30, bob.age)

def test_andModify_upsert(self):
"""Ensure that andModify can add a new document
"""

self.Person.drop_collection()

result = self.Person.objects(name="Bob").andModify(full_response=True, set__age=30)
self.assertEqual(result['value'], None)
bob = self.Person.objects(name="Bob").first()
self.assertEqual(bob, None)

result = self.Person.objects(name="Bob").andModify(upsert=True, full_response=True, set__age=30)
self.assertEqual(result['value'], None)
self.assertEqual(result['lastErrorObject']['updatedExisting'], False)
self.assertTrue(isinstance(result['lastErrorObject']['upserted'], ObjectId))

bob = self.Person.objects(name="Bob").first()
self.assertEqual("Bob", bob.name)
self.assertEqual(30, bob.age)

def test_andModify_sort(self):
"""Ensure sort can be used to select the record to find_and_modify
"""

self.Person.drop_collection()

bob = self.Person(name="Bob", age=30); bob.save()
betty = self.Person(name="Betty", age=30); betty.save()

result = self.Person.objects(age=30).andModify(sort=[('name', 1)], set__age=31)
self.assertEqual("Betty", result.name)
result = self.Person.objects(age=30).andModify(sort=[('name', 1)], set__age=31)
self.assertEqual("Bob", result.name)

result = self.Person.objects(age=31).andModify(sort=[('name', -1)], set__age=32)
self.assertEqual("Bob", result.name)

result = self.Person.objects.andModify(sort=[('age', 1)], set__age=32)
self.assertEqual("Betty", result.name)

def test_set_on_insert(self):
self.Person.drop_collection()

Expand Down Expand Up @@ -1445,6 +1485,60 @@ class BlogPost(Document):

BlogPost.drop_collection()

def test_andModify(self):
"""Ensure that andModify updates a record atomically and returns the
old record.
"""

class BlogPost(Document):
title = StringField()
hits = IntField()
tags = ListField(StringField())

BlogPost.drop_collection()

post = BlogPost(title="Test Post", hits=5, tags=['test'])
post.save()

result = BlogPost.objects(title="Test Post").andModify(set__hits=10)
self.assertEqual(result.hits, 5)

result = BlogPost.objects(title="Test Post").andModify(inc__hits=1)
self.assertEqual(result.hits, 10)

result = BlogPost.objects(title="Test Post").andModify(dec__hits=1)
self.assertEqual(result.hits, 11)

result = BlogPost.objects(title="Test Post").andModify(push__tags='mongo')
self.assertEqual(result.hits, 10)

result = BlogPost.objects(title="Test Post").andModify(push_all__tags=['db', 'nosql'])
self.assertTrue('mongo' in result.tags)

post.reload()
self.assertTrue('db' in post.tags and 'nosql' in post.tags)

tags = post.tags
result = BlogPost.objects(title="Test Post").andModify(pop__tags=1)
self.assertEqual(result.tags, tags)
post.reload()
self.assertEqual(post.tags, tags[:-1])

result = BlogPost.objects(title="Test Post").andModify(add_to_set__tags='unique')
self.assertEqual(result.tags.count('unique'), 0)
result = BlogPost.objects(title="Test Post").andModify(add_to_set__tags='unique')
self.assertEqual(result.tags.count('unique'), 1)
post.reload()
self.assertEqual(post.tags.count('unique'), 1)

self.assertNotEqual(post.hits, None)
result = BlogPost.objects(title="Test Post").andModify(unset__hits=1)
self.assertEqual(result.hits, 10)
post.reload()
self.assertEqual(post.hits, None)

BlogPost.drop_collection()

def test_update_push_and_pull_add_to_set(self):
"""Ensure that the 'pull' update operation works correctly.
"""
Expand Down