Skip to content

Commit

Permalink
Merge pull request #263 from Scille/base_schema_cls
Browse files Browse the repository at this point in the history
Set MA_BASE_SCHEMA_CLS in Document / EmbeddedDocument
  • Loading branch information
lafrech committed Apr 29, 2020
2 parents 8a0aaf9 + 32bd9b8 commit e0305e8
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 210 deletions.
209 changes: 64 additions & 145 deletions tests/test_marshmallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,111 +41,57 @@ def test_by_schema(self):
assert issubclass(ma_schema_cls, marshmallow.Schema)
assert not issubclass(ma_schema_cls, BaseSchema)

def test_custom_base_schema(self):
def test_custom_ma_base_schema_cls(self):

class MyBaseSchema(marshmallow.Schema):
name = marshmallow.fields.Int()
age = marshmallow.fields.Int()
# Define custom marshmallow schema base class
class ExcludeBaseSchema(marshmallow.Schema):
class Meta:
unknown = marshmallow.EXCLUDE

ma_schema_cls = self.User.schema.as_marshmallow_schema(base_schema_cls=MyBaseSchema)
assert issubclass(ma_schema_cls, MyBaseSchema)
# Typically, we'll use it in all our schemas, so let's define base
# Document and EmbeddedDocument classes using this base schema class
@self.instance.register
class MyDocument(Document):
MA_BASE_SCHEMA_CLS = ExcludeBaseSchema

schema = ma_schema_cls()
assert schema.dump({'name': "42", 'age': 42, 'dummy': False}) == {'name': "42", 'age': 42}
with pytest.raises(marshmallow.ValidationError) as excinfo:
schema.load({'name': "42", 'age': 42, 'dummy': False})
assert excinfo.value.messages == {'dummy': ['Unknown field.']}
assert schema.load({'name': "42", 'age': 42}) == {'name': "42", 'age': 42}
class Meta:
allow_inheritance = True

def test_customize_params(self):
ma_field = self.User.schema.fields['name'].as_marshmallow_field(params={'load_only': True})
assert ma_field.load_only is True
@self.instance.register
class MyEmbeddedDocument(EmbeddedDocument):
MA_BASE_SCHEMA_CLS = ExcludeBaseSchema

ma_schema_cls = self.User.schema.as_marshmallow_schema(
params={'name': {'load_only': True, 'dump_only': True}})
schema = ma_schema_cls()
ret = schema.dump({'name': "42", 'birthday': datetime(1990, 10, 23), 'dummy': False})
assert ret == {'birthday': '1990-10-23T00:00:00'}
with pytest.raises(marshmallow.ValidationError) as excinfo:
schema.load({'name': "42", 'birthday': '1990-10-23T00:00:00', 'dummy': False})
assert excinfo.value.messages == {'name': ['Unknown field.'], 'dummy': ['Unknown field.']}
ret = schema.load({'birthday': '1990-10-23T00:00:00'})
assert ret == {'birthday': datetime(1990, 10, 23)}
class Meta:
allow_inheritance = True

def test_customize_nested_and_inner_params(self):
# Now, all our objects will generate "exclude" marshmallow schemas
@self.instance.register
class Accessory(EmbeddedDocument):
brief = fields.StrField(attribute='id', required=True)
class Accessory(MyEmbeddedDocument):
brief = fields.StrField()
value = fields.IntField()

@self.instance.register
class Bag(Document):
id = fields.EmbeddedField(Accessory, attribute='_id', required=True)
names = fields.ListField(fields.StringField)
class Bag(MyDocument):
item = fields.EmbeddedField(Accessory)
content = fields.ListField(fields.EmbeddedField(Accessory))
relations = fields.DictField(fields.StringField, fields.StringField)
inventory = fields.DictField(fields.StringField, fields.EmbeddedField(Accessory))

ma_field = Bag.schema.fields['id'].as_marshmallow_field(params={
'load_only': True,
'params': {'value': {'dump_only': True}}})
assert ma_field.load_only is True
assert ma_field.nested._declared_fields['value'].dump_only
ma_field = Bag.schema.fields['names'].as_marshmallow_field(params={
'load_only': True,
'params': {'dump_only': True}})
assert ma_field.load_only is True
assert ma_field.inner.dump_only is True
ma_field = Bag.schema.fields['content'].as_marshmallow_field(params={
'load_only': True,
'params': {'required': True, 'params': {'value': {'dump_only': True}}}})
assert ma_field.load_only is True
assert ma_field.inner.required is True
assert ma_field.inner.nested._declared_fields['value'].dump_only
ma_field = Bag.schema.fields['relations'].as_marshmallow_field(params={
'load_only': True,
'params': {'dump_only': True}})
assert ma_field.load_only is True
assert ma_field.value_field.dump_only is True
ma_field = Bag.schema.fields['inventory'].as_marshmallow_field(params={
'load_only': True,
'params': {'required': True, 'params': {'value': {'dump_only': True}}}})
assert ma_field.load_only is True
assert ma_field.value_field.required is True
assert ma_field.value_field.nested._declared_fields['value'].dump_only

def test_pass_meta_attributes(self):
@self.instance.register
class Accessory(EmbeddedDocument):
brief = fields.StrField(attribute='id', required=True)
value = fields.IntField()

@self.instance.register
class Bag(Document):
id = fields.EmbeddedField(Accessory, attribute='_id', required=True)
content = fields.ListField(fields.EmbeddedField(Accessory))
inventory = fields.DictField(fields.StringField, fields.EmbeddedField(Accessory))

ma_schema = Bag.schema.as_marshmallow_schema(meta={'exclude': ('id',)})
assert ma_schema.Meta.exclude == ('id',)
ma_schema = Bag.schema.as_marshmallow_schema(params={
'id': {'meta': {'exclude': ('value',)}}})
assert ma_schema._declared_fields['id'].nested.Meta.exclude == ('value',)
ma_schema = Bag.schema.as_marshmallow_schema(params={
'content': {'params': {'meta': {'exclude': ('value',)}}}})
assert ma_schema._declared_fields['content'].inner.nested.Meta.exclude == ('value',)
ma_schema = Bag.schema.as_marshmallow_schema(params={
'inventory': {'params': {'meta': {'exclude': ('value',)}}}})
assert ma_schema._declared_fields['inventory'].value_field.nested.Meta.exclude == ('value',)

class DumpOnlyIdSchema(marshmallow.Schema):
class Meta:
dump_only = ('id',)
data = {
'item': {'brief': 'sportbag', 'value': 100, 'name': 'Unknown'},
'content': [
{'brief': 'cellphone', 'value': 500, 'name': 'Unknown'},
{'brief': 'lighter', 'value': 2, 'name': 'Unknown'}
],
'name': 'Unknown',
}
excl_data = {
'item': {'brief': 'sportbag', 'value': 100},
'content': [
{'brief': 'cellphone', 'value': 500},
{'brief': 'lighter', 'value': 2}]
}

ma_custom_base_schema = Bag.schema.as_marshmallow_schema(
base_schema_cls=DumpOnlyIdSchema, meta={'exclude': ('content',)})
assert ma_custom_base_schema.Meta.exclude == ('content',)
assert ma_custom_base_schema.Meta.dump_only == ('id',)
ma_schema = Bag.schema.as_marshmallow_schema()
assert ma_schema().load(data) == excl_data

def test_as_marshmallow_field_pass_params(self):
@self.instance.register
Expand Down Expand Up @@ -202,29 +148,10 @@ class MyDoc(Document):
def test_as_marshmallow_schema_cache(self):
ma_schema_cls = self.User.schema.as_marshmallow_schema()

new_ma_schema_cls = self.User.schema.as_marshmallow_schema(
params={'name': {'load_only': True}})
assert new_ma_schema_cls != ma_schema_cls

new_ma_schema_cls = self.User.schema.as_marshmallow_schema(
meta={'exclude': ('name',)})
assert new_ma_schema_cls != ma_schema_cls

new_ma_schema_cls = self.User.schema.as_marshmallow_schema(
check_unknown_fields=False)
assert new_ma_schema_cls != ma_schema_cls

new_ma_schema_cls = self.User.schema.as_marshmallow_schema(
mongo_world=True)
assert new_ma_schema_cls != ma_schema_cls

class MyBaseSchema(marshmallow.Schema):
pass

new_ma_schema_cls = self.User.schema.as_marshmallow_schema(
base_schema_cls=MyBaseSchema)
assert new_ma_schema_cls != ma_schema_cls

new_ma_schema_cls = self.User.schema.as_marshmallow_schema()
assert new_ma_schema_cls == ma_schema_cls

Expand Down Expand Up @@ -307,14 +234,31 @@ def my_gettext(message):
'birthday': ['OMG !!! Not a valid datetime.'],
'dummy_field': ['OMG !!! Unknown field.']}

def test_unknow_fields_check(self):
ma_schema_cls = self.User.schema.as_marshmallow_schema()
def test_unknown_fields(self):

class ExcludeBaseSchema(marshmallow.Schema):
class Meta:
unknown = marshmallow.EXCLUDE

@self.instance.register
class ExcludeUser(self.User):
MA_BASE_SCHEMA_CLS = ExcludeBaseSchema

user_ma_schema_cls = self.User.schema.as_marshmallow_schema()
assert issubclass(user_ma_schema_cls, marshmallow.Schema)
exclude_user_ma_schema_cls = ExcludeUser.schema.as_marshmallow_schema()
assert issubclass(exclude_user_ma_schema_cls, ExcludeBaseSchema)

data = {'name': 'John', 'dummy': 'dummy'}
excl_data = {'name': 'John'}

# By default, marshmallow schemas raise on unknown fields
with pytest.raises(marshmallow.ValidationError) as excinfo:
ma_schema_cls().load({'name': 'John', 'dummy_field': 'dummy'})
assert excinfo.value.messages == {'dummy_field': ['Unknown field.']}
user_ma_schema_cls().load(data)
assert excinfo.value.messages == {'dummy': ['Unknown field.']}

ma_schema_cls = self.User.schema.as_marshmallow_schema(check_unknown_fields=False)
assert ma_schema_cls().load({'name': 'John', 'dummy_field': 'dummy'}) == {'name': 'John'}
# With custom schema, exclude unknown fields
assert exclude_user_ma_schema_cls().load(data) == excl_data

def test_missing_accessor(self):

Expand Down Expand Up @@ -359,8 +303,9 @@ class Bag(Document):
'id': {'brief': 'sportbag', 'value': 100},
'content': [{'brief': 'cellphone', 'value': 500}, {'brief': 'lighter', 'value': 2}]
}
# Here data is the same in both OO world and user world (no
# ObjectId to str conversion needed for example)

# Here data is the same in both OO world and user world
# (no ObjectId to str conversion needed for example)

ma_schema = Bag.schema.as_marshmallow_schema()()
ma_mongo_schema = Bag.schema.as_marshmallow_schema(mongo_world=True)()
Expand All @@ -372,32 +317,6 @@ class Bag(Document):
assert ma_mongo_schema.dump(bag.to_mongo()) == data
assert ma_mongo_schema.load(data) == bag.to_mongo()

# Check as_marshmallow_schema params (check_unknown_felds, base_schema_cls)
# are passed to nested schemas
data = {
'id': {'brief': 'sportbag', 'value': 100, 'name': 'Unknown'},
'content': [
{'brief': 'cellphone', 'value': 500, 'name': 'Unknown'},
{'brief': 'lighter', 'value': 2, 'name': 'Unknown'}]
}
with pytest.raises(marshmallow.ValidationError) as excinfo:
ma_schema.load(data)
assert excinfo.value.messages == {
'id': {'name': ['Unknown field.']},
'content': {
0: {'name': ['Unknown field.']},
1: {'name': ['Unknown field.']},
}}

ma_no_check_unknown_schema = Bag.schema.as_marshmallow_schema(check_unknown_fields=False)()
ma_no_check_unknown_schema.load(data)

class WithNameSchema(marshmallow.Schema):
name = marshmallow.fields.Str()

ma_custom_base_schema = Bag.schema.as_marshmallow_schema(base_schema_cls=WithNameSchema)()
ma_custom_base_schema.load(data)

def test_marshmallow_bonus_fields(self):
# Fields related to mongodb provided for marshmallow
@self.instance.register
Expand Down
45 changes: 12 additions & 33 deletions umongo/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from marshmallow import (Schema as MaSchema, fields as ma_fields,
validate as ma_validate, missing, EXCLUDE)
validate as ma_validate, missing)

from .i18n import gettext as _, N_
from .marshmallow_bonus import schema_from_umongo_get_attribute
Expand All @@ -18,6 +18,7 @@ class BaseSchema(MaSchema):
"""
All schema used in umongo should inherit from this base schema
"""
MA_BASE_SCHEMA_CLS = MaSchema

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -38,53 +39,35 @@ def map_to_field(self, func):
if hasattr(field, 'map_to_field'):
field.map_to_field(mongo_path, name, func)

def as_marshmallow_schema(self, params=None, base_schema_cls=MaSchema,
check_unknown_fields=True, mongo_world=False, meta=None):
def as_marshmallow_schema(self, *, mongo_world=False):
"""
Return a pure-marshmallow version of this schema class.
:param params: Per-field dict to pass parameters to their field creation.
:param base_schema_cls: Class the schema will inherit from (
default: :class:`marshmallow.Schema`).
:param check_unknown_fields: Unknown fields are considered as errors (default: True).
:param mongo_world: If True the schema will work against the mongo world
instead of the OO world (default: False).
:param meta: Optional dict with attributes for the schema's Meta class.
"""
params = params or {}
meta = meta or {}
# Use hashable parameters as cache dict key and dict parameters for manual comparison
cache_key = (self.__class__, base_schema_cls, check_unknown_fields, mongo_world)
cache_modifiers = (params, meta)
# Use a cache to avoid generating several times the same schema
cache_key = (self.__class__, self.MA_BASE_SCHEMA_CLS, mongo_world)
if cache_key in self._marshmallow_schemas_cache:
for modifiers, ma_schema in self._marshmallow_schemas_cache[cache_key]:
if modifiers == cache_modifiers:
return ma_schema
return self._marshmallow_schemas_cache[cache_key]

# Create schema if not found in cache
nmspc = {
name: field.as_marshmallow_field(
params=params.get(name),
base_schema_cls=base_schema_cls,
check_unknown_fields=check_unknown_fields,
mongo_world=mongo_world)
name: field.as_marshmallow_field(mongo_world=mongo_world)
for name, field in self.fields.items()
}
name = 'Marshmallow%s' % type(self).__name__
if not check_unknown_fields:
meta.setdefault('unknown', EXCLUDE)
# By default OO world returns `missing` fields as `None`,
# disable this behavior here to let marshmallow deal with it
if not mongo_world:
nmspc['get_attribute'] = schema_from_umongo_get_attribute
if meta:
nmspc['Meta'] = type('Meta', (base_schema_cls.Meta,), meta)
m_schema = type(name, (base_schema_cls, ), nmspc)
m_schema = type(name, (self.MA_BASE_SCHEMA_CLS, ), nmspc)
# Add i18n support to the schema
# We can't use I18nErrorDict here because __getitem__ is not called
# when error_messages is updated with _default_error_messages.
m_schema._default_error_messages = {
k: _(v) for k, v in m_schema._default_error_messages.items()}
self._marshmallow_schemas_cache.setdefault(cache_key, []).append(
(cache_modifiers, m_schema))
self._marshmallow_schemas_cache[cache_key] = m_schema
return m_schema


Expand Down Expand Up @@ -214,18 +197,14 @@ def _extract_marshmallow_field_params(self, mongo_world):
params.update(self.metadata)
return params

def as_marshmallow_field(self, params=None, mongo_world=False, **kwargs):
def as_marshmallow_field(self, *, mongo_world=False, **kwargs):
"""
Return a pure-marshmallow version of this field.
:param params: Additional parameters passed to the marshmallow field
class constructor.
:param mongo_world: If True the field will work against the mongo world
instead of the OO world (default: False)
"""
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
if params:
field_kwargs.update(params)
# Retrieve the marshmallow class we inherit from
for m_class in type(self).mro():
if (not issubclass(m_class, BaseField) and
Expand Down
1 change: 1 addition & 0 deletions umongo/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def _build_schema(self, template, schema_bases, schema_fields, schema_non_fields
schema_nmspc = {}
schema_nmspc.update(schema_fields)
schema_nmspc.update(schema_non_fields)
schema_nmspc['MA_BASE_SCHEMA_CLS'] = template.MA_BASE_SCHEMA_CLS
return type('%sSchema' % template.__name__, schema_bases, schema_nmspc)

def build_document_from_template(self, template):
Expand Down
6 changes: 5 additions & 1 deletion umongo/document.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from copy import deepcopy

from bson import DBRef
from marshmallow import pre_load, post_load, pre_dump, post_dump, validates_schema # republishing
from marshmallow import (
pre_load, post_load, pre_dump, post_dump, validates_schema, # republishing
Schema as MaSchema
)

from .abstract import BaseDataObject
from .data_proxy import missing
Expand Down Expand Up @@ -38,6 +41,7 @@ class DocumentTemplate(Template):
or `marshmallow.post_dump`) to this class that will be passed
to the marshmallow schema internally used for this document.
"""
MA_BASE_SCHEMA_CLS = MaSchema


Document = DocumentTemplate
Expand Down
Loading

0 comments on commit e0305e8

Please sign in to comment.