diff --git a/tests/test_marshmallow.py b/tests/test_marshmallow.py index 26df886b..c0a5d089 100644 --- a/tests/test_marshmallow.py +++ b/tests/test_marshmallow.py @@ -145,11 +145,6 @@ class MyDoc(Document): def test_as_marshmallow_schema_cache(self): ma_schema_cls = self.User.schema.as_marshmallow_schema() - - new_ma_schema_cls = self.User.schema.as_marshmallow_schema( - mongo_world=True) - assert new_ma_schema_cls != ma_schema_cls - new_ma_schema_cls = self.User.schema.as_marshmallow_schema() assert new_ma_schema_cls == ma_schema_cls @@ -208,16 +203,11 @@ class Dog(Document): payload = {'name': 'Scruffy', 'age': 2} ma_schema_cls = Dog.schema.as_marshmallow_schema() - ma_mongo_schema_cls = Dog.schema.as_marshmallow_schema(mongo_world=True) ret = ma_schema_cls().load(payload) assert ret == {'name': 'Scruffy', 'age': 2} assert ma_schema_cls().dump(ret) == payload - ret = ma_mongo_schema_cls().load(payload) - assert ret == {'_id': 'Scruffy', 'age': 2} - assert ma_mongo_schema_cls().dump(ret) == payload - def test_i18n(self): # i18n support should be kept, because it's pretty cool to have this ! def my_gettext(message): @@ -306,15 +296,11 @@ class Bag(Document): # (no ObjectId to str conversion needed for example) ma_schema = Bag.schema.as_marshmallow_schema()() - ma_mongo_schema = Bag.schema.as_marshmallow_schema(mongo_world=True)() bag = Bag(**data) assert ma_schema.dump(bag) == data assert ma_schema.load(data) == data - assert ma_mongo_schema.dump(bag.to_mongo()) == data - assert ma_mongo_schema.load(data) == bag.to_mongo() - def test_marshmallow_bonus_fields(self): # Fields related to mongodb provided for marshmallow @self.instance.register @@ -343,9 +329,6 @@ class Doc(Document): "gen_ref": {'cls': 'Doc', 'id': "57c1a71113adf27ab96b2c4f"} } doc = Doc(**oo_data) - mongo_data = doc.to_mongo() - - # schema to OO world ma_schema_cls = Doc.schema.as_marshmallow_schema() ma_schema = ma_schema_cls() # Dump uMongo object @@ -355,16 +338,6 @@ class Doc(Document): # Load serialized data assert ma_schema.load(serialized) == oo_data - # schema to mongo world - ma_mongo_schema_cls = Doc.schema.as_marshmallow_schema(mongo_world=True) - ma_mongo_schema = ma_mongo_schema_cls() - assert ma_mongo_schema.dump(mongo_data) == serialized - assert ma_mongo_schema.load(serialized) == mongo_data - # Cannot load mongo form - with pytest.raises(ma.ValidationError) as excinfo: - ma_mongo_schema.load({"gen_ref": {'_cls': 'Doc', '_id': "57c1a71113adf27ab96b2c4f"}}) - assert excinfo.value.messages == {'gen_ref': ['Generic reference must have `id` and `cls` fields.']} - def test_marshmallow_bonus_objectid_field(self): class DocSchema(ma.Schema): diff --git a/umongo/abstract.py b/umongo/abstract.py index 2343f612..6fd3c663 100644 --- a/umongo/abstract.py +++ b/umongo/abstract.py @@ -146,7 +146,7 @@ def _serialize_to_mongo(self, obj): def _deserialize_from_mongo(self, value): return value - def _extract_marshmallow_field_params(self, mongo_world): + def _extract_marshmallow_field_params(self): params = { attribute: getattr(self, attribute) for attribute in ( @@ -154,23 +154,14 @@ def _extract_marshmallow_field_params(self, mongo_world): 'load_only', 'dump_only', 'error_messages' ) } - if mongo_world and self.attribute: - params['attribute'] = self.attribute - # Override uMongo attributes with marshmallow_ prefixed attributes params.update(self._ma_kwargs) - params.update(self.metadata) return params - def as_marshmallow_field(self, *, mongo_world=False, **kwargs): - """ - Return a pure-marshmallow version of this field. - - :param mongo_world: If True the field will work against the mongo world - instead of the OO world (default: False) - """ - field_kwargs = self._extract_marshmallow_field_params(mongo_world) + def as_marshmallow_field(self): + """Return a pure-marshmallow version of this field""" + field_kwargs = self._extract_marshmallow_field_params() # Retrieve the marshmallow class we inherit from for m_class in type(self).mro(): if (not issubclass(m_class, BaseField) and diff --git a/umongo/fields.py b/umongo/fields.py index 54fd2a7f..f9ce0c7a 100644 --- a/umongo/fields.py +++ b/umongo/fields.py @@ -218,13 +218,12 @@ def _deserialize_from_mongo(self, value): ) return Dict(self.key_field, self.value_field) - def as_marshmallow_field(self, mongo_world=False, **kwargs): + def as_marshmallow_field(self): # Overwrite default `as_marshmallow_field` to handle deserialization # difference (`_id` vs `id`) - field_kwargs = self._extract_marshmallow_field_params(mongo_world) + field_kwargs = self._extract_marshmallow_field_params() if self.value_field: - inner_ma_schema = self.value_field.as_marshmallow_field( - mongo_world=mongo_world, **kwargs) + inner_ma_schema = self.value_field.as_marshmallow_field() else: inner_ma_schema = None return ma.fields.Dict(self.key_field, inner_ma_schema, **field_kwargs) @@ -282,12 +281,11 @@ def map_to_field(self, mongo_path, path, func): if hasattr(self.inner, 'map_to_field'): self.inner.map_to_field(mongo_path, path, func) - def as_marshmallow_field(self, mongo_world=False, **kwargs): + def as_marshmallow_field(self): # Overwrite default `as_marshmallow_field` to handle deserialization # difference (`_id` vs `id`) - field_kwargs = self._extract_marshmallow_field_params(mongo_world) - inner_ma_schema = self.inner.as_marshmallow_field( - mongo_world=mongo_world, **kwargs) + field_kwargs = self._extract_marshmallow_field_params() + inner_ma_schema = self.inner.as_marshmallow_field() return ma.fields.List(inner_ma_schema, **field_kwargs) def _required_validate(self, value): @@ -376,11 +374,11 @@ def _serialize_to_mongo(self, obj): def _deserialize_from_mongo(self, value): return self.reference_cls(self.document_cls, value) - def as_marshmallow_field(self, mongo_world=False, **kwargs): + def as_marshmallow_field(self): # Overwrite default `as_marshmallow_field` to handle deserialization # difference (`_id` vs `id`) - field_kwargs = self._extract_marshmallow_field_params(mongo_world) - return ma_bonus_fields.Reference(mongo_world=mongo_world, **field_kwargs) + field_kwargs = self._extract_marshmallow_field_params() + return ma_bonus_fields.Reference(**field_kwargs) class GenericReferenceField(BaseField, ma_bonus_fields.GenericReference): @@ -432,11 +430,11 @@ def _deserialize_from_mongo(self, value): document_cls = self._document_cls(value['_cls']) return self.reference_cls(document_cls, value['_id']) - def as_marshmallow_field(self, mongo_world=False, **kwargs): + def as_marshmallow_field(self): # Overwrite default `as_marshmallow_field` to handle deserialization # difference (`_id` vs `id`) - field_kwargs = self._extract_marshmallow_field_params(mongo_world) - return ma_bonus_fields.GenericReference(mongo_world=mongo_world, **field_kwargs) + field_kwargs = self._extract_marshmallow_field_params() + return ma_bonus_fields.GenericReference(**field_kwargs) class EmbeddedField(BaseField, ma.fields.Nested): @@ -553,11 +551,10 @@ def map_to_field(self, mongo_path, path, func): if hasattr(field, 'map_to_field'): field.map_to_field(cur_mongo_path, cur_path, func) - def as_marshmallow_field(self, mongo_world=False, **kwargs): + def as_marshmallow_field(self): # Overwrite default `as_marshmallow_field` to handle nesting - field_kwargs = self._extract_marshmallow_field_params(mongo_world) - nested_ma_schema = self._embedded_document_cls.schema.as_marshmallow_schema( - mongo_world=mongo_world) + field_kwargs = self._extract_marshmallow_field_params() + nested_ma_schema = self._embedded_document_cls.schema.as_marshmallow_schema() return ma.fields.Nested(nested_ma_schema, **field_kwargs) def _required_validate(self, value): diff --git a/umongo/marshmallow_bonus.py b/umongo/marshmallow_bonus.py index d8990bf1..e9a3a711 100644 --- a/umongo/marshmallow_bonus.py +++ b/umongo/marshmallow_bonus.py @@ -34,16 +34,12 @@ class Reference(ObjectId): Marshmallow field for :class:`umongo.fields.ReferenceField` """ - def __init__(self, *args, mongo_world=False, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.mongo_world = mongo_world def _serialize(self, value, attr, obj): if value is None: return None - if self.mongo_world: - # In mongo world, value is a regular ObjectId - return str(value) # In OO world, value is a :class:`umongo.data_object.Reference` # or an ObjectId before being loaded into a Document if isinstance(value, bson.ObjectId): @@ -56,16 +52,12 @@ class GenericReference(ma.fields.Field): Marshmallow field for :class:`umongo.fields.GenericReferenceField` """ - def __init__(self, *args, mongo_world=False, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.mongo_world = mongo_world def _serialize(self, value, attr, obj): if value is None: return None - if self.mongo_world: - # In mongo world, value a dict of cls and id - return {'id': str(value['_id']), 'cls': value['_cls']} # In OO world, value is a :class:`umongo.data_object.Reference` # or a dict before being loaded into a Document if isinstance(value, dict): @@ -81,6 +73,4 @@ def _deserialize(self, value, attr, data, **kwargs): _id = bson.ObjectId(value['id']) except ValueError: raise ma.ValidationError(_("Invalid `id` field.")) - if self.mongo_world: - return {'_cls': value['cls'], '_id': _id} return {'cls': value['cls'], 'id': _id} diff --git a/umongo/schema.py b/umongo/schema.py index f53a4ea7..6d7f08c6 100644 --- a/umongo/schema.py +++ b/umongo/schema.py @@ -26,21 +26,16 @@ class Schema(BaseSchema): _marshmallow_schemas_cache = {} - def as_marshmallow_schema(self, *, mongo_world=False): - """ - Return a pure-marshmallow version of this schema class. - - :param mongo_world: If True the schema will work against the mongo world - instead of the OO world (default: False). - """ + def as_marshmallow_schema(self): + """Return a pure-marshmallow version of this schema class""" # Use a cache to avoid generating several times the same schema - cache_key = (self.__class__, self.MA_BASE_SCHEMA_CLS, mongo_world) + cache_key = (self.__class__, self.MA_BASE_SCHEMA_CLS) if cache_key in self._marshmallow_schemas_cache: return self._marshmallow_schemas_cache[cache_key] # Create schema if not found in cache nmspc = { - name: field.as_marshmallow_field(mongo_world=mongo_world) + name: field.as_marshmallow_field() for name, field in self.fields.items() } name = 'Marshmallow%s' % type(self).__name__