Skip to content

Commit

Permalink
Merge pull request #2387 from bagerard/fix_change_fields_inconsistencies
Browse files Browse the repository at this point in the history
fix inconsistencies in ._changed_fields computation
  • Loading branch information
bagerard committed Oct 29, 2020
2 parents 9f82a02 + aabc187 commit 65f50fd
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 8 deletions.
42 changes: 39 additions & 3 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,9 @@ def _clear_changed_fields(self):
"""Using _get_changed_fields iterate and remove any fields that
are marked as changed.
"""
ReferenceField = _import_class("ReferenceField")
GenericReferenceField = _import_class("GenericReferenceField")

for changed in self._get_changed_fields():
parts = changed.split(".")
data = self
Expand All @@ -550,7 +553,8 @@ def _clear_changed_fields(self):
elif isinstance(data, dict):
data = data.get(part, None)
else:
data = getattr(data, part, None)
field_name = data._reverse_db_field_map.get(part, part)
data = getattr(data, field_name, None)

if not isinstance(data, LazyReference) and hasattr(
data, "_changed_fields"
Expand All @@ -559,10 +563,40 @@ def _clear_changed_fields(self):
continue

data._changed_fields = []
elif isinstance(data, (list, tuple, dict)):
if hasattr(data, "field") and isinstance(
data.field, (ReferenceField, GenericReferenceField)
):
continue
BaseDocument._nestable_types_clear_changed_fields(data)

self._changed_fields = []

def _nestable_types_changed_fields(self, changed_fields, base_key, data):
@staticmethod
def _nestable_types_clear_changed_fields(data):
"""Inspect nested data for changed fields
:param data: data to inspect for changes
"""
Document = _import_class("Document")

# Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(data, "items"):
iterator = enumerate(data)
else:
iterator = data.items()

for index_or_key, value in iterator:
if hasattr(value, "_get_changed_fields") and not isinstance(
value, Document
): # don't follow references
value._clear_changed_fields()
elif isinstance(value, (list, tuple, dict)):
BaseDocument._nestable_types_clear_changed_fields(value)

@staticmethod
def _nestable_types_changed_fields(changed_fields, base_key, data):
"""Inspect nested data for changed fields
:param changed_fields: Previously collected changed fields
Expand All @@ -587,7 +621,9 @@ def _nestable_types_changed_fields(self, changed_fields, base_key, data):
changed = value._get_changed_fields()
changed_fields += ["{}{}".format(item_key, k) for k in changed if k]
elif isinstance(value, (list, tuple, dict)):
self._nestable_types_changed_fields(changed_fields, item_key, value)
BaseDocument._nestable_types_changed_fields(
changed_fields, item_key, value
)

def _get_changed_fields(self):
"""Return a list of all fields that have explicitly been changed.
Expand Down
34 changes: 31 additions & 3 deletions tests/document/test_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ class Doc(DocClass):
{},
)
doc.save()
assert doc._get_changed_fields() == []
doc = doc.reload(10)

assert doc.embedded_field.list_field[0] == "1"
Expand Down Expand Up @@ -777,9 +778,7 @@ class MyDoc(Document):

MyDoc.drop_collection()

mydoc = MyDoc(
name="testcase1", subs={"a": {"b": EmbeddedDoc(name="foo")}}
).save()
MyDoc(name="testcase1", subs={"a": {"b": EmbeddedDoc(name="foo")}}).save()

mydoc = MyDoc.objects.first()
subdoc = mydoc.subs["a"]["b"]
Expand All @@ -791,6 +790,35 @@ class MyDoc(Document):
mydoc._clear_changed_fields()
assert mydoc._get_changed_fields() == []

def test_nested_nested_fields_db_field_set__gets_mark_as_changed_and_cleaned(self):
class EmbeddedDoc(EmbeddedDocument):
name = StringField(db_field="db_name")

class MyDoc(Document):
embed = EmbeddedDocumentField(EmbeddedDoc, db_field="db_embed")
name = StringField(db_field="db_name")

MyDoc.drop_collection()

MyDoc(name="testcase1", embed=EmbeddedDoc(name="foo")).save()

mydoc = MyDoc.objects.first()
mydoc.embed.name = "foo1"

assert mydoc.embed._get_changed_fields() == ["db_name"]
assert mydoc._get_changed_fields() == ["db_embed.db_name"]

mydoc = MyDoc.objects.first()
embed = EmbeddedDoc(name="foo2")
embed.name = "bar"
mydoc.embed = embed

assert embed._get_changed_fields() == ["db_name"]
assert mydoc._get_changed_fields() == ["db_embed"]

mydoc._clear_changed_fields()
assert mydoc._get_changed_fields() == []

def test_lower_level_mark_as_changed(self):
class EmbeddedDoc(EmbeddedDocument):
name = StringField()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,7 @@ class SimpleList(Document):
assert Post.objects.all()[0].user_lists == [[u1, u2], [u3]]

def test_circular_reference(self):
"""Ensure you can handle circular references
"""
"""Ensure you can handle circular references"""

class Relation(EmbeddedDocument):
name = StringField()
Expand Down Expand Up @@ -426,6 +425,7 @@ def __repr__(self):

daughter.relations.append(mother)
daughter.relations.append(daughter)
assert daughter._get_changed_fields() == ["relations"]
daughter.save()

assert "[<Person: Mother>, <Person: Daughter>]" == "%s" % Person.objects()
Expand Down

0 comments on commit 65f50fd

Please sign in to comment.