Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion flask_mongoengine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import inspect

from flask import abort, current_app

Expand All @@ -13,6 +14,45 @@
from .sessions import *
from .pagination import *
from .json import overide_json_encoder
from .wtf import WtfBaseField


def _patch_base_field(object, name):
"""
If the object submitted has a class whose base class is
mongoengine.base.fields.BaseField, then monkey patch to
replace it with flask_mongoengine.wtf.WtfBaseField.

@note: WtfBaseField is an instance of BaseField - but
gives us the flexibility to extend field parameters
and settings required of WTForm via model form generator.

@see: flask_mongoengine.wtf.base.WtfBaseField.
@see: model_form in flask_mongoengine.wtf.orm

@param object: The object whose footprint to locate the class.
@param name: Name of the class to locate.

"""
# locate class
cls = getattr(object, name)
if not inspect.isclass(cls):
return

# fetch class base classes
cls_bases = list(cls.__bases__)

# replace BaseField with WtfBaseField
for index, base in enumerate(cls_bases):
if base == mongoengine.base.fields.BaseField:
cls_bases[index] = WtfBaseField
cls.__bases__ = tuple(cls_bases)
break

# re-assign class back to
# object footprint
delattr(object, name)
setattr(object, name, cls)


def _include_mongoengine(obj):
Expand All @@ -21,6 +61,9 @@ def _include_mongoengine(obj):
if not hasattr(obj, key):
setattr(obj, key, getattr(module, key))

# patch BaseField if available
_patch_base_field(obj, key)


def _create_connection(conn_settings):

Expand All @@ -44,10 +87,10 @@ def _create_connection(conn_settings):
return mongoengine.connect(conn.pop('db', 'test'), **conn)



class MongoEngine(object):

def __init__(self, app=None, config=None):

_include_mongoengine(self)

self.Document = Document
Expand Down
1 change: 1 addition & 0 deletions flask_mongoengine/wtf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from flask.ext.mongoengine.wtf.orm import model_fields, model_form
from flask.ext.mongoengine.wtf.base import WtfBaseField
41 changes: 41 additions & 0 deletions flask_mongoengine/wtf/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from mongoengine.base import BaseField

__all__ = ('WtfBaseField')

class WtfBaseField(BaseField):
"""
Extension wrapper class for mongoengine BaseField.

This enables flask-mongoengine wtf to extend the
number of field parameters, and settings on behalf
of document model form generator for WTForm.

@param validators: wtf model form field validators.
@param filters: wtf model form field filters.
"""

def __init__(self, validators=None, filters=None, **kwargs):

self.validators =\
self._ensure_callable_or_list(validators, 'validators')
self.filters = self._ensure_callable_or_list(filters, 'filters')

BaseField.__init__(self, **kwargs)


def _ensure_callable_or_list(self, field, msg_flag):
"""
Ensure the value submitted via field is either
a callable object to convert to list or it is
in fact a valid list value.

"""
if field is not None:
if callable(field):
field = [field]
else:
msg = "Argument '%s' must be a list value" % msg_flag
if not isinstance(field, list):
raise TypeError(msg)

return field
4 changes: 2 additions & 2 deletions flask_mongoengine/wtf/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def convert(self, model, field, field_args):
kwargs = {
'label': getattr(field, 'verbose_name', field.name),
'description': field.help_text or '',
'validators': [],
'filters': [],
'validators': [] if not field.validators else field.validators,
'filters': [] if not field.filters else field.filters,
'default': field.default,
}
if field_args:
Expand Down