diff --git a/flask_mongoengine/__init__.py b/flask_mongoengine/__init__.py index 1d1167cd..47683937 100644 --- a/flask_mongoengine/__init__.py +++ b/flask_mongoengine/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import -from flask import abort +from flask import abort, current_app import mongoengine @@ -23,6 +23,14 @@ def _include_mongoengine(obj): def _create_connection(conn_settings): + + # Handle multiple connections recursively + if isinstance(conn_settings, list): + connections = {} + for conn in conn_settings: + connections[conn.get('alias')] = _create_connection(conn) + return connections + conn = dict([(k.lower(), v) for k, v in conn_settings.items() if v]) if 'replicaset' in conn: @@ -38,7 +46,7 @@ def _create_connection(conn_settings): class MongoEngine(object): - def __init__(self, app=None): + def __init__(self, app=None, config=None): _include_mongoengine(self) @@ -46,32 +54,52 @@ def __init__(self, app=None): self.DynamicDocument = DynamicDocument if app is not None: - self.init_app(app) + self.init_app(app, config) - def init_app(self, app): + def init_app(self, app, config=None): - conn_settings = app.config.get('MONGODB_SETTINGS', None) + app.extensions = getattr(app, 'extensions', {}) - if not conn_settings: - conn_settings = { - 'db': app.config.get('MONGODB_DB', None), - 'username': app.config.get('MONGODB_USERNAME', None), - 'password': app.config.get('MONGODB_PASSWORD', None), - 'host': app.config.get('MONGODB_HOST', None), - 'port': int(app.config.get('MONGODB_PORT', 0)) or None - } + # Make documents JSON serializable + overide_json_encoder(app) - if isinstance(conn_settings, list): - self.connection = {} - for conn in conn_settings: - self.connection[conn.get('alias')] = _create_connection(conn) - else: - self.connection = _create_connection(conn_settings) + if not 'mongoengine' in app.extensions: + app.extensions['mongoengine'] = {} - app.extensions = getattr(app, 'extensions', {}) - app.extensions['mongoengine'] = self - self.app = app - overide_json_encoder(app) + if self in app.extensions['mongoengine']: + # Raise an exception if extension already initialized as + # potentially new configuration would not be loaded. + raise Exception('Extension already initialized') + + if config: + # If passed an explicit config then we must make sure to ignore + # anything set in the application config. + connection = _create_connection(config) + else: + # Set default config + config = {} + config.setdefault('db', app.config.get('MONGODB_DB', None)) + config.setdefault('host', app.config.get('MONGODB_HOST', None)) + config.setdefault('port', app.config.get('MONGODB_PORT', None)) + config.setdefault('username', + app.config.get('MONGODB_USERNAME', None)) + config.setdefault('password', + app.config.get('MONGODB_PASSWORD', None)) + + # Before using default config we check for MONGODB_SETTINGS + if 'MONGODB_SETTINGS' in app.config: + connection = _create_connection(app.config['MONGODB_SETTINGS']) + else: + connection = _create_connection(config) + + # Store objects in application instance so that multiple apps do + # not end up accessing the same objects. + app.extensions['mongoengine'] = {self: {'app': app, + 'conn': connection}} + + @property + def connection(self): + return current_app.extensions['mongoengine'][self]['conn'] class BaseQuerySet(QuerySet): diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..105e493f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,15 @@ +import flask +import unittest + +class FlaskMongoEngineTestCase(unittest.TestCase): + """Parent class of all test cases""" + + def setUp(self): + self.app = flask.Flask(__name__) + self.app.config['MONGODB_DB'] = 'testing' + self.app.config['TESTING'] = True + self.ctx = self.app.app_context() + self.ctx.push() + + def tearDown(self): + self.ctx.pop() diff --git a/tests/test_basic_app.py b/tests/test_basic_app.py index 42241edf..40111a2d 100644 --- a/tests/test_basic_app.py +++ b/tests/test_basic_app.py @@ -1,19 +1,17 @@ import sys -sys.path[0:0] = [""] import unittest import datetime import flask from flask.ext.mongoengine import MongoEngine +from . import FlaskMongoEngineTestCase -class BasicAppTestCase(unittest.TestCase): +class BasicAppTestCase(FlaskMongoEngineTestCase): def setUp(self): - app = flask.Flask(__name__) - app.config['MONGODB_DB'] = 'testing' - app.config['TESTING'] = True + super(BasicAppTestCase, self).setUp() db = MongoEngine() class Todo(db.Document): @@ -22,16 +20,16 @@ class Todo(db.Document): done = db.BooleanField(default=False) pub_date = db.DateTimeField(default=datetime.datetime.now) - db.init_app(app) + db.init_app(self.app) Todo.drop_collection() self.Todo = Todo - @app.route('/') + @self.app.route('/') def index(): return '\n'.join(x.title for x in self.Todo.objects) - @app.route('/add', methods=['POST']) + @self.app.route('/add', methods=['POST']) def add(): form = flask.request.form todo = self.Todo(title=form['title'], @@ -39,36 +37,37 @@ def add(): todo.save() return 'added' - @app.route('/show//') + @self.app.route('/show//') def show(id): todo = self.Todo.objects.get_or_404(id=id) return '\n'.join([todo.title, todo.text]) - self.app = app self.db = db def test_connection_kwargs(self): - app = flask.Flask(__name__) - app.config['MONGODB_SETTINGS'] = { + self.app.config['MONGODB_SETTINGS'] = { 'DB': 'testing_tz_aware', - 'alias': 'tz_aware_true', + 'ALIAS': 'tz_aware_true', 'TZ_AWARE': True } - app.config['TESTING'] = True + self.app.config['TESTING'] = True db = MongoEngine() - db.init_app(app) + db.init_app(self.app) self.assertTrue(db.connection.tz_aware) - app.config['MONGODB_SETTINGS'] = { + # PyMongo defaults to tz_aware = True so we have to explicitly turn + # it off. + self.app.config['MONGODB_SETTINGS'] = { 'DB': 'testing', - 'alias': 'tz_aware_false', + 'ALIAS': 'tz_aware_false', + 'TZ_AWARE': False } - db.init_app(app) + db = MongoEngine() + db.init_app(self.app) self.assertFalse(db.connection.tz_aware) def test_connection_kwargs_as_list(self): - app = flask.Flask(__name__) - app.config['MONGODB_SETTINGS'] = [{ + self.app.config['MONGODB_SETTINGS'] = [{ 'DB': 'testing_tz_aware', 'alias': 'tz_aware_true', 'TZ_AWARE': True @@ -77,23 +76,22 @@ def test_connection_kwargs_as_list(self): 'alias': 'tz_aware_false', 'TZ_AWARE': False }] - app.config['TESTING'] = True + self.app.config['TESTING'] = True db = MongoEngine() - db.init_app(app) + db.init_app(self.app) self.assertTrue(db.connection['tz_aware_true'].tz_aware) self.assertFalse(db.connection['tz_aware_false'].tz_aware) def test_connection_default(self): - app = flask.Flask(__name__) - app.config['MONGODB_SETTINGS'] = {} - app.config['TESTING'] = True + self.app.config['MONGODB_SETTINGS'] = {} + self.app.config['TESTING'] = True db = MongoEngine() - db.init_app(app) + db.init_app(self.app) - app.config['TESTING'] = True + self.app.config['TESTING'] = True db = MongoEngine() - db.init_app(app) + db.init_app(self.app) def test_with_id(self): c = self.app.test_client() diff --git a/tests/test_forms.py b/tests/test_forms.py index a34bc92d..1865f8be 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -14,23 +14,22 @@ from flask.ext.mongoengine.wtf import model_form from mongoengine import queryset_manager +from . import FlaskMongoEngineTestCase -class WTFormsAppTestCase(unittest.TestCase): +class WTFormsAppTestCase(FlaskMongoEngineTestCase): def setUp(self): + super(WTFormsAppTestCase, self).setUp() self.db_name = 'testing' - - app = flask.Flask(__name__) - app.config['MONGODB_DB'] = self.db_name - app.config['TESTING'] = True + self.app.config['MONGODB_DB'] = self.db_name + self.app.config['TESTING'] = True # For Flask-WTF < 0.9 - app.config['CSRF_ENABLED'] = False + self.app.config['CSRF_ENABLED'] = False # For Flask-WTF >= 0.9 - app.config['WTF_CSRF_ENABLED'] = False - self.app = app + self.app.config['WTF_CSRF_ENABLED'] = False self.db = MongoEngine() - self.db.init_app(app) + self.db.init_app(self.app) def tearDown(self): self.db.connection.drop_database(self.db_name) diff --git a/tests/test_json.py b/tests/test_json.py index 7210d85a..a1d68e4f 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -7,6 +7,7 @@ from flask.ext.mongoengine import MongoEngine from flask.ext.mongoengine.json import MongoEngineJSONEncoder +from . import FlaskMongoEngineTestCase class DummyEncoder(flask.json.JSONEncoder): @@ -17,7 +18,7 @@ class DummyEncoder(flask.json.JSONEncoder): ''' -class JSONAppTestCase(unittest.TestCase): +class JSONAppTestCase(FlaskMongoEngineTestCase): def dictContains(self,superset,subset): for k,v in subset.items(): @@ -29,14 +30,12 @@ def assertDictContains(self,superset,subset): return self.assertTrue(self.dictContains(superset,subset)) def setUp(self): - app = flask.Flask(__name__) - app.config['MONGODB_DB'] = 'testing' - app.config['TESTING'] = True - app.json_encoder = DummyEncoder + super(JSONAppTestCase, self).setUp() + self.app.config['MONGODB_DB'] = 'testing' + self.app.config['TESTING'] = True + self.app.json_encoder = DummyEncoder db = MongoEngine() - db.init_app(app) - - self.app = app + db.init_app(self.app) self.db = db def test_inheritance(self): diff --git a/tests/test_json_app.py b/tests/test_json_app.py index 58063a1d..aad3a89e 100644 --- a/tests/test_json_app.py +++ b/tests/test_json_app.py @@ -6,8 +6,9 @@ import flask from flask.ext.mongoengine import MongoEngine +from . import FlaskMongoEngineTestCase -class JSONAppTestCase(unittest.TestCase): +class JSONAppTestCase(FlaskMongoEngineTestCase): def dictContains(self,superset,subset): for k,v in subset.items(): @@ -19,9 +20,9 @@ def assertDictContains(self,superset,subset): return self.assertTrue(self.dictContains(superset,subset)) def setUp(self): - app = flask.Flask(__name__) - app.config['MONGODB_DB'] = 'testing' - app.config['TESTING'] = True + super(JSONAppTestCase, self).setUp() + self.app.config['MONGODB_DB'] = 'testing' + self.app.config['TESTING'] = True db = MongoEngine() class Todo(db.Document): @@ -30,16 +31,16 @@ class Todo(db.Document): done = db.BooleanField(default=False) pub_date = db.DateTimeField(default=datetime.datetime.now) - db.init_app(app) + db.init_app(self.app) Todo.drop_collection() self.Todo = Todo - @app.route('/') + @self.app.route('/') def index(): return flask.jsonify(result=self.Todo.objects()) - @app.route('/add', methods=['POST']) + @self.app.route('/add', methods=['POST']) def add(): form = flask.request.form todo = self.Todo(title=form['title'], @@ -47,32 +48,30 @@ def add(): todo.save() return flask.jsonify(result=todo) - @app.route('/show//') + @self.app.route('/show//') def show(id): return flask.jsonify(result=self.Todo.objects.get_or_404(id=id)) - - self.app = app self.db = db def test_connection_kwargs(self): - app = flask.Flask(__name__) - app.config['MONGODB_SETTINGS'] = { + self.app.config['MONGODB_SETTINGS'] = { 'DB': 'testing_tz_aware', - 'alias': 'tz_aware_true', + 'ALIAS': 'tz_aware_true', 'TZ_AWARE': True, } - app.config['TESTING'] = True + self.app.config['TESTING'] = True db = MongoEngine() - db.init_app(app) + db.init_app(self.app) self.assertTrue(db.connection.tz_aware) - app.config['MONGODB_SETTINGS'] = { + db = MongoEngine() + self.app.config['MONGODB_SETTINGS'] = { 'DB': 'testing', 'alias': 'tz_aware_false', } - db.init_app(app) + db.init_app(self.app) self.assertFalse(db.connection.tz_aware) def test_with_id(self): diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 25337601..0402e817 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -6,20 +6,19 @@ from werkzeug.exceptions import NotFound from flask.ext.mongoengine import MongoEngine, Pagination, ListFieldPagination +from . import FlaskMongoEngineTestCase -class PaginationTestCase(unittest.TestCase): +class PaginationTestCase(FlaskMongoEngineTestCase): def setUp(self): + super(PaginationTestCase, self).setUp() self.db_name = 'testing' - - app = flask.Flask(__name__) - app.config['MONGODB_DB'] = self.db_name - app.config['TESTING'] = True - app.config['CSRF_ENABLED'] = False - self.app = app + self.app.config['MONGODB_DB'] = self.db_name + self.app.config['TESTING'] = True + self.app.config['CSRF_ENABLED'] = False self.db = MongoEngine() - self.db.init_app(app) + self.db.init_app(self.app) def tearDown(self): self.db.connection.drop_database(self.db_name) diff --git a/tests/test_session.py b/tests/test_session.py index a3361763..60a78c38 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -6,34 +6,33 @@ from flask import session from flask.ext.mongoengine import MongoEngine, MongoEngineSessionInterface +from . import FlaskMongoEngineTestCase -class BasicAppTestCase(unittest.TestCase): +class SessionTestCase(FlaskMongoEngineTestCase): def setUp(self): + super(SessionTestCase, self).setUp() self.db_name = 'testing' + self.app.config['MONGODB_DB'] = self.db_name + self.app.config['TESTING'] = True + db = MongoEngine(self.app) + self.app.session_interface = MongoEngineSessionInterface(db) - app = flask.Flask(__name__) - app.config['MONGODB_DB'] = self.db_name - app.config['TESTING'] = True - db = MongoEngine(app) - app.session_interface = MongoEngineSessionInterface(db) - - @app.route('/') + @self.app.route('/') def index(): session["a"] = "hello session" return session["a"] - @app.route('/check-session') + @self.app.route('/check-session') def check_session(): return "session: %s" % session["a"] - @app.route('/check-session-database') + @self.app.route('/check-session-database') def check_session_database(): sessions = self.app.session_interface.cls.objects.count() return "sessions: %s" % sessions - self.app = app self.db = db def tearDown(self):