From bb19dd42313560000ba075a7ae082341ea62b674 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Sat, 18 Apr 2020 20:38:36 +0200 Subject: [PATCH] Various pylint fixes --- umongo/abstract.py | 6 ++- umongo/builder.py | 10 +++-- umongo/data_objects.py | 6 +-- umongo/data_proxy.py | 60 ++++++++++++++++-------------- umongo/document.py | 17 ++------- umongo/embedded_document.py | 4 +- umongo/fields.py | 27 ++++++-------- umongo/frameworks/__init__.py | 1 - umongo/frameworks/motor_asyncio.py | 32 ++++++++-------- umongo/frameworks/pymongo.py | 27 +++++++------- umongo/frameworks/txmongo.py | 41 ++++++++++---------- umongo/indexes.py | 13 +++---- umongo/marshmallow_bonus.py | 28 ++++++-------- umongo/query_mapper.py | 5 +-- umongo/schema.py | 30 ++++++--------- umongo/template.py | 8 ++-- 16 files changed, 147 insertions(+), 168 deletions(-) diff --git a/umongo/abstract.py b/umongo/abstract.py index b347fab9..3eeb69f7 100644 --- a/umongo/abstract.py +++ b/umongo/abstract.py @@ -118,10 +118,12 @@ class BaseField(ma_fields.Field): def __init__(self, *args, io_validate=None, unique=False, instance=None, **kwargs): if 'missing' in kwargs: - raise RuntimeError("uMongo doesn't use `missing` argument, use `default` " + raise RuntimeError( + "uMongo doesn't use `missing` argument, use `default` " "instead and `marshmallow_missing`/`marshmallow_default` " "to tell `as_marshmallow_field` to use a custom value when " - "generating pure Marshmallow field.") + "generating pure Marshmallow field." + ) if 'default' in kwargs: kwargs['missing'] = kwargs['default'] diff --git a/umongo/builder.py b/umongo/builder.py index 954258ff..eadb3214 100644 --- a/umongo/builder.py +++ b/umongo/builder.py @@ -16,8 +16,8 @@ def camel_to_snake(name): - s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + tmp_str = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', tmp_str).lower() def _is_child(bases): @@ -72,8 +72,10 @@ def _collect_indexes(meta, schema_nmspc, bases): # Then get our own custom indexes if is_child: - custom_indexes = [parse_index(x, base_compound_field='_cls') - for x in getattr(meta, 'indexes', ())] + custom_indexes = [ + parse_index(x, base_compound_field='_cls') + for x in getattr(meta, 'indexes', ()) + ] else: custom_indexes = [parse_index(x) for x in getattr(meta, 'indexes', ())] indexes += custom_indexes diff --git a/umongo/data_objects.py b/umongo/data_objects.py index c011a593..e7155982 100644 --- a/umongo/data_objects.py +++ b/umongo/data_objects.py @@ -78,7 +78,7 @@ def set_modified(self): def clear_modified(self): self._modified = False - if len(self) and isinstance(self[0], BaseDataObject): + if self and isinstance(self[0], BaseDataObject): for obj in self: obj.clear_modified() @@ -179,8 +179,8 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, self.document_cls): return other.pk == self.pk - elif isinstance(other, Reference): + if isinstance(other, Reference): return self.pk == other.pk and self.document_cls == other.document_cls - elif isinstance(other, DBRef): + if isinstance(other, DBRef): return self.pk == other.id and self.document_cls.collection.name == other.collection return NotImplemented diff --git a/umongo/data_proxy.py b/umongo/data_proxy.py index a9373eb7..ac3ec57e 100644 --- a/umongo/data_proxy.py +++ b/umongo/data_proxy.py @@ -19,6 +19,7 @@ def __init__(self, data=None): self.not_loaded_fields = set() # Inside data proxy, data are stored in mongo world representation self._modified_data = set() + self._data = {} self.load(data or {}) @property @@ -29,16 +30,15 @@ def partial(self): def to_mongo(self, update=False): if update: return self._to_mongo_update() - else: - return self._to_mongo() + return self._to_mongo() def _to_mongo(self): mongo_data = {} - for k, v in self._data.items(): - field = self._fields_from_mongo_key[k] - v = field.serialize_to_mongo(v) - if v is not missing: - mongo_data[k] = v + for key, val in self._data.items(): + field = self._fields_from_mongo_key[key] + val = field.serialize_to_mongo(val) + if val is not missing: + mongo_data[key] = val return mongo_data def _to_mongo_update(self): @@ -48,11 +48,11 @@ def _to_mongo_update(self): for name in self.get_modified_fields(): field = self._fields[name] name = field.attribute or name - v = field.serialize_to_mongo(self._data[name]) - if v is missing: + val = field.serialize_to_mongo(self._data[name]) + if val is missing: unset_data.append(name) else: - set_data[name] = v + set_data[name] = val if set_data: mongo_data['$set'] = set_data if unset_data: @@ -61,14 +61,15 @@ def _to_mongo_update(self): def from_mongo(self, data, partial=False): self._data = {} - for k, v in data.items(): + for key, val in data.items(): try: - field = self._fields_from_mongo_key[k] + field = self._fields_from_mongo_key[key] except KeyError: - raise UnknownFieldInDBError( - _('{cls}: unknown "{key}" field found in DB.' - .format(key=k, cls=self.__class__.__name__))) - self._data[k] = field.deserialize_from_mongo(v) + raise UnknownFieldInDBError(_( + '{cls}: unknown "{key}" field found in DB.' + .format(key=key, cls=self.__class__.__name__) + )) + self._data[key] = field.deserialize_from_mongo(val) if partial: self._collect_partial_fields(data.keys(), as_mongo_fields=True) else: @@ -158,7 +159,7 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, dict): return self._data == other - elif hasattr(other, '_data'): + if hasattr(other, '_data'): return self._data == other._data return NotImplemented @@ -177,14 +178,16 @@ def get_modified_fields(self): def clear_modified(self): self._modified_data.clear() - for v in self._data.values(): - if isinstance(v, BaseDataObject): - v.clear_modified() + for val in self._data.values(): + if isinstance(val, BaseDataObject): + val.clear_modified() def is_modified(self): - return (bool(self._modified_data) or + return ( + bool(self._modified_data) or any(isinstance(v, BaseDataObject) and v.is_modified() - for v in self._data.values())) + for v in self._data.values()) + ) def _collect_partial_fields(self, loaded_fields, as_mongo_fields=False): if as_mongo_fields: @@ -222,8 +225,9 @@ def required_validate(self): # Standards iterators providing oo and mongo worlds views def items(self): - return ((key, self._data[field.attribute or key]) - for key, field in self._fields.items()) + return ( + (key, self._data[field.attribute or key]) for key, field in self._fields.items() + ) def items_by_mongo_name(self): return self._data.items() @@ -257,13 +261,13 @@ def _to_mongo(self): def from_mongo(self, data, partial=False): self._data = {} - for k, v in data.items(): + for key, val in data.items(): try: - field = self._fields_from_mongo_key[k] + field = self._fields_from_mongo_key[key] except KeyError: - self._additional_data[k] = v + self._additional_data[key] = val else: - self._data[k] = field.deserialize_from_mongo(v) + self._data[key] = field.deserialize_from_mongo(val) if partial: self._collect_partial_fields(data.keys(), as_mongo_fields=True) else: diff --git a/umongo/document.py b/umongo/document.py index 315550c9..09f52bf7 100644 --- a/umongo/document.py +++ b/umongo/document.py @@ -8,6 +8,7 @@ from .exceptions import (NotCreatedError, NoDBDefinedError, AbstractDocumentError, DocumentDefinitionError) from .template import Implementation, Template, MetaImplementation +from .data_objects import Reference __all__ = ( @@ -37,7 +38,6 @@ class DocumentTemplate(Template): or `marshmallow.post_dump`) to this class that will be passed to the marshmallow schema internally used for this document. """ - pass Document = DocumentTemplate @@ -76,9 +76,7 @@ class Meta: indexes yes List of custom indexes offspring no List of Documents inheriting this one ==================== ====================== =========== - """ - def __repr__(self): return ('<{ClassName}(' 'instance={self.instance}, ' @@ -148,14 +146,13 @@ def __repr__(self): self.__module__, self.__class__.__name__, dict(self._data.items())) def __eq__(self, other): - from .data_objects import Reference if self.pk is None: return self is other - elif isinstance(other, self.__class__) and other.pk is not None: + if isinstance(other, self.__class__) and other.pk is not None: return self.pk == other.pk - elif isinstance(other, DBRef): + if isinstance(other, DBRef): return other.collection == self.collection.name and other.id == self.pk - elif isinstance(other, Reference): + if isinstance(other, Reference): return isinstance(self, other.document_cls) and self.pk == other.pk return NotImplemented @@ -316,7 +313,6 @@ def pre_insert(self): .. note:: If you use an async driver, this callback can be asynchronous. """ - pass def pre_update(self): """ @@ -326,7 +322,6 @@ def pre_update(self): .. note:: If you use an async driver, this callback can be asynchronous. """ - pass def pre_delete(self): """ @@ -336,7 +331,6 @@ def pre_delete(self): .. note:: If you use an async driver, this callback can be asynchronous. """ - pass def post_insert(self, ret): """ @@ -345,7 +339,6 @@ def post_insert(self, ret): .. note:: If you use an async driver, this callback can be asynchronous. """ - pass def post_update(self, ret): """ @@ -354,7 +347,6 @@ def post_update(self, ret): .. note:: If you use an async driver, this callback can be asynchronous. """ - pass def post_delete(self, ret): """ @@ -363,4 +355,3 @@ def post_delete(self, ret): .. note:: If you use an async driver, this callback can be asynchronous. """ - pass diff --git a/umongo/embedded_document.py b/umongo/embedded_document.py index 3e152c5f..896ab9bb 100644 --- a/umongo/embedded_document.py +++ b/umongo/embedded_document.py @@ -21,7 +21,6 @@ class EmbeddedDocumentTemplate(Template): :class:`umongo.instance.BaseInstance` to obtain it corresponding :class:`umongo.embedded_document.EmbeddedDocumentImplementation`. """ - pass EmbeddedDocument = EmbeddedDocumentTemplate @@ -58,7 +57,6 @@ class Meta: offspring no List of EmbeddedDocuments inheriting this one ==================== ====================== =========== """ - def __repr__(self): return ('<{ClassName}(' 'instance={self.instance}, ' @@ -106,7 +104,7 @@ def __repr__(self): def __eq__(self, other): if isinstance(other, dict): return self._data == other - elif hasattr(other, '_data'): + if hasattr(other, '_data'): return self._data == other._data return NotImplemented diff --git a/umongo/fields.py b/umongo/fields.py index fec5a1f4..74ef79c6 100644 --- a/umongo/fields.py +++ b/umongo/fields.py @@ -6,6 +6,7 @@ from marshmallow import fields as ma_fields # from .registerer import retrieve_document +from .document import DocumentImplementation from .exceptions import NotRegisteredDocumentError from .template import get_template from .data_objects import Reference, List, Dict @@ -367,7 +368,7 @@ def _deserialize(self, value, attr, data, **kwargs): if value.document_cls != self.document_cls: raise ValidationError(_("`{document}` reference expected.").format( document=self.document_cls.__name__)) - if type(value) is not self.reference_cls: + if not isinstance(value, self.reference_cls): value = self.reference_cls(value.document_cls, value.pk) return value elif isinstance(value, self.document_cls): @@ -401,8 +402,6 @@ class GenericReferenceField(BaseField, ma_bonus_fields.GenericReference): def __init__(self, *args, reference_cls=Reference, **kwargs): super().__init__(*args, **kwargs) self.reference_cls = reference_cls - # Avoid importing multiple times - from .document import DocumentImplementation self._document_implementation_cls = DocumentImplementation def _document_cls(self, class_name): @@ -421,15 +420,15 @@ def _deserialize(self, value, attr, data, **kwargs): if value is None: return None if isinstance(value, Reference): - if type(value) is not self.reference_cls: + if not isinstance(value, self.reference_cls): value = self.reference_cls(value.document_cls, value.pk) return value - elif isinstance(value, self._document_implementation_cls): + if isinstance(value, self._document_implementation_cls): if not value.is_created: raise ValidationError( _("Cannot reference a document that has not been created yet.")) return self.reference_cls(value.__class__, value.pk) - elif isinstance(value, dict): + if isinstance(value, dict): if value.keys() != {'cls', 'id'}: raise ValidationError(_("Generic reference must have `id` and `cls` fields.")) try: @@ -438,8 +437,7 @@ def _deserialize(self, value, attr, data, **kwargs): raise ValidationError(_("Invalid `id` field.")) document_cls = self._document_cls(value['cls']) return self.reference_cls(document_cls, _id) - else: - raise ValidationError(_("Invalid value for generic reference field.")) + raise ValidationError(_("Invalid value for generic reference field.")) def _serialize_to_mongo(self, obj): return {'_id': obj.pk, '_cls': obj.document_cls.__name__} @@ -516,11 +514,10 @@ def _deserialize(self, value, attr, data, **kwargs): try: to_use_cls = embedded_document_cls.opts.instance.retrieve_embedded_document( to_use_cls_name) - except NotRegisteredDocumentError as e: - raise ValidationError(str(e)) + except NotRegisteredDocumentError as exc: + raise ValidationError(str(exc)) return to_use_cls(**value) - else: - return embedded_document_cls(**value) + return embedded_document_cls(**value) def _serialize_to_mongo(self, obj): return obj.to_mongo() @@ -533,7 +530,7 @@ def _validate_missing(self, value): super()._validate_missing(value) errors = {} if value is missing: - def get_sub_value(key): + def get_sub_value(_): return missing elif isinstance(value, dict): # value is a dict for deserialization @@ -555,8 +552,8 @@ def get_sub_value(key): continue try: field._validate_missing(sub_value) - except ValidationError as ve: - errors[name] = ve.messages + except ValidationError as exc: + errors[name] = exc.messages if errors: raise ValidationError(errors) diff --git a/umongo/frameworks/__init__.py b/umongo/frameworks/__init__.py index 7d3b16c3..a514f3bf 100644 --- a/umongo/frameworks/__init__.py +++ b/umongo/frameworks/__init__.py @@ -1,7 +1,6 @@ """ Frameworks ========== - """ from importlib import import_module diff --git a/umongo/frameworks/motor_asyncio.py b/umongo/frameworks/motor_asyncio.py index dce6328e..5780b336 100644 --- a/umongo/frameworks/motor_asyncio.py +++ b/umongo/frameworks/motor_asyncio.py @@ -1,9 +1,9 @@ +from inspect import iscoroutine import asyncio from motor.motor_asyncio import AsyncIOMotorDatabase, AsyncIOMotorCursor from motor import version_tuple as MOTOR_VERSION from pymongo.errors import DuplicateKeyError -from inspect import iscoroutine from ..builder import BaseBuilder from ..document import DocumentImplementation @@ -176,12 +176,13 @@ async def commit(self, io_validate_all=False, conditions=None): key = tuple(keys)[0] msg = self.schema.fields[key].error_messages['unique'] raise ValidationError({key: msg}) - else: - fields = self.schema.fields - # Compound index (sort value to make testing easier) - keys = sorted(keys) - raise ValidationError({k: fields[k].error_messages[ - 'unique_compound'].format(fields=keys) for k in keys}) + fields = self.schema.fields + # Compound index (sort value to make testing easier) + keys = sorted(keys) + raise ValidationError({ + k: fields[k].error_messages['unique_compound'].format(fields=keys) + for k in keys + }) # Unknown index, cannot wrap the error so just reraise it raise self._data.clear_modified() @@ -232,9 +233,8 @@ async def io_validate(self, validate_all=False): """ if validate_all: return await _io_validate_data_proxy(self.schema, self._data) - else: - return await _io_validate_data_proxy( - self.schema, self._data, partial=self._data.get_modified_fields()) + return await _io_validate_data_proxy( + self.schema, self._data, partial=self._data.get_modified_fields()) @classmethod async def find_one(cls, filter=None, *args, **kwargs): @@ -276,7 +276,7 @@ async def ensure_indexes(cls): """ for index in cls.opts.indexes: kwargs = index.document.copy() - keys = [(k, d) for k, d in kwargs.pop('key').items()] + keys = kwargs.pop('key').items() await cls.collection.create_index(keys, **kwargs) @@ -311,8 +311,8 @@ async def _io_validate_data_proxy(schema, data_proxy, partial=None): if field.io_validate: tasks.append(_run_validators(field.io_validate, field, value)) tasks_field_name.append(name) - except ValidationError as ve: - errors[name] = ve.messages + except ValidationError as exc: + errors[name] = exc.messages results = await asyncio.gather(*tasks, return_exceptions=True) for i, res in enumerate(results): if isinstance(res, ValidationError): @@ -383,8 +383,10 @@ def _patch_field(self, field): validators = list(validators) else: validators = [validators] - field.io_validate = [v if asyncio.iscoroutinefunction(v) else asyncio.coroutine(v) - for v in validators] + field.io_validate = [ + v if asyncio.iscoroutinefunction(v) else asyncio.coroutine(v) + for v in validators + ] if isinstance(field, ListField): field.io_validate_recursive = _list_io_validate if isinstance(field, ReferenceField): diff --git a/umongo/frameworks/pymongo.py b/umongo/frameworks/pymongo.py index b6040fea..a7c83113 100644 --- a/umongo/frameworks/pymongo.py +++ b/umongo/frameworks/pymongo.py @@ -130,12 +130,11 @@ def commit(self, io_validate_all=False, conditions=None): if len(keys) == 1: msg = self.schema.fields[keys[0]].error_messages['unique'] raise ValidationError({keys[0]: msg}) - else: - fields = self.schema.fields - # Compound index (sort value to make testing easier) - keys = sorted(keys) - raise ValidationError({k: fields[k].error_messages[ - 'unique_compound'].format(fields=keys) for k in keys}) + fields = self.schema.fields + # Compound index (sort value to make testing easier) + keys = sorted(keys) + raise ValidationError({k: fields[k].error_messages[ + 'unique_compound'].format(fields=keys) for k in keys}) # Unknown index, cannot wrap the error so just reraise it raise self._data.clear_modified() @@ -235,8 +234,8 @@ def _run_validators(validators, field, value): for validator in validators: try: validator(field, value) - except ValidationError as ve: - errors.extend(ve.messages) + except ValidationError as exc: + errors.extend(exc.messages) if errors: raise ValidationError(errors) @@ -255,8 +254,8 @@ def _io_validate_data_proxy(schema, data_proxy, partial=None): field.io_validate_recursive(field, value) if field.io_validate: _run_validators(field.io_validate, field, value) - except ValidationError as ve: - errors[name] = ve.messages + except ValidationError as exc: + errors[name] = exc.messages if errors: raise ValidationError(errors) @@ -270,11 +269,11 @@ def _list_io_validate(field, value): validators = field.inner.io_validate if not validators: return - for i, e in enumerate(value): + for idx, val in enumerate(value): try: - _run_validators(validators, field.inner, e) - except ValidationError as ev: - errors[i] = ev.messages + _run_validators(validators, field.inner, val) + except ValidationError as exc: + errors[idx] = exc.messages if errors: raise ValidationError(errors) diff --git a/umongo/frameworks/txmongo.py b/umongo/frameworks/txmongo.py index 374dc1ac..0a82b787 100644 --- a/umongo/frameworks/txmongo.py +++ b/umongo/frameworks/txmongo.py @@ -1,7 +1,7 @@ -from txmongo.database import Database from twisted.internet.defer import ( inlineCallbacks, Deferred, DeferredList, returnValue, maybeDeferred) from txmongo import filter as qf +from txmongo.database import Database from pymongo.errors import DuplicateKeyError from ..builder import BaseBuilder @@ -93,12 +93,15 @@ def commit(self, io_validate_all=False, conditions=None): if len(keys) == 1: msg = self.schema.fields[keys[0]].error_messages['unique'] raise ValidationError({keys[0]: msg}) - else: - fields = self.schema.fields - # Compound index (sort value to make testing easier) - keys = sorted(keys) - raise ValidationError({k: fields[k].error_messages[ - 'unique_compound'].format(fields=keys) for k in keys}) + fields = self.schema.fields + # Compound index (sort value to make testing easier) + keys = sorted(keys) + raise ValidationError( + { + k: fields[k].error_messages['unique_compound'].format(fields=keys) + for k in keys + } + ) self._data.clear_modified() return ret @@ -142,9 +145,8 @@ def io_validate(self, validate_all=False): """ if validate_all: return _io_validate_data_proxy(self.schema, self._data) - else: - return _io_validate_data_proxy( - self.schema, self._data, partial=self._data.get_modified_fields()) + return _io_validate_data_proxy( + self.schema, self._data, partial=self._data.get_modified_fields()) @classmethod @inlineCallbacks @@ -179,8 +181,7 @@ def wrap_raw_results(result): return ([cls.build_from_mongo(e, use_cls=True) for e in result[0]], cursor) return wrap_raw_results(raw_cursor_or_list) - else: - return [cls.build_from_mongo(e, use_cls=True) for e in raw_cursor_or_list] + return [cls.build_from_mongo(e, use_cls=True) for e in raw_cursor_or_list] @classmethod def count(cls, spec=None, **kwargs): @@ -200,7 +201,7 @@ def ensure_indexes(cls): for index in cls.opts.indexes: kwargs = index.document.copy() keys = kwargs.pop('key') - index = qf.sort([(k, d) for k, d in keys.items()]) + index = qf.sort(keys.items()) yield cls.collection.create_index(index, **kwargs) @@ -226,8 +227,8 @@ def _run_validators(validators, field, value): for validator in validators: try: defer = validator(field, value) - except ValidationError as ve: - errors.extend(ve.messages) + except ValidationError as exc: + errors.extend(exc.messages) else: assert isinstance(defer, Deferred), 'io_validate functions must return a Deferred' defer.addErrback(_errback_factory(errors)) @@ -255,8 +256,8 @@ def _io_validate_data_proxy(schema, data_proxy, partial=None): defer = _run_validators(field.io_validate, field, value) defer.addErrback(_errback_factory(errors, name)) defers.append(defer) - except ValidationError as ve: - errors[name] = ve.messages + except ValidationError as exc: + errors[name] = exc.messages yield DeferredList(defers) if errors: raise ValidationError(errors) @@ -273,9 +274,9 @@ def _list_io_validate(field, value): return errors = {} defers = [] - for i, e in enumerate(value): - defer = _run_validators(validators, field.inner, e) - defer.addErrback(_errback_factory(errors, i)) + for idx, exc in enumerate(value): + defer = _run_validators(validators, field.inner, exc) + defer.addErrback(_errback_factory(errors, idx)) defers.append(defer) yield DeferredList(defers) if errors: diff --git a/umongo/indexes.py b/umongo/indexes.py index b7667e70..749cca15 100644 --- a/umongo/indexes.py +++ b/umongo/indexes.py @@ -5,23 +5,22 @@ def explicit_key(index): if isinstance(index, (list, tuple)): assert len(index) == 2, 'Must be a (`key`, `direction`) tuple' return index - elif index.startswith('+'): + if index.startswith('+'): return (index[1:], ASCENDING) - elif index.startswith('-'): + if index.startswith('-'): return (index[1:], DESCENDING) - elif index.startswith('$'): + if index.startswith('$'): return (index[1:], TEXT) - elif index.startswith('#'): + if index.startswith('#'): return (index[1:], HASHED) - else: - return (index, ASCENDING) + return (index, ASCENDING) def parse_index(index, base_compound_field=None): keys = None args = {} if isinstance(index, IndexModel): - keys = [(k, d) for k, d in index.document['key'].items()] + keys = index.document['key'].items() args = {k: v for k, v in index.document.items() if k != 'key'} elif isinstance(index, (tuple, list)): # Compound indexes diff --git a/umongo/marshmallow_bonus.py b/umongo/marshmallow_bonus.py index 70412776..0759ee0b 100644 --- a/umongo/marshmallow_bonus.py +++ b/umongo/marshmallow_bonus.py @@ -35,8 +35,7 @@ class MySchema(marshsmallow.Schema): if ret is None and ret is not default and attr in obj.schema.fields: raw_ret = obj._data.get(attr) return default if raw_ret is missing else raw_ret - else: - return ret + return ret class SchemaFromUmongo(MaSchema): @@ -84,12 +83,11 @@ def _serialize(self, value, attr, obj): if self.mongo_world: # In mongo world, value is a regular ObjectId return str(value) - else: - # In OO world, value is a :class:`umongo.data_object.Reference` - # or an ObjectId before being loaded into a Document - if isinstance(value, bson.ObjectId): - return str(value) - return str(value.pk) + # In OO world, value is a :class:`umongo.data_object.Reference` + # or an ObjectId before being loaded into a Document + if isinstance(value, bson.ObjectId): + return str(value) + return str(value.pk) class GenericReference(ma_fields.Field): @@ -107,12 +105,11 @@ def _serialize(self, value, attr, obj): if self.mongo_world: # In mongo world, value a dict of cls and id return {'id': str(value['_id']), 'cls': value['_cls']} - else: - # In OO world, value is a :class:`umongo.data_object.Reference` - # or a dict before being loaded into a Document - if isinstance(value, dict): - return {'id': str(value['id']), 'cls': value['cls']} - return {'id': str(value.pk), 'cls': value.document_cls.__name__} + # In OO world, value is a :class:`umongo.data_object.Reference` + # or a dict before being loaded into a Document + if isinstance(value, dict): + return {'id': str(value['id']), 'cls': value['cls']} + return {'id': str(value.pk), 'cls': value.document_cls.__name__} def _deserialize(self, value, attr, data, **kwargs): if not isinstance(value, dict): @@ -125,5 +122,4 @@ def _deserialize(self, value, attr, data, **kwargs): raise ValidationError(_("Invalid `id` field.")) if self.mongo_world: return {'_cls': value['cls'], '_id': _id} - else: - return {'cls': value['cls'], 'id': _id} + return {'cls': value['cls'], 'id': _id} diff --git a/umongo/query_mapper.py b/umongo/query_mapper.py index 1cba21e6..c20224b3 100644 --- a/umongo/query_mapper.py +++ b/umongo/query_mapper.py @@ -42,7 +42,6 @@ def map_query(query, fields): mapped_entry, entry_fields = map_entry_with_dots(entry, fields) mapped_query[mapped_entry] = map_query(entry_query, entry_fields) return mapped_query - elif isinstance(query, (list, tuple)): + if isinstance(query, (list, tuple)): return [map_query(x, fields) for x in query] - else: - return query + return query diff --git a/umongo/schema.py b/umongo/schema.py index 27684a8f..c20074ad 100644 --- a/umongo/schema.py +++ b/umongo/schema.py @@ -1,24 +1,26 @@ from marshmallow.fields import Field +from . import fields from .abstract import BaseSchema __all__ = ('Schema', 'EmbeddedSchema', 'on_need_add_id_field', 'add_child_field') -def on_need_add_id_field(bases, fields): +def on_need_add_id_field(bases, fields_dict): """ If the given fields make no reference to `_id`, add an `id` field (type ObjectId, dump_only=True, attribute=`_id`) to handle it """ - def find_id_field(fields): - for name, field in fields.items(): + def find_id_field(fields_dict): + for name, field in fields_dict.items(): # Skip fake fields present in schema (e.g. `post_load` decorated function) if not isinstance(field, Field): continue if (name == '_id' and not field.attribute) or field.attribute == '_id': return name, field + return None # Search among parents for the id field for base in bases: @@ -27,28 +29,18 @@ def find_id_field(fields): return # Search amongo our own fields - if not find_id_field(fields): + if not find_id_field(fields_dict): # No id field found, add a default one - from .fields import ObjectIdField - fields['id'] = ObjectIdField(attribute='_id', dump_only=True) + fields_dict['id'] = fields.ObjectIdField(attribute='_id', dump_only=True) -def add_child_field(name, fields): - from .fields import StrField - fields['cls'] = StrField(attribute='_cls', default=name, dump_only=True) +def add_child_field(name, fields_dict): + fields_dict['cls'] = fields.StringField(attribute='_cls', default=name, dump_only=True) class Schema(BaseSchema): - """ - Base schema class used by :class:`umongo.Document` - """ - - pass + """Base schema class used by :class:`umongo.Document`""" class EmbeddedSchema(BaseSchema): - """ - Base schema class used by :class:`umongo.EmbeddedDocument` - """ - - pass + """Base schema class used by :class:`umongo.EmbeddedDocument`""" diff --git a/umongo/template.py b/umongo/template.py index 290a0219..a6079c4f 100644 --- a/umongo/template.py +++ b/umongo/template.py @@ -32,8 +32,7 @@ def __new__(cls, name, bases, nmspc): if 'opts' not in nmspc: # Inheritance to avoid metaclass conflicts return super().__new__(cls, name, bases, nmspc) - else: - return type.__new__(cls, name, bases, nmspc) + return type.__new__(cls, name, bases, nmspc) def __repr__(cls): return "" % (cls.__module__, cls.__name__) @@ -52,6 +51,5 @@ def opts(self): def get_template(template_or_implementation): if issubclass(template_or_implementation, Implementation): return template_or_implementation.opts.template - else: - assert issubclass(template_or_implementation, Template) - return template_or_implementation + assert issubclass(template_or_implementation, Template) + return template_or_implementation