Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mongoengine/base/metaclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def __new__(cls, name, bases, attrs):
new_class.id = field

# Set primary key if not defined by the document
new_class._auto_id_field = False
new_class._auto_id_field = getattr(parent_doc_cls,
'_auto_id_field', False)
if not new_class._meta.get('id_field'):
new_class._auto_id_field = True
new_class._meta['id_field'] = 'id'
Expand Down
25 changes: 19 additions & 6 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
BaseDocument, BaseDict, BaseList,
ALLOW_INHERITANCE, get_document)
from mongoengine.errors import ValidationError
from mongoengine.queryset import OperationError, NotUniqueError, QuerySet
from mongoengine.queryset import (OperationError, NotUniqueError,
QuerySet, transform)
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
from mongoengine.context_managers import switch_db, switch_collection

Expand Down Expand Up @@ -180,7 +181,7 @@ def _get_collection(cls):

def save(self, force_insert=False, validate=True, clean=True,
write_concern=None, cascade=None, cascade_kwargs=None,
_refs=None, **kwargs):
_refs=None, save_condition=None, **kwargs):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
Expand All @@ -203,7 +204,8 @@ def save(self, force_insert=False, validate=True, clean=True,
:param cascade_kwargs: (optional) kwargs dictionary to be passed throw
to cascading saves. Implies ``cascade=True``.
:param _refs: A list of processed references used in cascading saves

:param save_condition: only perform save if matching record in db
satisfies condition(s) (e.g., version number)
.. versionchanged:: 0.5
In existing documents it only saves changed fields using
set / unset. Saves are cascaded and any
Expand All @@ -217,6 +219,9 @@ def save(self, force_insert=False, validate=True, clean=True,
meta['cascade'] = True. Also you can pass different kwargs to
the cascade save using cascade_kwargs which overwrites the
existing kwargs with custom values.
.. versionchanged:: 0.8.5
Optional save_condition that only overwrites existing documents
if the condition is satisfied in the current db record.
"""
signals.pre_save.send(self.__class__, document=self)

Expand All @@ -230,7 +235,8 @@ def save(self, force_insert=False, validate=True, clean=True,

created = ('_id' not in doc or self._created or force_insert)

signals.pre_save_post_validation.send(self.__class__, document=self, created=created)
signals.pre_save_post_validation.send(self.__class__, document=self,
created=created)

try:
collection = self._get_collection()
Expand All @@ -243,7 +249,12 @@ def save(self, force_insert=False, validate=True, clean=True,
object_id = doc['_id']
updates, removals = self._delta()
# Need to add shard key to query, or you get an error
select_dict = {'_id': object_id}
if save_condition is not None:
select_dict = transform.query(self.__class__,
**save_condition)
else:
select_dict = {}
select_dict['_id'] = object_id
shard_key = self.__class__._meta.get('shard_key', tuple())
for k in shard_key:
actual_key = self._db_field_map.get(k, k)
Expand All @@ -263,10 +274,12 @@ def is_new_object(last_error):
if removals:
update_query["$unset"] = removals
if updates or removals:
upsert = save_condition is None
last_error = collection.update(select_dict, update_query,
upsert=True, **write_concern)
upsert=upsert, **write_concern)
created = is_new_object(last_error)


if cascade is None:
cascade = self._meta.get('cascade', False) or cascade_kwargs is not None

Expand Down
74 changes: 74 additions & 0 deletions tests/document/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,80 @@ class Person(Document):
p1.reload()
self.assertEqual(p1.name, p.parent.name)

def test_save_atomicity_condition(self):

class Widget(Document):
toggle = BooleanField(default=False)
count = IntField(default=0)
save_id = UUIDField()

def flip(widget):
widget.toggle = not widget.toggle
widget.count += 1

def UUID(i):
return uuid.UUID(int=i)

Widget.drop_collection()

w1 = Widget(toggle=False, save_id=UUID(1))

# ignore save_condition on new record creation
w1.save(save_condition={'save_id':UUID(42)})
w1.reload()
self.assertFalse(w1.toggle)
self.assertEqual(w1.save_id, UUID(1))
self.assertEqual(w1.count, 0)

# mismatch in save_condition prevents save
flip(w1)
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 1)
w1.save(save_condition={'save_id':UUID(42)})
w1.reload()
self.assertFalse(w1.toggle)
self.assertEqual(w1.count, 0)

# matched save_condition allows save
flip(w1)
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 1)
w1.save(save_condition={'save_id':UUID(1)})
w1.reload()
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 1)

# save_condition can be used to ensure atomic read & updates
# i.e., prevent interleaved reads and writes from separate contexts
w2 = Widget.objects.get()
self.assertEqual(w1, w2)
old_id = w1.save_id

flip(w1)
w1.save_id = UUID(2)
w1.save(save_condition={'save_id':old_id})
w1.reload()
self.assertFalse(w1.toggle)
self.assertEqual(w1.count, 2)
flip(w2)
flip(w2)
w2.save(save_condition={'save_id':old_id})
w2.reload()
self.assertFalse(w2.toggle)
self.assertEqual(w2.count, 2)

# save_condition uses mongoengine-style operator syntax
flip(w1)
w1.save(save_condition={'count__lt':w1.count})
w1.reload()
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 3)
flip(w1)
w1.save(save_condition={'count__gte':w1.count})
w1.reload()
self.assertTrue(w1.toggle)
self.assertEqual(w1.count, 3)

def test_update(self):
"""Ensure that an existing document is updated instead of be
overwritten."""
Expand Down