diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index ff5afddfd..4b2e8b9bd 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -359,7 +359,8 @@ def __new__(cls, name, bases, attrs): new_class.id = field # Set primary key if not defined by the document - new_class._auto_id_field = False + new_class._auto_id_field = getattr(parent_doc_cls, + '_auto_id_field', False) if not new_class._meta.get('id_field'): new_class._auto_id_field = True new_class._meta['id_field'] = 'id' diff --git a/mongoengine/document.py b/mongoengine/document.py index 114778eba..f870a0783 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -13,7 +13,8 @@ BaseDocument, BaseDict, BaseList, ALLOW_INHERITANCE, get_document) from mongoengine.errors import ValidationError -from mongoengine.queryset import OperationError, NotUniqueError, QuerySet +from mongoengine.queryset import (OperationError, NotUniqueError, + QuerySet, transform) from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME from mongoengine.context_managers import switch_db, switch_collection @@ -180,7 +181,7 @@ def _get_collection(cls): def save(self, force_insert=False, validate=True, clean=True, write_concern=None, cascade=None, cascade_kwargs=None, - _refs=None, **kwargs): + _refs=None, save_condition=None, **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. @@ -203,7 +204,8 @@ def save(self, force_insert=False, validate=True, clean=True, :param cascade_kwargs: (optional) kwargs dictionary to be passed throw to cascading saves. Implies ``cascade=True``. :param _refs: A list of processed references used in cascading saves - + :param save_condition: only perform save if matching record in db + satisfies condition(s) (e.g., version number) .. versionchanged:: 0.5 In existing documents it only saves changed fields using set / unset. Saves are cascaded and any @@ -217,6 +219,9 @@ def save(self, force_insert=False, validate=True, clean=True, meta['cascade'] = True. Also you can pass different kwargs to the cascade save using cascade_kwargs which overwrites the existing kwargs with custom values. + .. versionchanged:: 0.8.5 + Optional save_condition that only overwrites existing documents + if the condition is satisfied in the current db record. """ signals.pre_save.send(self.__class__, document=self) @@ -230,7 +235,8 @@ def save(self, force_insert=False, validate=True, clean=True, created = ('_id' not in doc or self._created or force_insert) - signals.pre_save_post_validation.send(self.__class__, document=self, created=created) + signals.pre_save_post_validation.send(self.__class__, document=self, + created=created) try: collection = self._get_collection() @@ -243,7 +249,12 @@ def save(self, force_insert=False, validate=True, clean=True, object_id = doc['_id'] updates, removals = self._delta() # Need to add shard key to query, or you get an error - select_dict = {'_id': object_id} + if save_condition is not None: + select_dict = transform.query(self.__class__, + **save_condition) + else: + select_dict = {} + select_dict['_id'] = object_id shard_key = self.__class__._meta.get('shard_key', tuple()) for k in shard_key: actual_key = self._db_field_map.get(k, k) @@ -263,10 +274,12 @@ def is_new_object(last_error): if removals: update_query["$unset"] = removals if updates or removals: + upsert = save_condition is None last_error = collection.update(select_dict, update_query, - upsert=True, **write_concern) + upsert=upsert, **write_concern) created = is_new_object(last_error) + if cascade is None: cascade = self._meta.get('cascade', False) or cascade_kwargs is not None diff --git a/tests/document/instance.py b/tests/document/instance.py index 07db85a08..acb26c6d1 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -820,6 +820,80 @@ class Person(Document): p1.reload() self.assertEqual(p1.name, p.parent.name) + def test_save_atomicity_condition(self): + + class Widget(Document): + toggle = BooleanField(default=False) + count = IntField(default=0) + save_id = UUIDField() + + def flip(widget): + widget.toggle = not widget.toggle + widget.count += 1 + + def UUID(i): + return uuid.UUID(int=i) + + Widget.drop_collection() + + w1 = Widget(toggle=False, save_id=UUID(1)) + + # ignore save_condition on new record creation + w1.save(save_condition={'save_id':UUID(42)}) + w1.reload() + self.assertFalse(w1.toggle) + self.assertEqual(w1.save_id, UUID(1)) + self.assertEqual(w1.count, 0) + + # mismatch in save_condition prevents save + flip(w1) + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 1) + w1.save(save_condition={'save_id':UUID(42)}) + w1.reload() + self.assertFalse(w1.toggle) + self.assertEqual(w1.count, 0) + + # matched save_condition allows save + flip(w1) + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 1) + w1.save(save_condition={'save_id':UUID(1)}) + w1.reload() + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 1) + + # save_condition can be used to ensure atomic read & updates + # i.e., prevent interleaved reads and writes from separate contexts + w2 = Widget.objects.get() + self.assertEqual(w1, w2) + old_id = w1.save_id + + flip(w1) + w1.save_id = UUID(2) + w1.save(save_condition={'save_id':old_id}) + w1.reload() + self.assertFalse(w1.toggle) + self.assertEqual(w1.count, 2) + flip(w2) + flip(w2) + w2.save(save_condition={'save_id':old_id}) + w2.reload() + self.assertFalse(w2.toggle) + self.assertEqual(w2.count, 2) + + # save_condition uses mongoengine-style operator syntax + flip(w1) + w1.save(save_condition={'count__lt':w1.count}) + w1.reload() + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 3) + flip(w1) + w1.save(save_condition={'count__gte':w1.count}) + w1.reload() + self.assertTrue(w1.toggle) + self.assertEqual(w1.count, 3) + def test_update(self): """Ensure that an existing document is updated instead of be overwritten."""