From efca40d2844eef80a357111830856c4e268ce158 Mon Sep 17 00:00:00 2001 From: Nicholas Freville Date: Tue, 14 Dec 2021 12:00:46 -0500 Subject: [PATCH 1/8] Issue #1236 Save the children documents first to avoid the issue where a parent cannot save due to having new children documents. --- mongoengine/document.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mongoengine/document.py b/mongoengine/document.py index e56a9e7c2..126f626b1 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -411,14 +411,6 @@ def save( self.ensure_indexes() try: - # Save a new document or update an existing one - if created: - object_id = self._save_create(doc, force_insert, write_concern) - else: - object_id, created = self._save_update( - doc, save_condition, write_concern - ) - if cascade is None: cascade = self._meta.get("cascade", False) or cascade_kwargs is not None @@ -434,6 +426,14 @@ def save( kwargs["_refs"] = _refs self.cascade_save(**kwargs) + # Save a new document or update an existing one + if created: + object_id = self._save_create(doc, force_insert, write_concern) + else: + object_id, created = self._save_update( + doc, save_condition, write_concern + ) + except pymongo.errors.DuplicateKeyError as err: message = "Tried to save duplicate unique keys (%s)" raise NotUniqueError(message % err) From bb6f7d84e86a20e6de61e7a881b76bf6d7e3f211 Mon Sep 17 00:00:00 2001 From: Nicholas Freville Date: Tue, 14 Dec 2021 13:05:11 -0500 Subject: [PATCH 2/8] Added Testing for saving new ReferenceField with cascade --- tests/document/test_class_methods.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index f82808ba0..5fe2f240a 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -344,6 +344,25 @@ class Person(Document): Person.drop_collection() + def test_save_with_cascade_on_new_referencefield(self): + """Ensure that a new and unsaved ReferenceField is saved before + the parent Document is saved to avoid validation issues. + """ + + class Job(Document): + employee = ReferenceField(self.Person) + + person = self.Person(name="Test User") + job = Job(employee=person) + job.save(cascade=True) + + employee_obj = self.Person.objects[0] + assert employee_obj["name"] == "Test User" + + job_obj = Job.objects[0] + assert job_obj.employee == job.employee + + if __name__ == "__main__": unittest.main() From e9a9dba5bf561055a0a9b9c932e8d05309a01a66 Mon Sep 17 00:00:00 2001 From: Nicholas Freville Date: Tue, 14 Dec 2021 13:05:58 -0500 Subject: [PATCH 3/8] Moved cascade save handling higher up to avoid validation errors. --- mongoengine/document.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/mongoengine/document.py b/mongoengine/document.py index 126f626b1..3071a3b7e 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -390,6 +390,22 @@ def save( if self._meta.get("abstract"): raise InvalidDocumentError("Cannot save an abstract document.") + # Cascade save first before saving document + if cascade is None: + cascade = self._meta.get("cascade", False) or cascade_kwargs is not None + + if cascade: + kwargs = { + "force_insert": force_insert, + "validate": validate, + "write_concern": write_concern, + "cascade": cascade, + } + if cascade_kwargs: # Allow granular control over cascades + kwargs.update(cascade_kwargs) + kwargs["_refs"] = _refs + self.cascade_save(**kwargs) + signals.pre_save.send(self.__class__, document=self, **signal_kwargs) if validate: @@ -411,21 +427,6 @@ def save( self.ensure_indexes() try: - if cascade is None: - cascade = self._meta.get("cascade", False) or cascade_kwargs is not None - - if cascade: - kwargs = { - "force_insert": force_insert, - "validate": validate, - "write_concern": write_concern, - "cascade": cascade, - } - if cascade_kwargs: # Allow granular control over cascades - kwargs.update(cascade_kwargs) - kwargs["_refs"] = _refs - self.cascade_save(**kwargs) - # Save a new document or update an existing one if created: object_id = self._save_create(doc, force_insert, write_concern) @@ -433,7 +434,6 @@ def save( object_id, created = self._save_update( doc, save_condition, write_concern ) - except pymongo.errors.DuplicateKeyError as err: message = "Tried to save duplicate unique keys (%s)" raise NotUniqueError(message % err) From e0cc779b3f3a53c205e82c52dfdbbeb5f4dfcdfa Mon Sep 17 00:00:00 2001 From: Nicholas Freville Date: Tue, 14 Dec 2021 13:14:37 -0500 Subject: [PATCH 4/8] Updated Authors --- AUTHORS | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS b/AUTHORS index 60663940a..f6d2ba2d9 100644 --- a/AUTHORS +++ b/AUTHORS @@ -263,3 +263,4 @@ that much better: * Timothé Perez (https://github.com/AchilleAsh) * oleksandr-l5 (https://github.com/oleksandr-l5) * Ido Shraga (https://github.com/idoshr) + * Nick Freville (https://github.com/nickfrev) From 979e4918c79c978d0ee24a161230db1c6b513d9c Mon Sep 17 00:00:00 2001 From: Nicholas Freville Date: Tue, 14 Dec 2021 13:50:25 -0500 Subject: [PATCH 5/8] Removed unnecessary newlines at end of test. --- tests/document/test_class_methods.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index 5fe2f240a..ea6d51bc1 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -362,7 +362,5 @@ class Job(Document): job_obj = Job.objects[0] assert job_obj.employee == job.employee - - if __name__ == "__main__": unittest.main() From 7fbdf8b76c20cc89e746eb29a6408a7a4caf937f Mon Sep 17 00:00:00 2001 From: Nicholas Freville Date: Mon, 20 Dec 2021 22:29:57 -0500 Subject: [PATCH 6/8] Modularize the cascade save event - Created a new base field (SaveableBaseField) which allows a field to marked as savable during a cascade save. - Each SaveableBaseField defines a save method which describes how it will deal with a cascade save call this allows lists, dicts, and maps to be effected during a cascade save. - Added an _is_saving flag during Document save to avoid saving a document that is already in the process of being saved. (Caused if there is a circular reference.) --- mongoengine/base/__init__.py | 1 + mongoengine/base/fields.py | 21 +++- mongoengine/document.py | 155 +++++++++++++-------------- mongoengine/fields.py | 25 ++++- tests/document/test_class_methods.py | 19 ++++ 5 files changed, 137 insertions(+), 84 deletions(-) diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py index dca0c4bb7..f31759ece 100644 --- a/mongoengine/base/__init__.py +++ b/mongoengine/base/__init__.py @@ -27,6 +27,7 @@ "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField", + "SaveableBaseField", # metaclasses "DocumentMetaclass", "TopLevelDocumentMetaclass", diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index a68035274..a5ba3f58f 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -13,7 +13,7 @@ from mongoengine.common import _import_class from mongoengine.errors import DeprecatedError, ValidationError -__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") +__all__ = ("BaseField", "SaveableBaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") class BaseField: @@ -259,7 +259,14 @@ def owner_document(self, owner_document): self._set_owner_document(owner_document) -class ComplexBaseField(BaseField): +class SaveableBaseField(BaseField): + """A base class that dictates a field has the ability to save. + """ + def save(): + pass + + +class ComplexBaseField(SaveableBaseField): """Handles complex fields, such as lists / dictionaries. Allows for nesting of embedded documents inside complex types. @@ -483,6 +490,16 @@ def validate(self, value): if self.required and not value: self.error("Field is required and cannot be empty") + def save(self, instance, **kwargs): + Document = _import_class("Document") + value = instance._data.get(self.name) + + for ref in value: + if isinstance(ref, SaveableBaseField): + ref.save(self, **kwargs) + elif isinstance(ref, Document): + ref.save(**kwargs) + def prepare_query_value(self, op, value): return self.to_mongo(value) diff --git a/mongoengine/document.py b/mongoengine/document.py index 3071a3b7e..dc6b30766 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -8,6 +8,7 @@ from mongoengine.base import ( BaseDict, BaseDocument, + SaveableBaseField, BaseList, DocumentMetaclass, EmbeddedDocumentList, @@ -385,78 +386,89 @@ def save( the cascade save using cascade_kwargs which overwrites the existing kwargs with custom values. """ - signal_kwargs = signal_kwargs or {} - - if self._meta.get("abstract"): - raise InvalidDocumentError("Cannot save an abstract document.") - - # Cascade save first before saving document - if cascade is None: - cascade = self._meta.get("cascade", False) or cascade_kwargs is not None + # Used to avoid saving a document that is already saving (infinite loops) + # this can be caused by the cascade save and circular references + if getattr(self, "_is_saving", False): + return + self._is_saving = True - if cascade: - kwargs = { - "force_insert": force_insert, - "validate": validate, - "write_concern": write_concern, - "cascade": cascade, - } - if cascade_kwargs: # Allow granular control over cascades - kwargs.update(cascade_kwargs) - kwargs["_refs"] = _refs - self.cascade_save(**kwargs) + try: + signal_kwargs = signal_kwargs or {} + + if self._meta.get("abstract"): + raise InvalidDocumentError("Cannot save an abstract document.") + + # Cascade save first before saving document + if cascade is None: + cascade = self._meta.get("cascade", False) or cascade_kwargs is not None + + if cascade: + kwargs = { + "force_insert": force_insert, + "validate": validate, + "write_concern": write_concern, + "cascade": cascade, + } + if cascade_kwargs: # Allow granular control over cascades + kwargs.update(cascade_kwargs) + kwargs["_refs"] = _refs + self.cascade_save(**kwargs) - signals.pre_save.send(self.__class__, document=self, **signal_kwargs) + signals.pre_save.send(self.__class__, document=self, **signal_kwargs) - if validate: - self.validate(clean=clean) + if validate: + self.validate(clean=clean) - if write_concern is None: - write_concern = {} + if write_concern is None: + write_concern = {} - doc_id = self.to_mongo(fields=[self._meta["id_field"]]) - created = "_id" not in doc_id or self._created or force_insert + doc_id = self.to_mongo(fields=[self._meta["id_field"]]) + created = "_id" not in doc_id or self._created or force_insert - signals.pre_save_post_validation.send( - self.__class__, document=self, created=created, **signal_kwargs - ) - # it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation - doc = self.to_mongo() + signals.pre_save_post_validation.send( + self.__class__, document=self, created=created, **signal_kwargs + ) + # it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation + doc = self.to_mongo() - if self._meta.get("auto_create_index", True): - self.ensure_indexes() + if self._meta.get("auto_create_index", True): + self.ensure_indexes() - try: - # Save a new document or update an existing one - if created: - object_id = self._save_create(doc, force_insert, write_concern) - else: - object_id, created = self._save_update( - doc, save_condition, write_concern - ) - except pymongo.errors.DuplicateKeyError as err: - message = "Tried to save duplicate unique keys (%s)" - raise NotUniqueError(message % err) - except pymongo.errors.OperationFailure as err: - message = "Could not save document (%s)" - if re.match("^E1100[01] duplicate key", str(err)): - # E11000 - duplicate key error index - # E11001 - duplicate key on update + try: + # Save a new document or update an existing one + if created: + object_id = self._save_create(doc, force_insert, write_concern) + else: + object_id, created = self._save_update( + doc, save_condition, write_concern + ) + except pymongo.errors.DuplicateKeyError as err: message = "Tried to save duplicate unique keys (%s)" raise NotUniqueError(message % err) - raise OperationError(message % err) - - # Make sure we store the PK on this document now that it's saved - id_field = self._meta["id_field"] - if created or id_field not in self._meta.get("shard_key", []): - self[id_field] = self._fields[id_field].to_python(object_id) - - signals.post_save.send( - self.__class__, document=self, created=created, **signal_kwargs - ) + except pymongo.errors.OperationFailure as err: + message = "Could not save document (%s)" + if re.match("^E1100[01] duplicate key", str(err)): + # E11000 - duplicate key error index + # E11001 - duplicate key on update + message = "Tried to save duplicate unique keys (%s)" + raise NotUniqueError(message % err) + raise OperationError(message % err) + + # Make sure we store the PK on this document now that it's saved + id_field = self._meta["id_field"] + if created or id_field not in self._meta.get("shard_key", []): + self[id_field] = self._fields[id_field].to_python(object_id) + + signals.post_save.send( + self.__class__, document=self, created=created, **signal_kwargs + ) - self._clear_changed_fields() - self._created = False + self._clear_changed_fields() + self._created = False + except Exception as e: + raise e + finally: + self._is_saving = False return self @@ -556,28 +568,11 @@ def cascade_save(self, **kwargs): """Recursively save any references and generic references on the document. """ - _refs = kwargs.get("_refs") or [] - - ReferenceField = _import_class("ReferenceField") - GenericReferenceField = _import_class("GenericReferenceField") for name, cls in self._fields.items(): - if not isinstance(cls, (ReferenceField, GenericReferenceField)): + if not isinstance(cls, SaveableBaseField): continue - - ref = self._data.get(name) - if not ref or isinstance(ref, DBRef): - continue - - if not getattr(ref, "_changed_fields", True): - continue - - ref_id = f"{ref.__class__.__name__},{str(ref._data)}" - if ref and ref_id not in _refs: - _refs.append(ref_id) - kwargs["_refs"] = _refs - ref.save(**kwargs) - ref._changed_fields = [] + cls.save(self, **kwargs) @property def _qs(self): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index a2ccc7aea..8f351369c 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -25,6 +25,7 @@ from mongoengine.base import ( BaseDocument, BaseField, + SaveableBaseField, ComplexBaseField, GeoJsonBaseField, LazyReference, @@ -1123,7 +1124,7 @@ def __init__(self, field=None, *args, **kwargs): super().__init__(field=field, *args, **kwargs) -class ReferenceField(BaseField): +class ReferenceField(SaveableBaseField): """A reference to a document that will be automatically dereferenced on access (lazily). @@ -1295,6 +1296,16 @@ def validate(self, value): "saved to the database" ) + def save(self, instance, **kwargs): + ref = instance._data.get(self.name) + if not ref or isinstance(ref, DBRef): + return + + if not getattr(self, "_changed_fields", True): + return + + ref.save(**kwargs) + def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -1464,7 +1475,7 @@ def sync_all(self): self.owner_document.objects(**filter_kwargs).update(**update_kwargs) -class GenericReferenceField(BaseField): +class GenericReferenceField(SaveableBaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). @@ -1546,6 +1557,16 @@ def validate(self, value): " saved to the database" ) + def save(self, instance, **kwargs): + ref = instance._data.get(self.name) + if not ref or isinstance(ref, DBRef): + return + + if not getattr(ref, "_changed_fields", True): + return + + ref.save(**kwargs) + def to_mongo(self, document): if document is None: return None diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index ea6d51bc1..ce6c0998e 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -362,5 +362,24 @@ class Job(Document): job_obj = Job.objects[0] assert job_obj.employee == job.employee + def test_cascade_save_nested_referencefields(self): + """Ensure that nested ReferenceFields are saved during a cascade_save. + """ + + class Job(Document): + employee = ReferenceField(self.Person) + + class Company(Document): + job_list = ListField(ReferenceField(Job)) + + person = self.Person(name="Test User") + job = Job(employee=person) + company = Company(job_list=[job]).save(cascade=True) + + company_obj = Company.objects.first() + assert company_obj.job_list[0] == job + + assert company_obj.job_list[0].employee["name"] == "Test User" + if __name__ == "__main__": unittest.main() From 7672f2c4af7374695ba20c4f55154848d07f43ad Mon Sep 17 00:00:00 2001 From: Nicholas Freville Date: Wed, 22 Dec 2021 14:13:50 -0500 Subject: [PATCH 7/8] Allowed ReferenceField cycles of unsaved documents - Allows for users to cascade save unsaved documents even if a cycle exists. --- mongoengine/document.py | 31 +++++++++++++++++--- tests/document/test_class_methods.py | 42 ++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/mongoengine/document.py b/mongoengine/document.py index dc6b30766..ae391a63f 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,6 +1,7 @@ import re import pymongo +from bson import SON from bson.dbref import DBRef from pymongo.read_preferences import ReadPreference @@ -395,14 +396,25 @@ def save( try: signal_kwargs = signal_kwargs or {} + if write_concern is None: + write_concern = {} + if self._meta.get("abstract"): raise InvalidDocumentError("Cannot save an abstract document.") - # Cascade save first before saving document + # Cascade save before validation to avoid child not existing errors if cascade is None: cascade = self._meta.get("cascade", False) or cascade_kwargs is not None + has_placeholder_saved = False + if cascade: + # If a cascade will occur save a placeholder version of this document to + # avoid issues with cyclic saves if this doc has not been created yet + if self.id is None: + self._save_place_holder(force_insert, write_concern) + has_placeholder_saved = True + kwargs = { "force_insert": force_insert, "validate": validate, @@ -414,14 +426,15 @@ def save( kwargs["_refs"] = _refs self.cascade_save(**kwargs) + # update force_insert to reflect that we might have already run the insert for + # the placeholder + force_insert = force_insert and not has_placeholder_saved + signals.pre_save.send(self.__class__, document=self, **signal_kwargs) if validate: self.validate(clean=clean) - if write_concern is None: - write_concern = {} - doc_id = self.to_mongo(fields=[self._meta["id_field"]]) created = "_id" not in doc_id or self._created or force_insert @@ -472,6 +485,16 @@ def save( return self + def _save_place_holder(self, force_insert, write_concern): + """Save a temp placeholder to the db with nothing but the ID. + """ + data = SON() + + object_id = self._save_create(data, force_insert, write_concern) + + id_field = self._meta["id_field"] + self[id_field] = self._fields[id_field].to_python(object_id) + def _save_create(self, doc, force_insert, write_concern): """Save a new document. diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index ce6c0998e..0cd99eb4a 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -381,5 +381,47 @@ class Company(Document): assert company_obj.job_list[0].employee["name"] == "Test User" + def test_cascade_save_with_cycles(self): + """Ensure that cyclic references do not break cascade saves. + """ + + class Object1(Document): + name = StringField() + oject2_reference = ReferenceField('Object2') + oject2_list = ListField(ReferenceField('Object2')) + + class Object2(Document): + name = StringField() + oject1_reference = ReferenceField(Object1) + oject1_list = ListField(ReferenceField(Object1)) + + # TOFIX: is there a way to make it so the objects do not need to be saved + # beforhand? + obj_1_name = "Test Object 1" + obj_1 = Object1(name=obj_1_name) + obj_2_name = "Test Object 2" + obj_2 = Object2(name="Has not been saved") + + # Create a cyclic reference nightmare + obj_2.oject1_reference = obj_1 + obj_2.oject1_list = [obj_1] + + obj_1.oject2_reference = obj_2 + obj_1.oject2_list = [obj_2] + + + obj_2.name = obj_2_name + obj_1.save(cascade=True) + + test_1 = Object1.objects.first() + assert test_1.name == obj_1_name + assert test_1.oject2_reference.name == obj_2_name + assert test_1.oject2_list[0].name == obj_2_name + + test_2 = Object2.objects.first() + assert test_2.name == obj_2_name + assert test_2.oject1_reference.name == obj_1_name + assert test_2.oject1_list[0].name == obj_1_name + if __name__ == "__main__": unittest.main() From ba400b7c7884932f9c5f48523282e18228fd2ada Mon Sep 17 00:00:00 2001 From: Nicholas Freville Date: Wed, 22 Dec 2021 14:57:26 -0500 Subject: [PATCH 8/8] Removed TOFIX comment that I fixed. --- tests/document/test_class_methods.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/document/test_class_methods.py b/tests/document/test_class_methods.py index 0cd99eb4a..cd3cfd05a 100644 --- a/tests/document/test_class_methods.py +++ b/tests/document/test_class_methods.py @@ -395,8 +395,6 @@ class Object2(Document): oject1_reference = ReferenceField(Object1) oject1_list = ListField(ReferenceField(Object1)) - # TOFIX: is there a way to make it so the objects do not need to be saved - # beforhand? obj_1_name = "Test Object 1" obj_1 = Object1(name=obj_1_name) obj_2_name = "Test Object 2"