Skip to content

Commit

Permalink
Merge pull request #826 from seglberg/mmelliso/fix#503
Browse files Browse the repository at this point in the history
EmbeddedDocumentListField (Resolves #503)
  • Loading branch information
DavidBord committed Feb 20, 2015
2 parents f42ab95 + 4272162 commit 47a4d58
Show file tree
Hide file tree
Showing 10 changed files with 745 additions and 29 deletions.
16 changes: 16 additions & 0 deletions docs/apireference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Fields
.. autoclass:: mongoengine.fields.GenericEmbeddedDocumentField
.. autoclass:: mongoengine.fields.DynamicField
.. autoclass:: mongoengine.fields.ListField
.. autoclass:: mongoengine.fields.EmbeddedDocumentListField
.. autoclass:: mongoengine.fields.SortedListField
.. autoclass:: mongoengine.fields.DictField
.. autoclass:: mongoengine.fields.MapField
Expand All @@ -103,6 +104,21 @@ Fields
.. autoclass:: mongoengine.fields.ImageGridFsProxy
.. autoclass:: mongoengine.fields.ImproperlyConfigured

Embedded Document Querying
==========================

.. versionadded:: 0.9

Additional queries for Embedded Documents are available when using the
:class:`~mongoengine.EmbeddedDocumentListField` to store a list of embedded
documents.

A list of embedded documents is returned as a special list with the
following methods:

.. autoclass:: mongoengine.base.datastructures.EmbeddedDocumentList
:members:

Misc
====

Expand Down
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Changelog

Changes in 0.9.X - DEV
======================
- Added `EmbeddedDocumentListField` for Lists of Embedded Documents. #826
- ComplexDateTimeField should fall back to None when null=True #864
- Request Support for $min, $max Field update operators #863
- `BaseDict` does not follow `setdefault` #866
Expand Down
166 changes: 164 additions & 2 deletions mongoengine/base/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import functools
import itertools
from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned

__all__ = ("BaseDict", "BaseList")
__all__ = ("BaseDict", "BaseList", "EmbeddedDocumentList")


class BaseDict(dict):
Expand Down Expand Up @@ -106,7 +107,7 @@ def __init__(self, list_items, instance, name):
if isinstance(instance, (Document, EmbeddedDocument)):
self._instance = weakref.proxy(instance)
self._name = name
return super(BaseList, self).__init__(list_items)
super(BaseList, self).__init__(list_items)

def __getitem__(self, key, *args, **kwargs):
value = super(BaseList, self).__getitem__(key)
Expand Down Expand Up @@ -191,6 +192,167 @@ def _mark_as_changed(self, key=None):
self._instance._mark_as_changed(self._name)


class EmbeddedDocumentList(BaseList):

@classmethod
def __match_all(cls, i, kwargs):
items = kwargs.items()
return all([
getattr(i, k) == v or str(getattr(i, k)) == v for k, v in items
])

@classmethod
def __only_matches(cls, obj, kwargs):
if not kwargs:
return obj
return filter(lambda i: cls.__match_all(i, kwargs), obj)

def __init__(self, list_items, instance, name):
super(EmbeddedDocumentList, self).__init__(list_items, instance, name)
self._instance = instance

def filter(self, **kwargs):
"""
Filters the list by only including embedded documents with the
given keyword arguments.
:param kwargs: The keyword arguments corresponding to the fields to
filter on. *Multiple arguments are treated as if they are ANDed
together.*
:return: A new ``EmbeddedDocumentList`` containing the matching
embedded documents.
Raises ``AttributeError`` if a given keyword is not a valid field for
the embedded document class.
"""
values = self.__only_matches(self, kwargs)
return EmbeddedDocumentList(values, self._instance, self._name)

def exclude(self, **kwargs):
"""
Filters the list by excluding embedded documents with the given
keyword arguments.
:param kwargs: The keyword arguments corresponding to the fields to
exclude on. *Multiple arguments are treated as if they are ANDed
together.*
:return: A new ``EmbeddedDocumentList`` containing the non-matching
embedded documents.
Raises ``AttributeError`` if a given keyword is not a valid field for
the embedded document class.
"""
exclude = self.__only_matches(self, kwargs)
values = [item for item in self if item not in exclude]
return EmbeddedDocumentList(values, self._instance, self._name)

def count(self):
"""
The number of embedded documents in the list.
:return: The length of the list, equivalent to the result of ``len()``.
"""
return len(self)

def get(self, **kwargs):
"""
Retrieves an embedded document determined by the given keyword
arguments.
:param kwargs: The keyword arguments corresponding to the fields to
search on. *Multiple arguments are treated as if they are ANDed
together.*
:return: The embedded document matched by the given keyword arguments.
Raises ``DoesNotExist`` if the arguments used to query an embedded
document returns no results. ``MultipleObjectsReturned`` if more
than one result is returned.
"""
values = self.__only_matches(self, kwargs)
if len(values) == 0:
raise DoesNotExist(
"%s matching query does not exist." % self._name
)
elif len(values) > 1:
raise MultipleObjectsReturned(
"%d items returned, instead of 1" % len(values)
)

return values[0]

def first(self):
"""
Returns the first embedded document in the list, or ``None`` if empty.
"""
if len(self) > 0:
return self[0]

def create(self, **values):
"""
Creates a new embedded document and saves it to the database.
.. note::
The embedded document changes are not automatically saved
to the database after calling this method.
:param values: A dictionary of values for the embedded document.
:return: The new embedded document instance.
"""
name = self._name
EmbeddedClass = self._instance._fields[name].field.document_type_obj
self._instance[self._name].append(EmbeddedClass(**values))

return self._instance[self._name][-1]

def save(self, *args, **kwargs):
"""
Saves the ancestor document.
:param args: Arguments passed up to the ancestor Document's save
method.
:param kwargs: Keyword arguments passed up to the ancestor Document's
save method.
"""
self._instance.save(*args, **kwargs)

def delete(self):
"""
Deletes the embedded documents from the database.
.. note::
The embedded document changes are not automatically saved
to the database after calling this method.
:return: The number of entries deleted.
"""
values = list(self)
for item in values:
self._instance[self._name].remove(item)

return len(values)

def update(self, **update):
"""
Updates the embedded documents with the given update values.
.. note::
The embedded document changes are not automatically saved
to the database after calling this method.
:param update: A dictionary of update values to apply to each
embedded document.
:return: The number of entries updated.
"""
if len(update) == 0:
return 0
values = list(self)
for item in values:
for k, v in update.items():
setattr(item, k, v)

return len(values)


class StrictDict(object):
__slots__ = ()
_special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create'])
Expand Down
15 changes: 13 additions & 2 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from mongoengine.python_support import PY3, txt_type

from mongoengine.base.common import get_document, ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict
from mongoengine.base.datastructures import (
BaseDict,
BaseList,
EmbeddedDocumentList,
StrictDict,
SemiStrictDict
)
from mongoengine.base.fields import ComplexBaseField

__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
Expand Down Expand Up @@ -419,6 +425,8 @@ def __expand_dynamic_values(self, name, value):
if not isinstance(value, (dict, list, tuple)):
return value

EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')

is_list = False
if not hasattr(value, 'items'):
is_list = True
Expand All @@ -442,7 +450,10 @@ def __expand_dynamic_values(self, name, value):
# Convert lists / values so we can watch for any changes on them
if (isinstance(value, (list, tuple)) and
not isinstance(value, BaseList)):
value = BaseList(value, self, name)
if issubclass(type(self), EmbeddedDocumentListField):
value = EmbeddedDocumentList(value, self, name)
else:
value = BaseList(value, self, name)
elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, self, name)

Expand Down
14 changes: 10 additions & 4 deletions mongoengine/base/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from mongoengine.errors import ValidationError

from mongoengine.base.common import ALLOW_INHERITANCE
from mongoengine.base.datastructures import BaseDict, BaseList
from mongoengine.base.datastructures import (
BaseDict, BaseList, EmbeddedDocumentList
)

__all__ = ("BaseField", "ComplexBaseField",
"ObjectIdField", "GeoJsonBaseField")
Expand Down Expand Up @@ -210,6 +212,7 @@ def __get__(self, instance, owner):

ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
dereference = (self._auto_dereference and
(self.field is None or isinstance(self.field,
(GenericReferenceField, ReferenceField))))
Expand All @@ -226,9 +229,12 @@ def __get__(self, instance, owner):
value = super(ComplexBaseField, self).__get__(instance, owner)

# Convert lists / values so we can watch for any changes on them
if (isinstance(value, (list, tuple)) and
not isinstance(value, BaseList)):
value = BaseList(value, instance, self.name)
if isinstance(value, (list, tuple)):
if (issubclass(type(self), EmbeddedDocumentListField) and
not isinstance(value, EmbeddedDocumentList)):
value = EmbeddedDocumentList(value, instance, self.name)
elif not isinstance(value, BaseList):
value = BaseList(value, instance, self.name)
instance._data[self.name] = value
elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, instance, self.name)
Expand Down
18 changes: 11 additions & 7 deletions mongoengine/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
_class_registry_cache = {}
_field_list_cache = []


def _import_class(cls_name):
Expand All @@ -20,13 +21,16 @@ class from the :data:`mongoengine.common._class_registry_cache`.

doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument',
'MapReduceDocument')
field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField',
'FileField', 'GenericReferenceField',
'GenericEmbeddedDocumentField', 'GeoPointField',
'PointField', 'LineStringField', 'ListField',
'PolygonField', 'ReferenceField', 'StringField',
'CachedReferenceField',
'ComplexBaseField', 'GeoJsonBaseField')

# Field Classes
if not _field_list_cache:
from mongoengine.fields import __all__ as fields
_field_list_cache.extend(fields)
from mongoengine.base.fields import __all__ as fields
_field_list_cache.extend(fields)

field_classes = _field_list_cache

queryset_classes = ('OperationError',)
deref_classes = ('DeReference',)

Expand Down
10 changes: 8 additions & 2 deletions mongoengine/dereference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from bson import DBRef, SON

from base import (BaseDict, BaseList, TopLevelDocumentMetaclass, get_document)
from base import (
BaseDict, BaseList, EmbeddedDocumentList,
TopLevelDocumentMetaclass, get_document
)
from fields import (ReferenceField, ListField, DictField, MapField)
from connection import get_db
from queryset import QuerySet
Expand Down Expand Up @@ -189,6 +192,9 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):

if not hasattr(items, 'items'):
is_list = True
list_type = BaseList
if isinstance(items, EmbeddedDocumentList):
list_type = EmbeddedDocumentList
as_tuple = isinstance(items, tuple)
iterator = enumerate(items)
data = []
Expand Down Expand Up @@ -225,7 +231,7 @@ def _attach_objects(self, items, depth=0, instance=None, name=None):

if instance and name:
if is_list:
return tuple(data) if as_tuple else BaseList(data, instance, name)
return tuple(data) if as_tuple else list_type(data, instance, name)
return BaseDict(data, instance, name)
depth += 1
return data
22 changes: 19 additions & 3 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@
from bson.dbref import DBRef
from mongoengine import signals
from mongoengine.common import _import_class
from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
BaseDocument, BaseDict, BaseList,
ALLOW_INHERITANCE, get_document)
from mongoengine.base import (
DocumentMetaclass,
TopLevelDocumentMetaclass,
BaseDocument,
BaseDict,
BaseList,
EmbeddedDocumentList,
ALLOW_INHERITANCE,
get_document
)
from mongoengine.errors import ValidationError, InvalidQueryError, InvalidDocumentError
from mongoengine.queryset import (OperationError, NotUniqueError,
QuerySet, transform)
Expand Down Expand Up @@ -76,6 +83,12 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

def save(self, *args, **kwargs):
self._instance.save(*args, **kwargs)

def reload(self, *args, **kwargs):
self._instance.reload(*args, **kwargs)


class Document(BaseDocument):

Expand Down Expand Up @@ -560,6 +573,9 @@ def _reload(self, key, value):
if isinstance(value, BaseDict):
value = [(k, self._reload(k, v)) for k, v in value.items()]
value = BaseDict(value, self, key)
elif isinstance(value, EmbeddedDocumentList):
value = [self._reload(key, v) for v in value]
value = EmbeddedDocumentList(value, self, key)
elif isinstance(value, BaseList):
value = [self._reload(key, v) for v in value]
value = BaseList(value, self, key)
Expand Down

0 comments on commit 47a4d58

Please sign in to comment.