Skip to content

Commit

Permalink
Remove params in as_marshmallow_schema/field.
Browse files Browse the repository at this point in the history
Use a base schema class instead.
  • Loading branch information
lafrech committed Apr 29, 2020
1 parent 3335197 commit 32bd9b8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 107 deletions.
61 changes: 0 additions & 61 deletions tests/test_marshmallow.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,63 +93,6 @@ class Bag(MyDocument):
ma_schema = Bag.schema.as_marshmallow_schema()
assert ma_schema().load(data) == excl_data

def test_customize_params(self):
ma_field = self.User.schema.fields['name'].as_marshmallow_field(params={'load_only': True})
assert ma_field.load_only is True

ma_schema_cls = self.User.schema.as_marshmallow_schema(
params={'name': {'load_only': True, 'dump_only': True}})
schema = ma_schema_cls()
ret = schema.dump({'name': "42", 'birthday': datetime(1990, 10, 23), 'dummy': False})
assert ret == {'birthday': '1990-10-23T00:00:00'}
with pytest.raises(marshmallow.ValidationError) as excinfo:
schema.load({'name': "42", 'birthday': '1990-10-23T00:00:00', 'dummy': False})
assert excinfo.value.messages == {'name': ['Unknown field.'], 'dummy': ['Unknown field.']}
ret = schema.load({'birthday': '1990-10-23T00:00:00'})
assert ret == {'birthday': datetime(1990, 10, 23)}

def test_customize_nested_and_inner_params(self):
@self.instance.register
class Accessory(EmbeddedDocument):
brief = fields.StrField(attribute='id', required=True)
value = fields.IntField()

@self.instance.register
class Bag(Document):
id = fields.EmbeddedField(Accessory, attribute='_id', required=True)
names = fields.ListField(fields.StringField)
content = fields.ListField(fields.EmbeddedField(Accessory))
relations = fields.DictField(fields.StringField, fields.StringField)
inventory = fields.DictField(fields.StringField, fields.EmbeddedField(Accessory))

ma_field = Bag.schema.fields['id'].as_marshmallow_field(params={
'load_only': True,
'params': {'value': {'dump_only': True}}})
assert ma_field.load_only is True
assert ma_field.nested._declared_fields['value'].dump_only
ma_field = Bag.schema.fields['names'].as_marshmallow_field(params={
'load_only': True,
'params': {'dump_only': True}})
assert ma_field.load_only is True
assert ma_field.inner.dump_only is True
ma_field = Bag.schema.fields['content'].as_marshmallow_field(params={
'load_only': True,
'params': {'required': True, 'params': {'value': {'dump_only': True}}}})
assert ma_field.load_only is True
assert ma_field.inner.required is True
assert ma_field.inner.nested._declared_fields['value'].dump_only
ma_field = Bag.schema.fields['relations'].as_marshmallow_field(params={
'load_only': True,
'params': {'dump_only': True}})
assert ma_field.load_only is True
assert ma_field.value_field.dump_only is True
ma_field = Bag.schema.fields['inventory'].as_marshmallow_field(params={
'load_only': True,
'params': {'required': True, 'params': {'value': {'dump_only': True}}}})
assert ma_field.load_only is True
assert ma_field.value_field.required is True
assert ma_field.value_field.nested._declared_fields['value'].dump_only

def test_as_marshmallow_field_pass_params(self):
@self.instance.register
class MyDoc(Document):
Expand Down Expand Up @@ -205,10 +148,6 @@ class MyDoc(Document):
def test_as_marshmallow_schema_cache(self):
ma_schema_cls = self.User.schema.as_marshmallow_schema()

new_ma_schema_cls = self.User.schema.as_marshmallow_schema(
params={'name': {'load_only': True}})
assert new_ma_schema_cls != ma_schema_cls

new_ma_schema_cls = self.User.schema.as_marshmallow_schema(
mongo_world=True)
assert new_ma_schema_cls != ma_schema_cls
Expand Down
27 changes: 8 additions & 19 deletions umongo/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,21 @@ def map_to_field(self, func):
if hasattr(field, 'map_to_field'):
field.map_to_field(mongo_path, name, func)

def as_marshmallow_schema(self, *, params=None, mongo_world=False):
def as_marshmallow_schema(self, *, mongo_world=False):
"""
Return a pure-marshmallow version of this schema class.
:param params: Per-field dict to pass parameters to their field creation.
:param mongo_world: If True the schema will work against the mongo world
instead of the OO world (default: False).
"""
params = params or {}
# Use hashable parameters as cache dict key and dict parameters for manual comparison
# Use a cache to avoid generating several times the same schema
cache_key = (self.__class__, self.MA_BASE_SCHEMA_CLS, mongo_world)
cache_modifiers = (params, )
if cache_key in self._marshmallow_schemas_cache:
for modifiers, ma_schema in self._marshmallow_schemas_cache[cache_key]:
if modifiers == cache_modifiers:
return ma_schema
return self._marshmallow_schemas_cache[cache_key]

# Create schema if not found in cache
nmspc = {
name: field.as_marshmallow_field(
params=params.get(name),
mongo_world=mongo_world,
)
name: field.as_marshmallow_field(mongo_world=mongo_world)
for name, field in self.fields.items()
}
name = 'Marshmallow%s' % type(self).__name__
Expand All @@ -73,8 +67,7 @@ def as_marshmallow_schema(self, *, params=None, mongo_world=False):
# when error_messages is updated with _default_error_messages.
m_schema._default_error_messages = {
k: _(v) for k, v in m_schema._default_error_messages.items()}
self._marshmallow_schemas_cache.setdefault(cache_key, []).append(
(cache_modifiers, m_schema))
self._marshmallow_schemas_cache[cache_key] = m_schema
return m_schema


Expand Down Expand Up @@ -204,18 +197,14 @@ def _extract_marshmallow_field_params(self, mongo_world):
params.update(self.metadata)
return params

def as_marshmallow_field(self, *, params=None, mongo_world=False, **kwargs):
def as_marshmallow_field(self, *, mongo_world=False, **kwargs):
"""
Return a pure-marshmallow version of this field.
:param params: Additional parameters passed to the marshmallow field
class constructor.
:param mongo_world: If True the field will work against the mongo world
instead of the OO world (default: False)
"""
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
if params:
field_kwargs.update(params)
# Retrieve the marshmallow class we inherit from
for m_class in type(self).mro():
if (not issubclass(m_class, BaseField) and
Expand Down
35 changes: 8 additions & 27 deletions umongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,13 @@ def _deserialize_from_mongo(self, value):
)
return Dict(self.key_field, self.value_field)

def as_marshmallow_field(self, params=None, mongo_world=False, **kwargs):
def as_marshmallow_field(self, mongo_world=False, **kwargs):
# Overwrite default `as_marshmallow_field` to handle deserialization
# difference (`_id` vs `id`)
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
if params:
inner_params = params.pop('params', None)
field_kwargs.update(params)
else:
inner_params = None
if self.value_field:
inner_ma_schema = self.value_field.as_marshmallow_field(
mongo_world=mongo_world, params=inner_params, **kwargs)
mongo_world=mongo_world, **kwargs)
else:
inner_ma_schema = None
return ma_fields.Dict(self.key_field, inner_ma_schema, **field_kwargs)
Expand Down Expand Up @@ -287,17 +282,12 @@ def map_to_field(self, mongo_path, path, func):
if hasattr(self.inner, 'map_to_field'):
self.inner.map_to_field(mongo_path, path, func)

def as_marshmallow_field(self, params=None, mongo_world=False, **kwargs):
def as_marshmallow_field(self, mongo_world=False, **kwargs):
# Overwrite default `as_marshmallow_field` to handle deserialization
# difference (`_id` vs `id`)
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
if params:
inner_params = params.pop('params', None)
field_kwargs.update(params)
else:
inner_params = None
inner_ma_schema = self.inner.as_marshmallow_field(
mongo_world=mongo_world, params=inner_params, **kwargs)
mongo_world=mongo_world, **kwargs)
return ma_fields.List(inner_ma_schema, **field_kwargs)

def _required_validate(self, value):
Expand Down Expand Up @@ -388,12 +378,10 @@ def _serialize_to_mongo(self, obj):
def _deserialize_from_mongo(self, value):
return self.reference_cls(self.document_cls, value)

def as_marshmallow_field(self, params=None, mongo_world=False, **kwargs):
def as_marshmallow_field(self, mongo_world=False, **kwargs):
# Overwrite default `as_marshmallow_field` to handle deserialization
# difference (`_id` vs `id`)
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
if params:
field_kwargs.update(params)
return ma_bonus_fields.Reference(mongo_world=mongo_world, **field_kwargs)


Expand Down Expand Up @@ -446,12 +434,10 @@ def _deserialize_from_mongo(self, value):
document_cls = self._document_cls(value['_cls'])
return self.reference_cls(document_cls, value['_id'])

def as_marshmallow_field(self, params=None, mongo_world=False, **kwargs):
def as_marshmallow_field(self, mongo_world=False, **kwargs):
# Overwrite default `as_marshmallow_field` to handle deserialization
# difference (`_id` vs `id`)
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
if params:
field_kwargs.update(params)
return ma_bonus_fields.GenericReference(mongo_world=mongo_world, **field_kwargs)


Expand Down Expand Up @@ -566,16 +552,11 @@ def map_to_field(self, mongo_path, path, func):
if hasattr(field, 'map_to_field'):
field.map_to_field(cur_mongo_path, cur_path, func)

def as_marshmallow_field(self, params=None, mongo_world=False, **kwargs):
def as_marshmallow_field(self, mongo_world=False, **kwargs):
# Overwrite default `as_marshmallow_field` to handle nesting
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
if params:
nested_params = params.pop('params', None)
field_kwargs.update(params)
else:
nested_params = None
nested_ma_schema = self._embedded_document_cls.schema.as_marshmallow_schema(
params=nested_params, mongo_world=mongo_world)
mongo_world=mongo_world)
return ma_fields.Nested(nested_ma_schema, **field_kwargs)

def _required_validate(self, value):
Expand Down

0 comments on commit 32bd9b8

Please sign in to comment.