Skip to content

Commit

Permalink
Merge pull request #2545 from bagerard/fix_embedded_instance_deepcopy
Browse files Browse the repository at this point in the history
Fix embedded instance deepcopy
  • Loading branch information
bagerard committed Aug 7, 2021
2 parents 3b10236 + 66978ae commit bb9ba73
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 27 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,4 @@ that much better:
* Stankiewicz Mateusz (https://github.com/mas15)
* Felix Schultheiß (https://github.com/felix-smashdocs)
* Jan Stein (https://github.com/janste63)
* Timothé Perez (https://github.com/AchilleAsh)
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Development
===========
- (Fill this out as you fix issues and develop your features).
- EnumField improvements: now `choices` limits the values of an enum to allow
- Fix deepcopy of EmbeddedDocument #2202

Changes in 0.23.1
===========
Expand Down
11 changes: 10 additions & 1 deletion mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

def __getstate__(self):
data = super().__getstate__()
data["_instance"] = None
return data

def __setstate__(self, state):
super().__setstate__(state)
self._instance = state["_instance"]

def to_mongo(self, *args, **kwargs):
data = super().to_mongo(*args, **kwargs)

Expand Down Expand Up @@ -126,7 +135,7 @@ class Document(BaseDocument, metaclass=TopLevelDocumentMetaclass):
create a specialised version of the document that will be stored in the
same collection. To facilitate this behaviour a `_cls`
field is added to documents (hidden though the MongoEngine interface).
To enable this behaviourset :attr:`allow_inheritance` to ``True`` in the
To enable this behaviour set :attr:`allow_inheritance` to ``True`` in the
:attr:`meta` dictionary.
A :class:`~mongoengine.Document` may use a **Capped Collection** by
Expand Down
30 changes: 15 additions & 15 deletions tests/document/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def tearDown(self):
for collection in list_collection_names(self.db):
self.db.drop_collection(collection)

def assertDbEqual(self, docs):
def _assert_db_equal(self, docs):
assert list(self.Person._get_collection().find().sort("id")) == sorted(
docs, key=lambda doc: doc["_id"]
)

def assertHasInstance(self, field, instance):
def _assert_has_instance(self, field, instance):
assert hasattr(field, "_instance")
assert field._instance is not None
if isinstance(field._instance, weakref.ProxyType):
Expand Down Expand Up @@ -740,11 +740,11 @@ class Doc(Document):
Doc.drop_collection()

doc = Doc(embedded_field=Embedded(string="Hi"))
self.assertHasInstance(doc.embedded_field, doc)
self._assert_has_instance(doc.embedded_field, doc)

doc.save()
doc = Doc.objects.get()
self.assertHasInstance(doc.embedded_field, doc)
self._assert_has_instance(doc.embedded_field, doc)

def test_embedded_document_complex_instance(self):
"""Ensure that embedded documents in complex fields can reference
Expand All @@ -759,11 +759,11 @@ class Doc(Document):

Doc.drop_collection()
doc = Doc(embedded_field=[Embedded(string="Hi")])
self.assertHasInstance(doc.embedded_field[0], doc)
self._assert_has_instance(doc.embedded_field[0], doc)

doc.save()
doc = Doc.objects.get()
self.assertHasInstance(doc.embedded_field[0], doc)
self._assert_has_instance(doc.embedded_field[0], doc)

def test_embedded_document_complex_instance_no_use_db_field(self):
"""Ensure that use_db_field is propagated to list of Emb Docs."""
Expand Down Expand Up @@ -792,11 +792,11 @@ class Account(Document):

acc = Account()
acc.email = Email(email="test@example.com")
self.assertHasInstance(acc._data["email"], acc)
self._assert_has_instance(acc._data["email"], acc)
acc.save()

acc1 = Account.objects.first()
self.assertHasInstance(acc1._data["email"], acc1)
self._assert_has_instance(acc1._data["email"], acc1)

def test_instance_is_set_on_setattr_on_embedded_document_list(self):
class Email(EmbeddedDocument):
Expand All @@ -808,11 +808,11 @@ class Account(Document):
Account.drop_collection()
acc = Account()
acc.emails = [Email(email="test@example.com")]
self.assertHasInstance(acc._data["emails"][0], acc)
self._assert_has_instance(acc._data["emails"][0], acc)
acc.save()

acc1 = Account.objects.first()
self.assertHasInstance(acc1._data["emails"][0], acc1)
self._assert_has_instance(acc1._data["emails"][0], acc1)

def test_save_checks_that_clean_is_called(self):
class CustomError(Exception):
Expand Down Expand Up @@ -921,7 +921,7 @@ def test_modify_empty(self):
with pytest.raises(InvalidDocumentError):
self.Person().modify(set__age=10)

self.assertDbEqual([dict(doc.to_mongo())])
self._assert_db_equal([dict(doc.to_mongo())])

def test_modify_invalid_query(self):
doc1 = self.Person(name="bob", age=10).save()
Expand All @@ -931,7 +931,7 @@ def test_modify_invalid_query(self):
with pytest.raises(InvalidQueryError):
doc1.modify({"id": doc2.id}, set__value=20)

self.assertDbEqual(docs)
self._assert_db_equal(docs)

def test_modify_match_another_document(self):
doc1 = self.Person(name="bob", age=10).save()
Expand All @@ -941,7 +941,7 @@ def test_modify_match_another_document(self):
n_modified = doc1.modify({"name": doc2.name}, set__age=100)
assert n_modified == 0

self.assertDbEqual(docs)
self._assert_db_equal(docs)

def test_modify_not_exists(self):
doc1 = self.Person(name="bob", age=10).save()
Expand All @@ -951,7 +951,7 @@ def test_modify_not_exists(self):
n_modified = doc2.modify({"name": doc2.name}, set__age=100)
assert n_modified == 0

self.assertDbEqual(docs)
self._assert_db_equal(docs)

def test_modify_update(self):
other_doc = self.Person(name="bob", age=10).save()
Expand All @@ -977,7 +977,7 @@ def test_modify_update(self):
assert doc.to_json() == doc_copy.to_json()
assert doc._get_changed_fields() == []

self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())])
self._assert_db_equal([dict(other_doc.to_mongo()), dict(doc.to_mongo())])

def test_modify_with_positional_push(self):
class Content(EmbeddedDocument):
Expand Down
31 changes: 31 additions & 0 deletions tests/fields/test_embedded_document_field.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from copy import deepcopy

import pytest
from bson import ObjectId

from mongoengine import (
Document,
Expand All @@ -9,6 +12,7 @@
InvalidQueryError,
ListField,
LookUpError,
MapField,
StringField,
ValidationError,
)
Expand Down Expand Up @@ -350,3 +354,30 @@ class Person(Document):
# Test existing attribute
assert Person.objects(settings__base_foo="basefoo").first().id == p.id
assert Person.objects(settings__sub_foo="subfoo").first().id == p.id

def test_deepcopy_set__instance(self):
"""Ensure that the _instance attribute on EmbeddedDocument exists after a deepcopy"""

class Wallet(EmbeddedDocument):
money = IntField()

class Person(Document):
wallet = EmbeddedDocumentField(Wallet)
wallet_map = MapField(EmbeddedDocumentField(Wallet))

# Test on fresh EmbeddedDoc
emb_doc = Wallet(money=1)
assert emb_doc._instance is None
copied_emb_doc = deepcopy(emb_doc)
assert copied_emb_doc._instance is None

# Test on attached EmbeddedDoc
doc = Person(
id=ObjectId(), wallet=Wallet(money=2), wallet_map={"test": Wallet(money=2)}
)
assert doc.wallet._instance == doc
copied_emb_doc = deepcopy(doc.wallet)
assert copied_emb_doc._instance is None

copied_map_emb_doc = deepcopy(doc.wallet_map)
assert copied_map_emb_doc["test"]._instance is None
22 changes: 11 additions & 11 deletions tests/queryset/test_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self):
connect(db="mongoenginetest")
Doc.drop_collection()

def assertDbEqual(self, docs):
def _assert_db_equal(self, docs):
assert list(Doc._collection.find().sort("id")) == docs

def test_modify(self):
Expand All @@ -28,7 +28,7 @@ def test_modify(self):

old_doc = Doc.objects(id=1).modify(set__value=-1)
assert old_doc.to_json() == doc.to_json()
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])

def test_modify_with_new(self):
Doc(id=0, value=0).save()
Expand All @@ -37,45 +37,45 @@ def test_modify_with_new(self):
new_doc = Doc.objects(id=1).modify(set__value=-1, new=True)
doc.value = -1
assert new_doc.to_json() == doc.to_json()
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])

def test_modify_not_existing(self):
Doc(id=0, value=0).save()
assert Doc.objects(id=1).modify(set__value=-1) is None
self.assertDbEqual([{"_id": 0, "value": 0}])
self._assert_db_equal([{"_id": 0, "value": 0}])

def test_modify_with_upsert(self):
Doc(id=0, value=0).save()
old_doc = Doc.objects(id=1).modify(set__value=1, upsert=True)
assert old_doc is None
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}])
self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}])

def test_modify_with_upsert_existing(self):
Doc(id=0, value=0).save()
doc = Doc(id=1, value=1).save()

old_doc = Doc.objects(id=1).modify(set__value=-1, upsert=True)
assert old_doc.to_json() == doc.to_json()
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])

def test_modify_with_upsert_with_new(self):
Doc(id=0, value=0).save()
new_doc = Doc.objects(id=1).modify(upsert=True, new=True, set__value=1)
assert new_doc.to_mongo() == {"_id": 1, "value": 1}
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}])
self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": 1}])

def test_modify_with_remove(self):
Doc(id=0, value=0).save()
doc = Doc(id=1, value=1).save()

old_doc = Doc.objects(id=1).modify(remove=True)
assert old_doc.to_json() == doc.to_json()
self.assertDbEqual([{"_id": 0, "value": 0}])
self._assert_db_equal([{"_id": 0, "value": 0}])

def test_find_and_modify_with_remove_not_existing(self):
Doc(id=0, value=0).save()
assert Doc.objects(id=1).modify(remove=True) is None
self.assertDbEqual([{"_id": 0, "value": 0}])
self._assert_db_equal([{"_id": 0, "value": 0}])

def test_modify_with_order_by(self):
Doc(id=0, value=3).save()
Expand All @@ -85,7 +85,7 @@ def test_modify_with_order_by(self):

old_doc = Doc.objects().order_by("-id").modify(set__value=-1)
assert old_doc.to_json() == doc.to_json()
self.assertDbEqual(
self._assert_db_equal(
[
{"_id": 0, "value": 3},
{"_id": 1, "value": 2},
Expand All @@ -100,7 +100,7 @@ def test_modify_with_fields(self):

old_doc = Doc.objects(id=1).only("id").modify(set__value=-1)
assert old_doc.to_mongo() == {"_id": 1}
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
self._assert_db_equal([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])

def test_modify_with_push(self):
class BlogPost(Document):
Expand Down

0 comments on commit bb9ba73

Please sign in to comment.