Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions flask_mongoengine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import

from flask import abort
from flask import abort, current_app

import mongoengine

Expand All @@ -23,6 +23,14 @@ def _include_mongoengine(obj):


def _create_connection(conn_settings):

# Handle multiple connections recursively
if isinstance(conn_settings, list):
connections = {}
for conn in conn_settings:
connections[conn.get('alias')] = _create_connection(conn)
return connections

conn = dict([(k.lower(), v) for k, v in conn_settings.items() if v])

if 'replicaset' in conn:
Expand All @@ -38,40 +46,60 @@ def _create_connection(conn_settings):

class MongoEngine(object):

def __init__(self, app=None):
def __init__(self, app=None, config=None):

_include_mongoengine(self)

self.Document = Document
self.DynamicDocument = DynamicDocument

if app is not None:
self.init_app(app)
self.init_app(app, config)

def init_app(self, app):
def init_app(self, app, config=None):

conn_settings = app.config.get('MONGODB_SETTINGS', None)
app.extensions = getattr(app, 'extensions', {})

if not conn_settings:
conn_settings = {
'db': app.config.get('MONGODB_DB', None),
'username': app.config.get('MONGODB_USERNAME', None),
'password': app.config.get('MONGODB_PASSWORD', None),
'host': app.config.get('MONGODB_HOST', None),
'port': int(app.config.get('MONGODB_PORT', 0)) or None
}
# Make documents JSON serializable
overide_json_encoder(app)

if isinstance(conn_settings, list):
self.connection = {}
for conn in conn_settings:
self.connection[conn.get('alias')] = _create_connection(conn)
else:
self.connection = _create_connection(conn_settings)
if not 'mongoengine' in app.extensions:
app.extensions['mongoengine'] = {}

app.extensions = getattr(app, 'extensions', {})
app.extensions['mongoengine'] = self
self.app = app
overide_json_encoder(app)
if self in app.extensions['mongoengine']:
# Raise an exception if extension already initialized as
# potentially new configuration would not be loaded.
raise Exception('Extension already initialized')

if config:
# If passed an explicit config then we must make sure to ignore
# anything set in the application config.
connection = _create_connection(config)
else:
# Set default config
config = {}
config.setdefault('db', app.config.get('MONGODB_DB', None))
config.setdefault('host', app.config.get('MONGODB_HOST', None))
config.setdefault('port', app.config.get('MONGODB_PORT', None))
config.setdefault('username',
app.config.get('MONGODB_USERNAME', None))
config.setdefault('password',
app.config.get('MONGODB_PASSWORD', None))

# Before using default config we check for MONGODB_SETTINGS
if 'MONGODB_SETTINGS' in app.config:
connection = _create_connection(app.config['MONGODB_SETTINGS'])
else:
connection = _create_connection(config)

# Store objects in application instance so that multiple apps do
# not end up accessing the same objects.
app.extensions['mongoengine'] = {self: {'app': app,
'conn': connection}}

@property
def connection(self):
return current_app.extensions['mongoengine'][self]['conn']


class BaseQuerySet(QuerySet):
Expand Down
15 changes: 15 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import flask
import unittest

class FlaskMongoEngineTestCase(unittest.TestCase):
"""Parent class of all test cases"""

def setUp(self):
self.app = flask.Flask(__name__)
self.app.config['MONGODB_DB'] = 'testing'
self.app.config['TESTING'] = True
self.ctx = self.app.app_context()
self.ctx.push()

def tearDown(self):
self.ctx.pop()
54 changes: 26 additions & 28 deletions tests/test_basic_app.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import sys
sys.path[0:0] = [""]

import unittest
import datetime
import flask

from flask.ext.mongoengine import MongoEngine
from . import FlaskMongoEngineTestCase


class BasicAppTestCase(unittest.TestCase):
class BasicAppTestCase(FlaskMongoEngineTestCase):

def setUp(self):
app = flask.Flask(__name__)
app.config['MONGODB_DB'] = 'testing'
app.config['TESTING'] = True
super(BasicAppTestCase, self).setUp()
db = MongoEngine()

class Todo(db.Document):
Expand All @@ -22,53 +20,54 @@ class Todo(db.Document):
done = db.BooleanField(default=False)
pub_date = db.DateTimeField(default=datetime.datetime.now)

db.init_app(app)
db.init_app(self.app)

Todo.drop_collection()
self.Todo = Todo

@app.route('/')
@self.app.route('/')
def index():
return '\n'.join(x.title for x in self.Todo.objects)

@app.route('/add', methods=['POST'])
@self.app.route('/add', methods=['POST'])
def add():
form = flask.request.form
todo = self.Todo(title=form['title'],
text=form['text'])
todo.save()
return 'added'

@app.route('/show/<id>/')
@self.app.route('/show/<id>/')
def show(id):
todo = self.Todo.objects.get_or_404(id=id)
return '\n'.join([todo.title, todo.text])

self.app = app
self.db = db

def test_connection_kwargs(self):
app = flask.Flask(__name__)
app.config['MONGODB_SETTINGS'] = {
self.app.config['MONGODB_SETTINGS'] = {
'DB': 'testing_tz_aware',
'alias': 'tz_aware_true',
'ALIAS': 'tz_aware_true',
'TZ_AWARE': True
}
app.config['TESTING'] = True
self.app.config['TESTING'] = True
db = MongoEngine()
db.init_app(app)
db.init_app(self.app)
self.assertTrue(db.connection.tz_aware)

app.config['MONGODB_SETTINGS'] = {
# PyMongo defaults to tz_aware = True so we have to explicitly turn
# it off.
self.app.config['MONGODB_SETTINGS'] = {
'DB': 'testing',
'alias': 'tz_aware_false',
'ALIAS': 'tz_aware_false',
'TZ_AWARE': False
}
db.init_app(app)
db = MongoEngine()
db.init_app(self.app)
self.assertFalse(db.connection.tz_aware)

def test_connection_kwargs_as_list(self):
app = flask.Flask(__name__)
app.config['MONGODB_SETTINGS'] = [{
self.app.config['MONGODB_SETTINGS'] = [{
'DB': 'testing_tz_aware',
'alias': 'tz_aware_true',
'TZ_AWARE': True
Expand All @@ -77,23 +76,22 @@ def test_connection_kwargs_as_list(self):
'alias': 'tz_aware_false',
'TZ_AWARE': False
}]
app.config['TESTING'] = True
self.app.config['TESTING'] = True
db = MongoEngine()
db.init_app(app)
db.init_app(self.app)
self.assertTrue(db.connection['tz_aware_true'].tz_aware)
self.assertFalse(db.connection['tz_aware_false'].tz_aware)

def test_connection_default(self):
app = flask.Flask(__name__)
app.config['MONGODB_SETTINGS'] = {}
app.config['TESTING'] = True
self.app.config['MONGODB_SETTINGS'] = {}
self.app.config['TESTING'] = True

db = MongoEngine()
db.init_app(app)
db.init_app(self.app)

app.config['TESTING'] = True
self.app.config['TESTING'] = True
db = MongoEngine()
db.init_app(app)
db.init_app(self.app)

def test_with_id(self):
c = self.app.test_client()
Expand Down
17 changes: 8 additions & 9 deletions tests/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,22 @@
from flask.ext.mongoengine.wtf import model_form

from mongoengine import queryset_manager
from . import FlaskMongoEngineTestCase


class WTFormsAppTestCase(unittest.TestCase):
class WTFormsAppTestCase(FlaskMongoEngineTestCase):

def setUp(self):
super(WTFormsAppTestCase, self).setUp()
self.db_name = 'testing'

app = flask.Flask(__name__)
app.config['MONGODB_DB'] = self.db_name
app.config['TESTING'] = True
self.app.config['MONGODB_DB'] = self.db_name
self.app.config['TESTING'] = True
# For Flask-WTF < 0.9
app.config['CSRF_ENABLED'] = False
self.app.config['CSRF_ENABLED'] = False
# For Flask-WTF >= 0.9
app.config['WTF_CSRF_ENABLED'] = False
self.app = app
self.app.config['WTF_CSRF_ENABLED'] = False
self.db = MongoEngine()
self.db.init_app(app)
self.db.init_app(self.app)

def tearDown(self):
self.db.connection.drop_database(self.db_name)
Expand Down
15 changes: 7 additions & 8 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from flask.ext.mongoengine import MongoEngine
from flask.ext.mongoengine.json import MongoEngineJSONEncoder
from . import FlaskMongoEngineTestCase


class DummyEncoder(flask.json.JSONEncoder):
Expand All @@ -17,7 +18,7 @@ class DummyEncoder(flask.json.JSONEncoder):
'''


class JSONAppTestCase(unittest.TestCase):
class JSONAppTestCase(FlaskMongoEngineTestCase):

def dictContains(self,superset,subset):
for k,v in subset.items():
Expand All @@ -29,14 +30,12 @@ def assertDictContains(self,superset,subset):
return self.assertTrue(self.dictContains(superset,subset))

def setUp(self):
app = flask.Flask(__name__)
app.config['MONGODB_DB'] = 'testing'
app.config['TESTING'] = True
app.json_encoder = DummyEncoder
super(JSONAppTestCase, self).setUp()
self.app.config['MONGODB_DB'] = 'testing'
self.app.config['TESTING'] = True
self.app.json_encoder = DummyEncoder
db = MongoEngine()
db.init_app(app)

self.app = app
db.init_app(self.app)
self.db = db

def test_inheritance(self):
Expand Down
Loading