Skip to content

Commit

Permalink
Improve LazyReferenceField and GenericLazyReferenceField with nested …
Browse files Browse the repository at this point in the history
…fields
  • Loading branch information
touilleMan committed Nov 22, 2017
1 parent 47c7cb9 commit e74f659
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 22 deletions.
3 changes: 2 additions & 1 deletion mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mongoengine.base.common import get_document
from mongoengine.base.datastructures import (BaseDict, BaseList,
EmbeddedDocumentList,
LazyReference,
StrictDict)
from mongoengine.base.fields import ComplexBaseField
from mongoengine.common import _import_class
Expand Down Expand Up @@ -488,7 +489,7 @@ def _clear_changed_fields(self):
else:
data = getattr(data, part, None)

if hasattr(data, '_changed_fields'):
if not isinstance(data, LazyReference) and hasattr(data, '_changed_fields'):
if getattr(data, '_is_document', False):
continue

Expand Down
9 changes: 8 additions & 1 deletion mongoengine/dereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList,
TopLevelDocumentMetaclass, get_document)
from mongoengine.base.datastructures import LazyReference
from mongoengine.connection import get_db
from mongoengine.document import Document, EmbeddedDocument
from mongoengine.fields import DictField, ListField, MapField, ReferenceField
Expand Down Expand Up @@ -99,7 +100,10 @@ def _find_references(self, items, depth=0):
if isinstance(item, (Document, EmbeddedDocument)):
for field_name, field in item._fields.iteritems():
v = item._data.get(field_name, None)
if isinstance(v, DBRef):
if isinstance(v, LazyReference):
# LazyReference inherits DBRef but should not be dereferenced here !
continue
elif isinstance(v, DBRef):
reference_map.setdefault(field.document_type, set()).add(v.id)
elif isinstance(v, (dict, SON)) and '_ref' in v:
reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id)
Expand All @@ -110,6 +114,9 @@ def _find_references(self, items, depth=0):
if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)):
key = field_cls
reference_map.setdefault(key, set()).update(refs)
elif isinstance(item, LazyReference):
# LazyReference inherits DBRef but should not be dereferenced here !
continue
elif isinstance(item, DBRef):
reference_map.setdefault(item.collection, set()).add(item.id)
elif isinstance(item, (dict, SON)) and '_ref' in item:
Expand Down
56 changes: 37 additions & 19 deletions mongoengine/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from mongoengine.base import (BaseDocument, BaseField, ComplexBaseField,
GeoJsonBaseField, LazyReference, ObjectIdField,
get_document)
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.document import Document, EmbeddedDocument
from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError
Expand Down Expand Up @@ -789,6 +790,17 @@ def __init__(self, field=None, **kwargs):
kwargs.setdefault('default', lambda: [])
super(ListField, self).__init__(**kwargs)

def __get__(self, instance, owner):
if instance is None:
# Document class being used rather than a document object
return self
value = instance._data.get(self.name)
LazyReferenceField = _import_class('LazyReferenceField')
GenericLazyReferenceField = _import_class('GenericLazyReferenceField')
if isinstance(self.field, (LazyReferenceField, GenericLazyReferenceField)) and value:
instance._data[self.name] = [self.field.build_lazyref(x) for x in value]
return super(ListField, self).__get__(instance, owner)

def validate(self, value):
"""Make sure that a list of valid fields is being used."""
if (not isinstance(value, (list, tuple, QuerySet)) or
Expand Down Expand Up @@ -2211,17 +2223,10 @@ def document_type(self):
self.document_type_obj = get_document(self.document_type_obj)
return self.document_type_obj

def __get__(self, instance, owner):
"""Descriptor to allow lazy dereferencing."""
if instance is None:
# Document class being used rather than a document object
return self

value = instance._data.get(self.name)
def build_lazyref(self, value):
if isinstance(value, LazyReference):
if value.passthrough != self.passthrough:
instance._data[self.name] = LazyReference(
value.document_type, value.pk, passthrough=self.passthrough)
value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough)
elif value is not None:
if isinstance(value, self.document_type):
value = LazyReference(self.document_type, value.pk, passthrough=self.passthrough)
Expand All @@ -2230,6 +2235,16 @@ def __get__(self, instance, owner):
else:
# value is the primary key of the referenced document
value = LazyReference(self.document_type, value, passthrough=self.passthrough)
return value

def __get__(self, instance, owner):
"""Descriptor to allow lazy dereferencing."""
if instance is None:
# Document class being used rather than a document object
return self

value = self.build_lazyref(instance._data.get(self.name))
if value:
instance._data[self.name] = value

return super(LazyReferenceField, self).__get__(instance, owner)
Expand All @@ -2254,7 +2269,7 @@ def to_mongo(self, value):

def validate(self, value):
if isinstance(value, LazyReference):
if not issubclass(value.document_type, self.document_type):
if value.collection != self.document_type._get_collection_name():
self.error('Reference must be on a `%s` document.' % self.document_type)
pk = value.pk
elif isinstance(value, self.document_type):
Expand Down Expand Up @@ -2314,23 +2329,26 @@ def __init__(self, *args, **kwargs):

def _validate_choices(self, value):
if isinstance(value, LazyReference):
value = value.document_type
value = value.document_type._class_name
super(GenericLazyReferenceField, self)._validate_choices(value)

def __get__(self, instance, owner):
if instance is None:
return self

value = instance._data.get(self.name)
def build_lazyref(self, value):
if isinstance(value, LazyReference):
if value.passthrough != self.passthrough:
instance._data[self.name] = LazyReference(
value.document_type, value.pk, passthrough=self.passthrough)
value = LazyReference(value.document_type, value.pk, passthrough=self.passthrough)
elif value is not None:
if isinstance(value, (dict, SON)):
value = LazyReference(get_document(value['_cls']), value['_ref'].id, passthrough=self.passthrough)
elif isinstance(value, Document):
value = LazyReference(type(value), value.pk, passthrough=self.passthrough)
return value

def __get__(self, instance, owner):
if instance is None:
return self

value = self.build_lazyref(instance._data.get(self.name))
if value:
instance._data[self.name] = value

return super(GenericLazyReferenceField, self).__get__(instance, owner)
Expand All @@ -2348,7 +2366,7 @@ def to_mongo(self, document):
if isinstance(document, LazyReference):
return SON((
('_cls', document.document_type._class_name),
('_ref', document)
('_ref', DBRef(document.document_type._get_collection_name(), document.pk))
))
else:
return super(GenericLazyReferenceField, self).to_mongo(document)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[nosetests]
verbosity=2
detailed-errors=1
tests=tests
#tests=tests
cover-package=mongoengine

[flake8]
Expand Down
86 changes: 86 additions & 0 deletions tests/fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4871,6 +4871,48 @@ class Animal(Document):
self.assertNotEqual(animal, other_animalref)
self.assertNotEqual(other_animalref, animal)

def test_lazy_reference_embedded(self):
class Animal(Document):
name = StringField()
tag = StringField()

class EmbeddedOcurrence(EmbeddedDocument):
in_list = ListField(LazyReferenceField(Animal))
direct = LazyReferenceField(Animal)

class Ocurrence(Document):
in_list = ListField(LazyReferenceField(Animal))
in_embedded = EmbeddedDocumentField(EmbeddedOcurrence)
direct = LazyReferenceField(Animal)

Animal.drop_collection()
Ocurrence.drop_collection()

animal1 = Animal('doggo').save()
animal2 = Animal('cheeta').save()

def check_fields_type(occ):
self.assertIsInstance(occ.direct, LazyReference)
for elem in occ.in_list:
self.assertIsInstance(elem, LazyReference)
self.assertIsInstance(occ.in_embedded.direct, LazyReference)
for elem in occ.in_embedded.in_list:
self.assertIsInstance(elem, LazyReference)

occ = Ocurrence(
in_list=[animal1, animal2],
in_embedded={'in_list': [animal1, animal2], 'direct': animal1},
direct=animal1
).save()
check_fields_type(occ)
occ.reload()
check_fields_type(occ)
occ.direct = animal1.id
occ.in_list = [animal1.id, animal2.id]
occ.in_embedded.direct = animal1.id
occ.in_embedded.in_list = [animal1.id, animal2.id]
check_fields_type(occ)


class GenericLazyReferenceFieldTest(MongoDBTestCase):
def test_generic_lazy_reference_simple(self):
Expand Down Expand Up @@ -5051,6 +5093,50 @@ class Ocurrence(Document):
p = Ocurrence.objects.get()
self.assertIs(p.animal, None)

def test_generic_lazy_reference_embedded(self):
class Animal(Document):
name = StringField()
tag = StringField()

class EmbeddedOcurrence(EmbeddedDocument):
in_list = ListField(GenericLazyReferenceField())
direct = GenericLazyReferenceField()

class Ocurrence(Document):
in_list = ListField(GenericLazyReferenceField())
in_embedded = EmbeddedDocumentField(EmbeddedOcurrence)
direct = GenericLazyReferenceField()

Animal.drop_collection()
Ocurrence.drop_collection()

animal1 = Animal('doggo').save()
animal2 = Animal('cheeta').save()

def check_fields_type(occ):
self.assertIsInstance(occ.direct, LazyReference)
for elem in occ.in_list:
self.assertIsInstance(elem, LazyReference)
self.assertIsInstance(occ.in_embedded.direct, LazyReference)
for elem in occ.in_embedded.in_list:
self.assertIsInstance(elem, LazyReference)

occ = Ocurrence(
in_list=[animal1, animal2],
in_embedded={'in_list': [animal1, animal2], 'direct': animal1},
direct=animal1
).save()
check_fields_type(occ)
occ.reload()
check_fields_type(occ)
animal1_ref = {'_cls': 'Animal', '_ref': DBRef(animal1._get_collection_name(), animal1.pk)}
animal2_ref = {'_cls': 'Animal', '_ref': DBRef(animal2._get_collection_name(), animal2.pk)}
occ.direct = animal1_ref
occ.in_list = [animal1_ref, animal2_ref]
occ.in_embedded.direct = animal1_ref
occ.in_embedded.in_list = [animal1_ref, animal2_ref]
check_fields_type(occ)


if __name__ == '__main__':
unittest.main()

0 comments on commit e74f659

Please sign in to comment.