diff --git a/umongo/abstract.py b/umongo/abstract.py index 143a6e39..71848e84 100644 --- a/umongo/abstract.py +++ b/umongo/abstract.py @@ -114,6 +114,8 @@ class BaseField(ma_fields.Field): 'unique_compound': N_('Values of fields {fields} must be unique together.') } + MARSHMALLOW_ARGS_PREFIX = 'marshmallow_' + def __init__(self, *args, io_validate=None, unique=False, instance=None, **kwargs): if 'missing' in kwargs: raise RuntimeError("uMongo doesn't use `missing` argument, use `default` " @@ -125,21 +127,21 @@ def __init__(self, *args, io_validate=None, unique=False, instance=None, **kwarg # Store attributes prefixed with marshmallow_ to use them when # creating pure marshmallow Schema - for attribute in ( - 'data_key', 'attribute', 'validate', 'required', 'allow_none', - 'load_only', 'dump_only', 'error_messages' - ): - attribute = 'marshmallow_' + attribute - if attribute in kwargs: - setattr(self, attribute, kwargs.pop(attribute)) - - ma_missing = kwargs.pop('marshmallow_missing', None) - ma_default = kwargs.pop('marshmallow_default', None) + self._ma_kwargs = { + key[len(self.MARSHMALLOW_ARGS_PREFIX):]: val + for key, val in kwargs.items() + if key.startswith(self.MARSHMALLOW_ARGS_PREFIX) + } + kwargs = { + key: val + for key, val in kwargs.items() + if not key.startswith(self.MARSHMALLOW_ARGS_PREFIX) + } super().__init__(*args, **kwargs) - self.marshmallow_missing = ma_missing if ma_missing is not None else self.default - self.marshmallow_default = ma_default if ma_default is not None else self.default + self._ma_kwargs.setdefault('missing', self.default) + self._ma_kwargs.setdefault('default', self.default) # Overwrite error_messages to handle i18n translation self.error_messages = I18nErrorDict(self.error_messages) @@ -209,14 +211,7 @@ def _extract_marshmallow_field_params(self, mongo_world): params['attribute'] = self.attribute # Override uMongo attributes with marshmallow_ prefixed attributes - for attribute in ( - 'default', 'missing', 'data_key', 'attribute', - 'validate', 'required', 'allow_none', - 'load_only', 'dump_only', 'error_messages' - ): - ma_attribute = 'marshmallow_' + attribute - if hasattr(self, ma_attribute): - params[attribute] = getattr(self, ma_attribute) + params.update(self._ma_kwargs) params.update(self.metadata) return params