Skip to content

Commit

Permalink
Improve Document inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
touilleMan committed Apr 13, 2016
1 parent ac38c03 commit 9116357
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 29 deletions.
5 changes: 3 additions & 2 deletions examples/inheritance/app.py
Expand Up @@ -52,8 +52,9 @@ def get_vehicle(self, *args):
vehicle = None
try:
vehicle = Vehicle.find_one({'_id': ObjectId(id)})
except (ValueError, TypeError):
pass
except Exception as exc:
print('Error: %s' % exc)
return
if vehicle:
print(vehicle)
else:
Expand Down
45 changes: 45 additions & 0 deletions tests/dal/test_motor_asyncio.py
Expand Up @@ -617,3 +617,48 @@ class Meta:
assert name_sorted(indexes) == name_sorted(expected_indexes)

loop.run_until_complete(do_test())

def test_inheritance_search(self, db, loop):

@asyncio.coroutine
def do_test():

class InheritanceSearchParent(Document):
pf = fields.IntField()

class Meta:
collection = db.inheritance_search
allow_inheritance = True

class InheritanceSearchChild1(InheritanceSearchParent):
c1f = fields.IntField()

class Meta:
allow_inheritance = True

class InheritanceSearchChild1Child(InheritanceSearchChild1):
sc1f = fields.IntField()

class InheritanceSearchChild2(InheritanceSearchParent):
c2f = fields.IntField(required=True)

yield from InheritanceSearchParent.collection.drop()

yield from InheritanceSearchParent(pf=0).commit()
yield from InheritanceSearchChild1(pf=1, c1f=1).commit()
yield from InheritanceSearchChild1Child(pf=1, sc1f=1).commit()
yield from InheritanceSearchChild2(pf=2, c2f=2).commit()

assert (yield from InheritanceSearchParent.find().count()) == 4
assert (yield from InheritanceSearchChild1.find().count()) == 2
assert (yield from InheritanceSearchChild1Child.find().count()) == 1
assert (yield from InheritanceSearchChild2.find().count()) == 1

res = yield from InheritanceSearchParent.find_one({'sc1f': 1})
assert isinstance(res, InheritanceSearchChild1Child)

cursor = InheritanceSearchParent.find({'pf': 1})
for r in (yield from cursor.to_list(length=100)):
assert isinstance(r, InheritanceSearchChild1)

loop.run_until_complete(do_test())
40 changes: 40 additions & 0 deletions tests/dal/test_pymongo.py
Expand Up @@ -465,3 +465,43 @@ class Meta:
UniqueIndexChildDoc.ensure_indexes()
indexes = [e for e in UniqueIndexChildDoc.collection.list_indexes()]
assert name_sorted(indexes) == name_sorted(expected_indexes)

def test_inheritance_search(self, db):

class InheritanceSearchParent(Document):
pf = fields.IntField()

class Meta:
collection = db.inheritance_search
allow_inheritance = True

class InheritanceSearchChild1(InheritanceSearchParent):
c1f = fields.IntField()

class Meta:
allow_inheritance = True

class InheritanceSearchChild1Child(InheritanceSearchChild1):
sc1f = fields.IntField()

class InheritanceSearchChild2(InheritanceSearchParent):
c2f = fields.IntField(required=True)

InheritanceSearchParent.collection.drop()

InheritanceSearchParent(pf=0).commit()
InheritanceSearchChild1(pf=1, c1f=1).commit()
InheritanceSearchChild1Child(pf=1, sc1f=1).commit()
InheritanceSearchChild2(pf=2, c2f=2).commit()

assert InheritanceSearchParent.find().count() == 4
assert InheritanceSearchChild1.find().count() == 2
assert InheritanceSearchChild1Child.find().count() == 1
assert InheritanceSearchChild2.find().count() == 1

res = InheritanceSearchParent.find_one({'sc1f': 1})
assert isinstance(res, InheritanceSearchChild1Child)

res = InheritanceSearchParent.find({'pf': 1})
for r in res:
assert isinstance(r, InheritanceSearchChild1)
45 changes: 45 additions & 0 deletions tests/dal/test_txmongo.py
Expand Up @@ -549,3 +549,48 @@ class Meta:
yield UniqueIndexChildDoc.ensure_indexes()
indexes = [e for e in con[TEST_DB].unique_index_inheritance_doc.list_indexes()]
assert name_sorted(indexes) == name_sorted(expected_indexes)

@pytest_inlineCallbacks
def test_inheritance_search(self, db):

class InheritanceSearchParent(Document):
pf = fields.IntField()

class Meta:
collection = db.inheritance_search
allow_inheritance = True

class InheritanceSearchChild1(InheritanceSearchParent):
c1f = fields.IntField()

class Meta:
allow_inheritance = True

class InheritanceSearchChild1Child(InheritanceSearchChild1):
sc1f = fields.IntField()

class InheritanceSearchChild2(InheritanceSearchParent):
c2f = fields.IntField(required=True)

yield InheritanceSearchParent.collection.drop()

yield InheritanceSearchParent(pf=0).commit()
yield InheritanceSearchChild1(pf=1, c1f=1).commit()
yield InheritanceSearchChild1Child(pf=1, sc1f=1).commit()
yield InheritanceSearchChild2(pf=2, c2f=2).commit()

res = yield InheritanceSearchParent.find()
assert len(res) == 4
res = yield InheritanceSearchChild1.find()
assert len(res) == 2
res = yield InheritanceSearchChild1Child.find()
assert len(res) == 1
res = yield InheritanceSearchChild2.find()
assert len(res) == 1

res = yield InheritanceSearchParent.find_one({'sc1f': 1})
assert isinstance(res, InheritanceSearchChild1Child)

res = yield InheritanceSearchParent.find({'pf': 1})
for r in res:
assert isinstance(r, InheritanceSearchChild1)
22 changes: 14 additions & 8 deletions umongo/dal/motor_asyncio.py
Expand Up @@ -8,6 +8,8 @@
from ..exceptions import NotCreatedError, UpdateError, ValidationError
from ..fields import ReferenceField, ListField, EmbeddedField

from .tools import cook_find_filter


class WrappedCursor(AsyncIOMotorCursor):

Expand All @@ -31,12 +33,12 @@ def clone(self):

def next_object(self):
raw = self.raw_cursor.next_object()
return self.document_cls.build_from_mongo(raw)
return self.document_cls.build_from_mongo(raw, use_cls=True)

def each(self, callback):
def wrapped_callback(result, error):
if not error and result is not None:
result = self.document_cls.build_from_mongo(result)
result = self.document_cls.build_from_mongo(result, use_cls=True)
return callback(result, error)
return self.raw_cursor.each(wrapped_callback)

Expand All @@ -46,7 +48,7 @@ def to_list(self, length, callback=None):
builder = self.document_cls.build_from_mongo

def on_raw_done(fut):
cooked_future.set_result([builder(e) for e in fut.result()])
cooked_future.set_result([builder(e, use_cls=True) for e in fut.result()])

raw_future.add_done_callback(on_raw_done)
return cooked_future
Expand Down Expand Up @@ -114,15 +116,19 @@ def io_validate(self, validate_all=False):
self.schema, self._data, partial=self._data.get_modified_fields())

@classmethod
def find_one(cls, *args, **kwargs):
ret = yield from cls.collection.find_one(*args, **kwargs)
def find_one(cls, spec_or_id=None, *args, **kwargs):
# In pymongo<3, `spec_or_id` is for filtering and `filter` is for sorting
spec_or_id = cook_find_filter(cls, spec_or_id)
ret = yield from cls.collection.find_one(*args, spec_or_id=spec_or_id, **kwargs)
if ret is not None:
ret = cls.build_from_mongo(ret)
ret = cls.build_from_mongo(ret, use_cls=True)
return ret

@classmethod
def find(cls, *args, **kwargs):
return WrappedCursor(cls, cls.collection.find(*args, **kwargs))
def find(cls, spec=None, *args, **kwargs):
# In pymongo<3, `spec` is for filtering and `filter` is for sorting
spec = cook_find_filter(cls, spec)
return WrappedCursor(cls, cls.collection.find(*args, spec=spec, **kwargs))

@classmethod
def ensure_indexes(cls):
Expand Down
18 changes: 11 additions & 7 deletions umongo/dal/pymongo.py
Expand Up @@ -8,6 +8,8 @@
from ..exceptions import NotCreatedError, UpdateError, DeleteError, ValidationError
from ..fields import ReferenceField, ListField, EmbeddedField

from .tools import cook_find_filter


class WrappedCursor(Cursor):

Expand All @@ -28,11 +30,11 @@ def __setattr__(self, name, value):

def __next__(self):
elem = next(self.raw_cursor)
return self.document_cls.build_from_mongo(elem)
return self.document_cls.build_from_mongo(elem, use_cls=True)

def __iter__(self):
for elem in self.raw_cursor:
yield self.document_cls.build_from_mongo(elem)
yield self.document_cls.build_from_mongo(elem, use_cls=True)


class PyMongoDal(AbstractDal):
Expand Down Expand Up @@ -99,15 +101,17 @@ def io_validate(self, validate_all=False):
self.schema, self._data, partial=self._data.get_modified_fields())

@classmethod
def find_one(cls, *args, **kwargs):
ret = cls.collection.find_one(*args, **kwargs)
def find_one(cls, filter=None, *args, **kwargs):
filter = cook_find_filter(cls, filter)
ret = cls.collection.find_one(*args, filter=filter, **kwargs)
if ret is not None:
ret = cls.build_from_mongo(ret)
ret = cls.build_from_mongo(ret, use_cls=True)
return ret

@classmethod
def find(cls, *args, **kwargs):
raw_cursor = cls.collection.find(*args, **kwargs)
def find(cls, filter=None, *args, **kwargs):
filter = cook_find_filter(cls, filter)
raw_cursor = cls.collection.find(*args, filter=filter, **kwargs)
return WrappedCursor(cls, raw_cursor)

@classmethod
Expand Down
13 changes: 13 additions & 0 deletions umongo/dal/tools.py
@@ -0,0 +1,13 @@

def cook_find_filter(doc_cls, filter):
if doc_cls.opts.is_child:
filter = filter or {}
# Current document shares the collection with a parent,
# we must use the _cls field to discriminate
if doc_cls.opts.children:
# Current document has itself children, we also have
# to search through them
filter['_cls'] = {'$in': list(doc_cls.opts.children) + [doc_cls.__name__]}
else:
filter['_cls'] = doc_cls.__name__
return filter
20 changes: 13 additions & 7 deletions umongo/dal/txmongo.py
Expand Up @@ -9,6 +9,8 @@
from ..exceptions import NotCreatedError, UpdateError, DeleteError, ValidationError
from ..fields import ReferenceField, ListField, EmbeddedField

from .tools import cook_find_filter


class TxMongoDal(AbstractDal):

Expand Down Expand Up @@ -76,27 +78,31 @@ def io_validate(self, validate_all=False):

@classmethod
@inlineCallbacks
def find_one(cls, *args, **kwargs):
ret = yield cls.collection.find_one(*args, **kwargs)
def find_one(cls, spec=None, *args, **kwargs):
# In txmongo, `spec` is for filtering and `filter` is for sorting
spec = cook_find_filter(cls, spec)
ret = yield cls.collection.find_one(*args, spec=spec, **kwargs)
if ret is not None:
ret = cls.build_from_mongo(ret)
ret = cls.build_from_mongo(ret, use_cls=True)
return ret

@classmethod
@inlineCallbacks
def find(cls, *args, **kwargs):
raw_cursor_or_list = yield cls.collection.find(*args, **kwargs)
def find(cls, spec=None, *args, **kwargs):
# In txmongo, `spec` is for filtering and `filter` is for sorting
spec = cook_find_filter(cls, spec)
raw_cursor_or_list = yield cls.collection.find(*args, spec=spec, **kwargs)
if isinstance(raw_cursor_or_list, tuple):

def wrap_raw_results(result):
cursor = result[1]
if cursor is not None:
cursor.addCallback(wrap_raw_results)
return ([cls.build_from_mongo(e) for e in result[0]], cursor)
return ([cls.build_from_mongo(e, use_cls=True) for e in result[0]], cursor)

return wrap_raw_results(raw_cursor_or_list)
else:
return [cls.build_from_mongo(e) for e in raw_cursor_or_list]
return [cls.build_from_mongo(e, use_cls=True) for e in raw_cursor_or_list]

@classmethod
@inlineCallbacks
Expand Down
8 changes: 6 additions & 2 deletions umongo/document.py
Expand Up @@ -2,12 +2,13 @@
from .exceptions import NotCreatedError, AbstractDocumentError
from .meta import MetaDocument, DocumentOpts
from .data_objects import Reference
from .registerer import retrieve_document

from bson import DBRef


def _base_opts():
opts = DocumentOpts({}, ())
opts = DocumentOpts('Document', {}, ())
opts.abstract = True
opts.allow_inheritance = True
opts.register_document = False
Expand Down Expand Up @@ -60,7 +61,10 @@ def dbref(self):
return DBRef(collection=self.collection.name, id=self.pk)

@classmethod
def build_from_mongo(cls, data, partial=False):
def build_from_mongo(cls, data, partial=False, use_cls=False):
# If a _cls is specified, we have to use this document class
if use_cls and '_cls' in data:
cls = retrieve_document(data['_cls'])
doc = cls()
doc.from_mongo(data, partial=partial)
return doc
Expand Down
10 changes: 7 additions & 3 deletions umongo/meta.py
Expand Up @@ -79,10 +79,11 @@ def __repr__(self):
'custom_indexes={self.custom_indexes}, '
'collection={self.collection}, '
'lazy_collection={self.lazy_collection}, '
'dal={self.dal})>'
'dal={self.dal},'
'children={self.children})>'
.format(ClassName=self.__class__.__name__, self=self))

def __init__(self, nmspc, bases):
def __init__(self, name, nmspc, bases):
meta = nmspc.get('Meta')
self.abstract = getattr(meta, 'abstract', False)
self.allow_inheritance = getattr(meta, 'allow_inheritance', self.abstract)
Expand All @@ -92,11 +93,14 @@ def __init__(self, nmspc, bases):
self.base_schema_cls = getattr(meta, 'base_schema_cls', Schema)
self.indexes, self.custom_indexes = _collect_indexes(nmspc, bases)
self.is_child = _is_child(bases)
self.children = set()
if self.abstract and not self.allow_inheritance:
raise DocumentDefinitionError("Abstract document cannot disable inheritance")
# Handle option inheritance and integrity checks
for base in bases:
popts = base.opts
# Notify the parent of it newborn !
popts.children.add(name)
if not popts.allow_inheritance:
raise DocumentDefinitionError("Document %r doesn't allow inheritance" % base)
if self.abstract and not popts.abstract:
Expand Down Expand Up @@ -173,7 +177,7 @@ def __new__(cls, name, bases, nmspc):
# Generic handling (i.e. for all other documents)
assert '_cls' not in nmspc, '`_cls` is a reserved attribute'
# Generate options from the Meta class and inheritance
opts = DocumentOpts(nmspc, bases)
opts = DocumentOpts(name, nmspc, bases)
# If Document is a child, _cls field must be added to the schema
if opts.is_child:
from .fields import StrField
Expand Down

0 comments on commit 9116357

Please sign in to comment.