Skip to content

Commit

Permalink
Remove mongo_world parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Sep 9, 2020
1 parent b6f96e5 commit 3b2402d
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 79 deletions.
27 changes: 0 additions & 27 deletions tests/test_marshmallow.py
Expand Up @@ -145,11 +145,6 @@ class MyDoc(Document):

def test_as_marshmallow_schema_cache(self):
ma_schema_cls = self.User.schema.as_marshmallow_schema()

new_ma_schema_cls = self.User.schema.as_marshmallow_schema(
mongo_world=True)
assert new_ma_schema_cls != ma_schema_cls

new_ma_schema_cls = self.User.schema.as_marshmallow_schema()
assert new_ma_schema_cls == ma_schema_cls

Expand Down Expand Up @@ -208,16 +203,11 @@ class Dog(Document):

payload = {'name': 'Scruffy', 'age': 2}
ma_schema_cls = Dog.schema.as_marshmallow_schema()
ma_mongo_schema_cls = Dog.schema.as_marshmallow_schema(mongo_world=True)

ret = ma_schema_cls().load(payload)
assert ret == {'name': 'Scruffy', 'age': 2}
assert ma_schema_cls().dump(ret) == payload

ret = ma_mongo_schema_cls().load(payload)
assert ret == {'_id': 'Scruffy', 'age': 2}
assert ma_mongo_schema_cls().dump(ret) == payload

def test_i18n(self):
# i18n support should be kept, because it's pretty cool to have this !
def my_gettext(message):
Expand Down Expand Up @@ -306,15 +296,11 @@ class Bag(Document):
# (no ObjectId to str conversion needed for example)

ma_schema = Bag.schema.as_marshmallow_schema()()
ma_mongo_schema = Bag.schema.as_marshmallow_schema(mongo_world=True)()

bag = Bag(**data)
assert ma_schema.dump(bag) == data
assert ma_schema.load(data) == data

assert ma_mongo_schema.dump(bag.to_mongo()) == data
assert ma_mongo_schema.load(data) == bag.to_mongo()

def test_marshmallow_bonus_fields(self):
# Fields related to mongodb provided for marshmallow
@self.instance.register
Expand Down Expand Up @@ -343,9 +329,6 @@ class Doc(Document):
"gen_ref": {'cls': 'Doc', 'id': "57c1a71113adf27ab96b2c4f"}
}
doc = Doc(**oo_data)
mongo_data = doc.to_mongo()

# schema to OO world
ma_schema_cls = Doc.schema.as_marshmallow_schema()
ma_schema = ma_schema_cls()
# Dump uMongo object
Expand All @@ -355,16 +338,6 @@ class Doc(Document):
# Load serialized data
assert ma_schema.load(serialized) == oo_data

# schema to mongo world
ma_mongo_schema_cls = Doc.schema.as_marshmallow_schema(mongo_world=True)
ma_mongo_schema = ma_mongo_schema_cls()
assert ma_mongo_schema.dump(mongo_data) == serialized
assert ma_mongo_schema.load(serialized) == mongo_data
# Cannot load mongo form
with pytest.raises(ma.ValidationError) as excinfo:
ma_mongo_schema.load({"gen_ref": {'_cls': 'Doc', '_id': "57c1a71113adf27ab96b2c4f"}})
assert excinfo.value.messages == {'gen_ref': ['Generic reference must have `id` and `cls` fields.']}

def test_marshmallow_bonus_objectid_field(self):

class DocSchema(ma.Schema):
Expand Down
17 changes: 4 additions & 13 deletions umongo/abstract.py
Expand Up @@ -146,31 +146,22 @@ def _serialize_to_mongo(self, obj):
def _deserialize_from_mongo(self, value):
return value

def _extract_marshmallow_field_params(self, mongo_world):
def _extract_marshmallow_field_params(self):
params = {
attribute: getattr(self, attribute)
for attribute in (
'validate', 'required', 'allow_none',
'load_only', 'dump_only', 'error_messages'
)
}
if mongo_world and self.attribute:
params['attribute'] = self.attribute

# Override uMongo attributes with marshmallow_ prefixed attributes
params.update(self._ma_kwargs)

params.update(self.metadata)
return params

def as_marshmallow_field(self, *, mongo_world=False, **kwargs):
"""
Return a pure-marshmallow version of this field.
:param mongo_world: If True the field will work against the mongo world
instead of the OO world (default: False)
"""
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
def as_marshmallow_field(self):
"""Return a pure-marshmallow version of this field"""
field_kwargs = self._extract_marshmallow_field_params()
# Retrieve the marshmallow class we inherit from
for m_class in type(self).mro():
if (not issubclass(m_class, BaseField) and
Expand Down
33 changes: 15 additions & 18 deletions umongo/fields.py
Expand Up @@ -218,13 +218,12 @@ def _deserialize_from_mongo(self, value):
)
return Dict(self.key_field, self.value_field)

def as_marshmallow_field(self, mongo_world=False, **kwargs):
def as_marshmallow_field(self):
# Overwrite default `as_marshmallow_field` to handle deserialization
# difference (`_id` vs `id`)
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
field_kwargs = self._extract_marshmallow_field_params()
if self.value_field:
inner_ma_schema = self.value_field.as_marshmallow_field(
mongo_world=mongo_world, **kwargs)
inner_ma_schema = self.value_field.as_marshmallow_field()
else:
inner_ma_schema = None
return ma.fields.Dict(self.key_field, inner_ma_schema, **field_kwargs)
Expand Down Expand Up @@ -282,12 +281,11 @@ def map_to_field(self, mongo_path, path, func):
if hasattr(self.inner, 'map_to_field'):
self.inner.map_to_field(mongo_path, path, func)

def as_marshmallow_field(self, mongo_world=False, **kwargs):
def as_marshmallow_field(self):
# Overwrite default `as_marshmallow_field` to handle deserialization
# difference (`_id` vs `id`)
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
inner_ma_schema = self.inner.as_marshmallow_field(
mongo_world=mongo_world, **kwargs)
field_kwargs = self._extract_marshmallow_field_params()
inner_ma_schema = self.inner.as_marshmallow_field()
return ma.fields.List(inner_ma_schema, **field_kwargs)

def _required_validate(self, value):
Expand Down Expand Up @@ -376,11 +374,11 @@ def _serialize_to_mongo(self, obj):
def _deserialize_from_mongo(self, value):
return self.reference_cls(self.document_cls, value)

def as_marshmallow_field(self, mongo_world=False, **kwargs):
def as_marshmallow_field(self):
# Overwrite default `as_marshmallow_field` to handle deserialization
# difference (`_id` vs `id`)
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
return ma_bonus_fields.Reference(mongo_world=mongo_world, **field_kwargs)
field_kwargs = self._extract_marshmallow_field_params()
return ma_bonus_fields.Reference(**field_kwargs)


class GenericReferenceField(BaseField, ma_bonus_fields.GenericReference):
Expand Down Expand Up @@ -432,11 +430,11 @@ def _deserialize_from_mongo(self, value):
document_cls = self._document_cls(value['_cls'])
return self.reference_cls(document_cls, value['_id'])

def as_marshmallow_field(self, mongo_world=False, **kwargs):
def as_marshmallow_field(self):
# Overwrite default `as_marshmallow_field` to handle deserialization
# difference (`_id` vs `id`)
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
return ma_bonus_fields.GenericReference(mongo_world=mongo_world, **field_kwargs)
field_kwargs = self._extract_marshmallow_field_params()
return ma_bonus_fields.GenericReference(**field_kwargs)


class EmbeddedField(BaseField, ma.fields.Nested):
Expand Down Expand Up @@ -553,11 +551,10 @@ def map_to_field(self, mongo_path, path, func):
if hasattr(field, 'map_to_field'):
field.map_to_field(cur_mongo_path, cur_path, func)

def as_marshmallow_field(self, mongo_world=False, **kwargs):
def as_marshmallow_field(self):
# Overwrite default `as_marshmallow_field` to handle nesting
field_kwargs = self._extract_marshmallow_field_params(mongo_world)
nested_ma_schema = self._embedded_document_cls.schema.as_marshmallow_schema(
mongo_world=mongo_world)
field_kwargs = self._extract_marshmallow_field_params()
nested_ma_schema = self._embedded_document_cls.schema.as_marshmallow_schema()
return ma.fields.Nested(nested_ma_schema, **field_kwargs)

def _required_validate(self, value):
Expand Down
14 changes: 2 additions & 12 deletions umongo/marshmallow_bonus.py
Expand Up @@ -34,16 +34,12 @@ class Reference(ObjectId):
Marshmallow field for :class:`umongo.fields.ReferenceField`
"""

def __init__(self, *args, mongo_world=False, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mongo_world = mongo_world

def _serialize(self, value, attr, obj):
if value is None:
return None
if self.mongo_world:
# In mongo world, value is a regular ObjectId
return str(value)
# In OO world, value is a :class:`umongo.data_object.Reference`
# or an ObjectId before being loaded into a Document
if isinstance(value, bson.ObjectId):
Expand All @@ -56,16 +52,12 @@ class GenericReference(ma.fields.Field):
Marshmallow field for :class:`umongo.fields.GenericReferenceField`
"""

def __init__(self, *args, mongo_world=False, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mongo_world = mongo_world

def _serialize(self, value, attr, obj):
if value is None:
return None
if self.mongo_world:
# In mongo world, value a dict of cls and id
return {'id': str(value['_id']), 'cls': value['_cls']}
# In OO world, value is a :class:`umongo.data_object.Reference`
# or a dict before being loaded into a Document
if isinstance(value, dict):
Expand All @@ -81,6 +73,4 @@ def _deserialize(self, value, attr, data, **kwargs):
_id = bson.ObjectId(value['id'])
except ValueError:
raise ma.ValidationError(_("Invalid `id` field."))
if self.mongo_world:
return {'_cls': value['cls'], '_id': _id}
return {'cls': value['cls'], 'id': _id}
13 changes: 4 additions & 9 deletions umongo/schema.py
Expand Up @@ -26,21 +26,16 @@ class Schema(BaseSchema):

_marshmallow_schemas_cache = {}

def as_marshmallow_schema(self, *, mongo_world=False):
"""
Return a pure-marshmallow version of this schema class.
:param mongo_world: If True the schema will work against the mongo world
instead of the OO world (default: False).
"""
def as_marshmallow_schema(self):
"""Return a pure-marshmallow version of this schema class"""
# Use a cache to avoid generating several times the same schema
cache_key = (self.__class__, self.MA_BASE_SCHEMA_CLS, mongo_world)
cache_key = (self.__class__, self.MA_BASE_SCHEMA_CLS)
if cache_key in self._marshmallow_schemas_cache:
return self._marshmallow_schemas_cache[cache_key]

# Create schema if not found in cache
nmspc = {
name: field.as_marshmallow_field(mongo_world=mongo_world)
name: field.as_marshmallow_field()
for name, field in self.fields.items()
}
name = 'Marshmallow%s' % type(self).__name__
Expand Down

0 comments on commit 3b2402d

Please sign in to comment.