diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index d3bb4c4b3..23f6db594 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -467,7 +467,40 @@ def update_one(self, upsert=False, write_concern=None, **update): """ return self.update( upsert=upsert, multi=False, write_concern=write_concern, **update) + + def andModify(self, upsert=False, sort=None, full_response=False, **update): + """Perform an atomic update on the fields matched by the query. Returns + a document. Essentially a wrapper for PyMongo's find_and_modify; itself + a wrapper for MongoDB's findAndModify. + :param upsert: Any existing document with that "_id" is overwritten. + :param sort: a list of (key, direction) pairs specifying the sort order + for this query. + :param full_response: return the entire response object from the server. + :param update: Django-style update keyword arguments + + .. versionadded:: TBC + """ + if not update and not upsert: + raise OperationError("No update parameters, must either update or remove") + + queryset = self.clone() + query = queryset._query + update = transform.update(queryset._document, **update) + + try: + result = queryset._collection.find_and_modify(query, update, upsert=upsert, sort=sort, full_response=full_response) + if full_response: + if not result['value'] is None: + result['value'] = self._document._from_son(result['value']) + else: + if not result is None: + result = self._document._from_son(result) + return result + + except pymongo.errors.OperationFailure, err: + raise OperationError(u'findAndModify failed (%s)' % unicode(err)) + def with_id(self, object_id): """Retrieve the object matching the id provided. Uses `object_id` only and raises InvalidQueryError if a filter has been applied. Returns @@ -1476,4 +1509,4 @@ def _ensure_indexes(self): msg = ("Doc.objects()._ensure_indexes() is deprecated. " "Use Doc.ensure_indexes() instead.") warnings.warn(msg, DeprecationWarning) - self._document.__class__.ensure_indexes() \ No newline at end of file + self._document.__class__.ensure_indexes() diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index c56b31eb7..845b18f2d 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -566,7 +566,7 @@ def test_upsert(self): bob = self.Person.objects.first() self.assertEqual("Bob", bob.name) self.assertEqual(30, bob.age) - + def test_upsert_one(self): self.Person.drop_collection() @@ -576,6 +576,46 @@ def test_upsert_one(self): self.assertEqual("Bob", bob.name) self.assertEqual(30, bob.age) + def test_andModify_upsert(self): + """Ensure that andModify can add a new document + """ + + self.Person.drop_collection() + + result = self.Person.objects(name="Bob").andModify(full_response=True, set__age=30) + self.assertEqual(result['value'], None) + bob = self.Person.objects(name="Bob").first() + self.assertEqual(bob, None) + + result = self.Person.objects(name="Bob").andModify(upsert=True, full_response=True, set__age=30) + self.assertEqual(result['value'], None) + self.assertEqual(result['lastErrorObject']['updatedExisting'], False) + self.assertTrue(isinstance(result['lastErrorObject']['upserted'], ObjectId)) + + bob = self.Person.objects(name="Bob").first() + self.assertEqual("Bob", bob.name) + self.assertEqual(30, bob.age) + + def test_andModify_sort(self): + """Ensure sort can be used to select the record to find_and_modify + """ + + self.Person.drop_collection() + + bob = self.Person(name="Bob", age=30); bob.save() + betty = self.Person(name="Betty", age=30); betty.save() + + result = self.Person.objects(age=30).andModify(sort=[('name', 1)], set__age=31) + self.assertEqual("Betty", result.name) + result = self.Person.objects(age=30).andModify(sort=[('name', 1)], set__age=31) + self.assertEqual("Bob", result.name) + + result = self.Person.objects(age=31).andModify(sort=[('name', -1)], set__age=32) + self.assertEqual("Bob", result.name) + + result = self.Person.objects.andModify(sort=[('age', 1)], set__age=32) + self.assertEqual("Betty", result.name) + def test_set_on_insert(self): self.Person.drop_collection() @@ -1445,6 +1485,60 @@ class BlogPost(Document): BlogPost.drop_collection() + def test_andModify(self): + """Ensure that andModify updates a record atomically and returns the + old record. + """ + + class BlogPost(Document): + title = StringField() + hits = IntField() + tags = ListField(StringField()) + + BlogPost.drop_collection() + + post = BlogPost(title="Test Post", hits=5, tags=['test']) + post.save() + + result = BlogPost.objects(title="Test Post").andModify(set__hits=10) + self.assertEqual(result.hits, 5) + + result = BlogPost.objects(title="Test Post").andModify(inc__hits=1) + self.assertEqual(result.hits, 10) + + result = BlogPost.objects(title="Test Post").andModify(dec__hits=1) + self.assertEqual(result.hits, 11) + + result = BlogPost.objects(title="Test Post").andModify(push__tags='mongo') + self.assertEqual(result.hits, 10) + + result = BlogPost.objects(title="Test Post").andModify(push_all__tags=['db', 'nosql']) + self.assertTrue('mongo' in result.tags) + + post.reload() + self.assertTrue('db' in post.tags and 'nosql' in post.tags) + + tags = post.tags + result = BlogPost.objects(title="Test Post").andModify(pop__tags=1) + self.assertEqual(result.tags, tags) + post.reload() + self.assertEqual(post.tags, tags[:-1]) + + result = BlogPost.objects(title="Test Post").andModify(add_to_set__tags='unique') + self.assertEqual(result.tags.count('unique'), 0) + result = BlogPost.objects(title="Test Post").andModify(add_to_set__tags='unique') + self.assertEqual(result.tags.count('unique'), 1) + post.reload() + self.assertEqual(post.tags.count('unique'), 1) + + self.assertNotEqual(post.hits, None) + result = BlogPost.objects(title="Test Post").andModify(unset__hits=1) + self.assertEqual(result.hits, 10) + post.reload() + self.assertEqual(post.hits, None) + + BlogPost.drop_collection() + def test_update_push_and_pull_add_to_set(self): """Ensure that the 'pull' update operation works correctly. """