Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple fixes for cascade=True save issues #2615

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,5 @@ that much better:
* Timothé Perez (https://github.com/AchilleAsh)
* oleksandr-l5 (https://github.com/oleksandr-l5)
* Ido Shraga (https://github.com/idoshr)
* Nick Freville (https://github.com/nickfrev)
* Terence Honles (https://github.com/terencehonles)
1 change: 1 addition & 0 deletions mongoengine/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"ComplexBaseField",
"ObjectIdField",
"GeoJsonBaseField",
"SaveableBaseField",
# metaclasses
"DocumentMetaclass",
"TopLevelDocumentMetaclass",
Expand Down
21 changes: 19 additions & 2 deletions mongoengine/base/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mongoengine.common import _import_class
from mongoengine.errors import DeprecatedError, ValidationError

__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
__all__ = ("BaseField", "SaveableBaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")


class BaseField:
Expand Down Expand Up @@ -259,7 +259,14 @@ def owner_document(self, owner_document):
self._set_owner_document(owner_document)


class ComplexBaseField(BaseField):
class SaveableBaseField(BaseField):
"""A base class that dictates a field has the ability to save.
"""
def save():
pass


class ComplexBaseField(SaveableBaseField):
"""Handles complex fields, such as lists / dictionaries.

Allows for nesting of embedded documents inside complex types.
Expand Down Expand Up @@ -483,6 +490,16 @@ def validate(self, value):
if self.required and not value:
self.error("Field is required and cannot be empty")

def save(self, instance, **kwargs):
Document = _import_class("Document")
value = instance._data.get(self.name)

for ref in value:
if isinstance(ref, SaveableBaseField):
ref.save(self, **kwargs)
elif isinstance(ref, Document):
ref.save(**kwargs)

def prepare_query_value(self, op, value):
return self.to_mongo(value)

Expand Down
154 changes: 86 additions & 68 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import re

import pymongo
from bson import SON
from bson.dbref import DBRef
from pymongo.read_preferences import ReadPreference

from mongoengine import signals
from mongoengine.base import (
BaseDict,
BaseDocument,
SaveableBaseField,
BaseList,
DocumentMetaclass,
EmbeddedDocumentList,
Expand Down Expand Up @@ -385,44 +387,34 @@ def save(
the cascade save using cascade_kwargs which overwrites the
existing kwargs with custom values.
"""
signal_kwargs = signal_kwargs or {}

if self._meta.get("abstract"):
raise InvalidDocumentError("Cannot save an abstract document.")

signals.pre_save.send(self.__class__, document=self, **signal_kwargs)

if validate:
self.validate(clean=clean)

if write_concern is None:
write_concern = {}
# Used to avoid saving a document that is already saving (infinite loops)
# this can be caused by the cascade save and circular references
if getattr(self, "_is_saving", False):
return
self._is_saving = True

doc_id = self.to_mongo(fields=[self._meta["id_field"]])
created = "_id" not in doc_id or self._created or force_insert
try:
signal_kwargs = signal_kwargs or {}

signals.pre_save_post_validation.send(
self.__class__, document=self, created=created, **signal_kwargs
)
# it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation
doc = self.to_mongo()
if write_concern is None:
write_concern = {}

if self._meta.get("auto_create_index", True):
self.ensure_indexes()

try:
# Save a new document or update an existing one
if created:
object_id = self._save_create(doc, force_insert, write_concern)
else:
object_id, created = self._save_update(
doc, save_condition, write_concern
)
if self._meta.get("abstract"):
raise InvalidDocumentError("Cannot save an abstract document.")

# Cascade save before validation to avoid child not existing errors
if cascade is None:
cascade = self._meta.get("cascade", False) or cascade_kwargs is not None

has_placeholder_saved = False

if cascade:
# If a cascade will occur save a placeholder version of this document to
# avoid issues with cyclic saves if this doc has not been created yet
if self.id is None:
self._save_place_holder(force_insert, write_concern)
has_placeholder_saved = True

kwargs = {
"force_insert": force_insert,
"validate": validate,
Expand All @@ -434,31 +426,74 @@ def save(
kwargs["_refs"] = _refs
self.cascade_save(**kwargs)

except pymongo.errors.DuplicateKeyError as err:
message = "Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % err)
except pymongo.errors.OperationFailure as err:
message = "Could not save document (%s)"
if re.match("^E1100[01] duplicate key", str(err)):
# E11000 - duplicate key error index
# E11001 - duplicate key on update
# update force_insert to reflect that we might have already run the insert for
# the placeholder
force_insert = force_insert and not has_placeholder_saved

signals.pre_save.send(self.__class__, document=self, **signal_kwargs)

if validate:
self.validate(clean=clean)

doc_id = self.to_mongo(fields=[self._meta["id_field"]])
created = "_id" not in doc_id or self._created or force_insert

signals.pre_save_post_validation.send(
self.__class__, document=self, created=created, **signal_kwargs
)
# it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation
doc = self.to_mongo()

if self._meta.get("auto_create_index", True):
self.ensure_indexes()

try:
# Save a new document or update an existing one
if created:
object_id = self._save_create(doc, force_insert, write_concern)
else:
object_id, created = self._save_update(
doc, save_condition, write_concern
)
except pymongo.errors.DuplicateKeyError as err:
message = "Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % err)
raise OperationError(message % err)
except pymongo.errors.OperationFailure as err:
message = "Could not save document (%s)"
if re.match("^E1100[01] duplicate key", str(err)):
# E11000 - duplicate key error index
# E11001 - duplicate key on update
message = "Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % err)
raise OperationError(message % err)

# Make sure we store the PK on this document now that it's saved
id_field = self._meta["id_field"]
if created or id_field not in self._meta.get("shard_key", []):
self[id_field] = self._fields[id_field].to_python(object_id)

signals.post_save.send(
self.__class__, document=self, created=created, **signal_kwargs
)

# Make sure we store the PK on this document now that it's saved
id_field = self._meta["id_field"]
if created or id_field not in self._meta.get("shard_key", []):
self[id_field] = self._fields[id_field].to_python(object_id)
self._clear_changed_fields()
self._created = False
except Exception as e:
raise e
finally:
self._is_saving = False

signals.post_save.send(
self.__class__, document=self, created=created, **signal_kwargs
)
return self

self._clear_changed_fields()
self._created = False
def _save_place_holder(self, force_insert, write_concern):
"""Save a temp placeholder to the db with nothing but the ID.
"""
data = SON()

return self
object_id = self._save_create(data, force_insert, write_concern)

id_field = self._meta["id_field"]
self[id_field] = self._fields[id_field].to_python(object_id)

def _save_create(self, doc, force_insert, write_concern):
"""Save a new document.
Expand Down Expand Up @@ -556,28 +591,11 @@ def cascade_save(self, **kwargs):
"""Recursively save any references and generic references on the
document.
"""
_refs = kwargs.get("_refs") or []

ReferenceField = _import_class("ReferenceField")
GenericReferenceField = _import_class("GenericReferenceField")

for name, cls in self._fields.items():
if not isinstance(cls, (ReferenceField, GenericReferenceField)):
continue

ref = self._data.get(name)
if not ref or isinstance(ref, DBRef):
if not isinstance(cls, SaveableBaseField):
continue

if not getattr(ref, "_changed_fields", True):
continue

ref_id = f"{ref.__class__.__name__},{str(ref._data)}"
if ref and ref_id not in _refs:
_refs.append(ref_id)
kwargs["_refs"] = _refs
ref.save(**kwargs)
ref._changed_fields = []
cls.save(self, **kwargs)

@property
def _qs(self):
Expand Down
25 changes: 23 additions & 2 deletions mongoengine/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from mongoengine.base import (
BaseDocument,
BaseField,
SaveableBaseField,
ComplexBaseField,
GeoJsonBaseField,
LazyReference,
Expand Down Expand Up @@ -1123,7 +1124,7 @@ def __init__(self, field=None, *args, **kwargs):
super().__init__(field=field, *args, **kwargs)


class ReferenceField(BaseField):
class ReferenceField(SaveableBaseField):
"""A reference to a document that will be automatically dereferenced on
access (lazily).

Expand Down Expand Up @@ -1295,6 +1296,16 @@ def validate(self, value):
"saved to the database"
)

def save(self, instance, **kwargs):
ref = instance._data.get(self.name)
if not ref or isinstance(ref, DBRef):
return

if not getattr(self, "_changed_fields", True):
return

ref.save(**kwargs)

def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)

Expand Down Expand Up @@ -1464,7 +1475,7 @@ def sync_all(self):
self.owner_document.objects(**filter_kwargs).update(**update_kwargs)


class GenericReferenceField(BaseField):
class GenericReferenceField(SaveableBaseField):
"""A reference to *any* :class:`~mongoengine.document.Document` subclass
that will be automatically dereferenced on access (lazily).

Expand Down Expand Up @@ -1546,6 +1557,16 @@ def validate(self, value):
" saved to the database"
)

def save(self, instance, **kwargs):
ref = instance._data.get(self.name)
if not ref or isinstance(ref, DBRef):
return

if not getattr(ref, "_changed_fields", True):
return

ref.save(**kwargs)

def to_mongo(self, document):
if document is None:
return None
Expand Down
Loading