Skip to content

Commit

Permalink
Merge pull request #472 from MongoEngine/replace-ugly-patching
Browse files Browse the repository at this point in the history
Ugly silent mongoengine patching replaced with explicit inheritance, mixins and attributes bypass(for backward compatibility)
  • Loading branch information
insspb committed Jul 4, 2022
2 parents fdeb2c1 + aabad12 commit 6de8374
Show file tree
Hide file tree
Showing 7 changed files with 538 additions and 113 deletions.
6 changes: 3 additions & 3 deletions docs/api/wtf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ WTF module API

This is the flask_mongoengine.wtf modules API documentation.

flask_mongoengine.wtf.base module
---------------------------------
flask_mongoengine.wtf.db_fields module
--------------------------------------

.. automodule:: flask_mongoengine.wtf.base
.. automodule:: flask_mongoengine.wtf.db_fields
:members:
:undoc-members:
:show-inheritance:
Expand Down
127 changes: 60 additions & 67 deletions flask_mongoengine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import inspect

import mongoengine
from flask import Flask, abort, current_app
from mongoengine.base.fields import BaseField
from mongoengine.errors import DoesNotExist
from mongoengine.queryset import QuerySet

from .connection import *
from .json import override_json_encoder
from .pagination import *
from .sessions import *
from .wtf import WtfBaseField
from flask_mongoengine.connection import *
from flask_mongoengine.json import override_json_encoder
from flask_mongoengine.pagination import *
from flask_mongoengine.sessions import *
from flask_mongoengine.wtf import db_fields

VERSION = (1, 0, 0)

Expand All @@ -23,61 +20,6 @@ def get_version():
__version__ = get_version()


def _patch_base_field(obj, name):
"""
If the object submitted has a class whose base class is
mongoengine.base.fields.BaseField, then monkey patch to
replace it with flask_mongoengine.wtf.WtfBaseField.
@note: WtfBaseField is an instance of BaseField - but
gives us the flexibility to extend field parameters
and settings required of WTForm via model form generator.
@see: flask_mongoengine.wtf.base.WtfBaseField.
@see: model_form in flask_mongoengine.wtf.orm
@param obj: MongoEngine instance in which we should locate the class.
@param name: Name of an attribute which may or may not be a BaseField.
"""
# TODO is there a less hacky way to accomplish the same level of
# extensibility/control?

# get an attribute of the MongoEngine class and return if it's not
# a class
cls = getattr(obj, name)
if not inspect.isclass(cls):
return

# if it is a class, inspect all of its parent classes
cls_bases = list(cls.__bases__)

# if any of them is a BaseField, replace it with WtfBaseField
for index, base in enumerate(cls_bases):
if base == BaseField:
cls_bases[index] = WtfBaseField
cls.__bases__ = tuple(cls_bases)
break

# re-assign the class back to the MongoEngine instance
delattr(obj, name)
setattr(obj, name, cls)


def _include_mongoengine(obj):
"""
Copy all of the attributes from mongoengine and mongoengine.fields
onto obj (which should be an instance of the MongoEngine class).
"""
# TODO why do we need this? What's wrong with importing from the
# original modules?
for attr_name in mongoengine.__all__:
if not hasattr(obj, attr_name):
setattr(obj, attr_name, getattr(mongoengine, attr_name))

# patch BaseField if available
_patch_base_field(obj, attr_name)


def current_mongoengine_instance():
"""Return a MongoEngine instance associated with current Flask app."""
me = current_app.extensions.get("mongoengine", {})
Expand All @@ -90,10 +32,52 @@ class MongoEngine(object):
"""Main class used for initialization of Flask-MongoEngine."""

def __init__(self, app=None, config=None):
_include_mongoengine(self)

# Extended database fields
self.BinaryField = db_fields.BinaryField
self.BooleanField = db_fields.BooleanField
self.CachedReferenceField = db_fields.CachedReferenceField
self.ComplexDateTimeField = db_fields.ComplexDateTimeField
self.DateField = db_fields.DateField
self.DateTimeField = db_fields.DateTimeField
self.DecimalField = db_fields.DecimalField
self.DictField = db_fields.DictField
self.DynamicField = db_fields.DynamicField
self.EmailField = db_fields.EmailField
self.EmbeddedDocumentField = db_fields.EmbeddedDocumentField
self.EmbeddedDocumentListField = db_fields.EmbeddedDocumentListField
self.EnumField = db_fields.EnumField
self.FileField = db_fields.FileField
self.FloatField = db_fields.FloatField
self.GenericEmbeddedDocumentField = db_fields.GenericEmbeddedDocumentField
self.GenericLazyReferenceField = db_fields.GenericLazyReferenceField
self.GenericReferenceField = db_fields.GenericReferenceField
self.GeoJsonBaseField = db_fields.GeoJsonBaseField
self.GeoPointField = db_fields.GeoPointField
self.ImageField = db_fields.ImageField
self.IntField = db_fields.IntField
self.LazyReferenceField = db_fields.LazyReferenceField
self.LineStringField = db_fields.LineStringField
self.ListField = db_fields.ListField
self.LongField = db_fields.LongField
self.MapField = db_fields.MapField
self.MultiLineStringField = db_fields.MultiLineStringField
self.MultiPointField = db_fields.MultiPointField
self.MultiPolygonField = db_fields.MultiPolygonField
self.ObjectIdField = db_fields.ObjectIdField
self.PointField = db_fields.PointField
self.PolygonField = db_fields.PolygonField
self.ReferenceField = db_fields.ReferenceField
self.SequenceField = db_fields.SequenceField
self.SortedListField = db_fields.SortedListField
self.StringField = db_fields.StringField
self.URLField = db_fields.URLField
self.UUIDField = db_fields.UUIDField

# Flask related data
self.app = None
self.config = config

# Extended documents classes
self.Document = Document
self.DynamicDocument = DynamicDocument

Expand Down Expand Up @@ -143,6 +127,15 @@ def connection(self):
"""
return current_app.extensions["mongoengine"][self]["conn"]

def __getattr__(self, attr_name):
"""
Mongoengine backward compatibility handler.
Provide original :module:``mongoengine`` module methods/classes if they are not
modified by us, and not mapped directly.
"""
return getattr(mongoengine, attr_name)


class BaseQuerySet(QuerySet):
"""Extends :class:`~mongoengine.queryset.QuerySet` class with handly methods."""
Expand Down Expand Up @@ -194,7 +187,7 @@ def paginate_field(self, field_name, doc_id, page, per_page, total=None):
"""
# TODO this doesn't sound useful at all - remove in next release?
item = self.get(id=doc_id)
count = getattr(item, field_name + "_count", "")
count = getattr(item, f"{field_name}_count", "")
total = total or count or len(getattr(item, field_name))
return ListFieldPagination(
self, doc_id, field_name, page, per_page, total=total
Expand All @@ -209,7 +202,7 @@ class Document(mongoengine.Document):
def paginate_field(self, field_name, page, per_page, total=None):
"""Paginate items within a list field."""
# TODO this doesn't sound useful at all - remove in next release?
count = getattr(self, field_name + "_count", "")
count = getattr(self, f"{field_name}_count", "")
total = total or count or len(getattr(self, field_name))
return ListFieldPagination(
self.__class__.objects, self.pk, field_name, page, per_page, total=total
Expand Down
2 changes: 1 addition & 1 deletion flask_mongoengine/wtf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from flask_mongoengine.wtf.base import WtfBaseField # noqa
from flask_mongoengine.wtf.db_fields import * # noqa
from flask_mongoengine.wtf.orm import model_fields, model_form # noqa
40 changes: 0 additions & 40 deletions flask_mongoengine/wtf/base.py

This file was deleted.

0 comments on commit 6de8374

Please sign in to comment.