-
Notifications
You must be signed in to change notification settings - Fork 256
/
__init__.py
202 lines (147 loc) · 6.42 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import inspect
from flask import abort, current_app
import mongoengine
from distutils.version import StrictVersion
from mongoengine.base.fields import BaseField
from mongoengine.queryset import MultipleObjectsReturned, DoesNotExist, QuerySet
from mongoengine.base import ValidationError
from pymongo import uri_parser
from .sessions import *
from .pagination import *
from .metadata import *
from .json import override_json_encoder
from .wtf import WtfBaseField
def _patch_base_field(object, name):
"""
If the object submitted has a class whose base class is
mongoengine.base.fields.BaseField, then monkey patch to
replace it with flask_mongoengine.wtf.WtfBaseField.
@note: WtfBaseField is an instance of BaseField - but
gives us the flexibility to extend field parameters
and settings required of WTForm via model form generator.
@see: flask_mongoengine.wtf.base.WtfBaseField.
@see: model_form in flask_mongoengine.wtf.orm
@param object: The object whose footprint to locate the class.
@param name: Name of the class to locate.
"""
# locate class
cls = getattr(object, name)
if not inspect.isclass(cls):
return
# fetch class base classes
cls_bases = list(cls.__bases__)
# replace BaseField with WtfBaseField
for index, base in enumerate(cls_bases):
if base == BaseField:
cls_bases[index] = WtfBaseField
cls.__bases__ = tuple(cls_bases)
break
# re-assign class back to
# object footprint
delattr(object, name)
setattr(object, name, cls)
def _include_mongoengine(obj):
for module in mongoengine, mongoengine.fields:
for key in module.__all__:
if not hasattr(obj, key):
setattr(obj, key, getattr(module, key))
# patch BaseField if available
_patch_base_field(obj, key)
def _create_connection(conn_settings):
# Handle multiple connections recursively
if isinstance(conn_settings, list):
connections = {}
for conn in conn_settings:
connections[conn.get('alias')] = _create_connection(conn)
return connections
# Ugly dict comprehention in order to support python 2.6
conn = dict((k.lower(), v) for k, v in conn_settings.items() if v is not None)
if 'replicaset' in conn:
conn['replicaSet'] = conn.pop('replicaset')
if (StrictVersion(mongoengine.__version__) >= StrictVersion('0.10.6') and
current_app.config['TESTING'] == True and
conn.get('host', '').startswith('mongomock://')):
pass
# Handle uri style connections
elif "://" in conn.get('host', ''):
uri_dict = uri_parser.parse_uri(conn['host'])
conn['db'] = uri_dict['database']
return mongoengine.connect(conn.pop('db', 'test'), **conn)
class MongoEngine(object):
def __init__(self, app=None, config=None):
_include_mongoengine(self)
self.Document = Document
self.DynamicDocument = DynamicDocument
if app is not None:
self.init_app(app, config)
def init_app(self, app, config=None):
app.extensions = getattr(app, 'extensions', {})
# Make documents JSON serializable
override_json_encoder(app)
if not 'mongoengine' in app.extensions:
app.extensions['mongoengine'] = {}
if self in app.extensions['mongoengine']:
# Raise an exception if extension already initialized as
# potentially new configuration would not be loaded.
raise Exception('Extension already initialized')
if not config:
# If not passed a config then we read the connection settings
# from the app config.
config = app.config
if 'MONGODB_SETTINGS' in config:
# Connection settings provided as a dictionary.
connection = _create_connection(config['MONGODB_SETTINGS'])
else:
# Connection settings provided in standard format.
settings = {'alias': config.get('MONGODB_ALIAS', None),
'db': config.get('MONGODB_DB', None),
'host': config.get('MONGODB_HOST', None),
'password': config.get('MONGODB_PASSWORD', None),
'port': config.get('MONGODB_PORT', None),
'username': config.get('MONGODB_USERNAME', None)}
connection = _create_connection(settings)
# Store objects in application instance so that multiple apps do
# not end up accessing the same objects.
app.extensions['mongoengine'][self] = {'app': app,
'conn': connection}
@property
def connection(self):
return current_app.extensions['mongoengine'][self]['conn']
class BaseQuerySet(QuerySet):
"""
A base queryset with handy extras
"""
def get_or_404(self, *args, **kwargs):
try:
return self.get(*args, **kwargs)
except (MultipleObjectsReturned, DoesNotExist, ValidationError):
abort(404)
def first_or_404(self):
obj = self.first()
if obj is None:
abort(404)
return obj
def paginate(self, page, per_page, error_out=True):
return Pagination(self, page, per_page)
def paginate_field(self, field_name, doc_id, page, per_page,
total=None):
item = self.get(id=doc_id)
count = getattr(item, field_name + "_count", '')
total = total or count or len(getattr(item, field_name))
return ListFieldPagination(self, doc_id, field_name, page, per_page,
total=total)
class Document(mongoengine.Document):
"""Abstract document with extra helpers in the queryset class"""
meta = {'abstract': True,
'queryset_class': BaseQuerySet}
def paginate_field(self, field_name, page, per_page, total=None):
count = getattr(self, field_name + "_count", '')
total = total or count or len(getattr(self, field_name))
return ListFieldPagination(self.__class__.objects, self.pk, field_name,
page, per_page, total=total)
class DynamicDocument(mongoengine.DynamicDocument):
"""Abstract Dynamic document with extra helpers in the queryset class"""
meta = {'abstract': True,
'queryset_class': BaseQuerySet}