Skip to content

Commit

Permalink
Merge pull request #429 from MongoEngine/rewrite-connection-module
Browse files Browse the repository at this point in the history
Rewrite connection module
  • Loading branch information
insspb committed Jul 8, 2022
2 parents 89eb76e + fb83fd6 commit 6b18011
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 135 deletions.
4 changes: 2 additions & 2 deletions flask_mongoengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def current_mongoengine_instance():
return k


class MongoEngine(object):
class MongoEngine:
"""Main class used for initialization of Flask-MongoEngine."""

def __init__(self, app=None, config=None):
Expand Down Expand Up @@ -110,7 +110,7 @@ def init_app(self, app, config=None):
app.extensions["mongoengine"][self] = s

@property
def connection(self):
def connection(self) -> dict:
"""
Return MongoDB connection(s) associated with this MongoEngine
instance.
Expand Down
179 changes: 79 additions & 100 deletions flask_mongoengine/connection.py
Original file line number Diff line number Diff line change
@@ -1,138 +1,117 @@
from typing import List

import mongoengine
from pymongo import ReadPreference, uri_parser

__all__ = (
"create_connections",
"get_connection_settings",
"InvalidSettingsError",
)


MONGODB_CONF_VARS = (
"MONGODB_ALIAS",
"MONGODB_DB",
"MONGODB_HOST",
"MONGODB_IS_MOCK",
"MONGODB_PASSWORD",
"MONGODB_PORT",
"MONGODB_USERNAME",
"MONGODB_CONNECT",
"MONGODB_TZ_AWARE",
)

def _get_name(setting_name: str) -> str:
"""
Return known pymongo setting name, or lower case name for unknown.
class InvalidSettingsError(Exception):
pass
This problem discovered in issue #451. As mentioned there pymongo settings are not
case-sensitive, but mongoengine use exact name of some settings for matching,
overwriting pymongo behaviour.
This function address this issue, and potentially address cases when pymongo will
become case-sensitive in some settings by same reasons as mongoengine done.
def _sanitize_settings(settings):
"""Given a dict of connection settings, sanitize the keys and fall
back to some sane defaults.
Based on pymongo 4.1.1 settings.
"""
# Remove the "MONGODB_" prefix and make all settings keys lower case.
KNOWN_CAMEL_CASE_SETTINGS = {
"directconnection": "directConnection",
"maxpoolsize": "maxPoolSize",
"minpoolsize": "minPoolSize",
"maxidletimems": "maxIdleTimeMS",
"maxconnecting": "maxConnecting",
"sockettimeoutms": "socketTimeoutMS",
"connecttimeoutms": "connectTimeoutMS",
"serverselectiontimeoutms": "serverSelectionTimeoutMS",
"waitqueuetimeoutms": "waitQueueTimeoutMS",
"heartbeatfrequencyms": "heartbeatFrequencyMS",
"retrywrites": "retryWrites",
"retryreads": "retryReads",
"zlibcompressionlevel": "zlibCompressionLevel",
"uuidrepresentation": "uuidRepresentation",
"srvservicename": "srvServiceName",
"wtimeoutms": "wTimeoutMS",
"replicaset": "replicaSet",
"readpreference": "readPreference",
"readpreferencetags": "readPreferenceTags",
"maxstalenessseconds": "maxStalenessSeconds",
"authsource": "authSource",
"authmechanism": "authMechanism",
"authmechanismproperties": "authMechanismProperties",
"tlsinsecure": "tlsInsecure",
"tlsallowinvalidcertificates": "tlsAllowInvalidCertificates",
"tlsallowinvalidhostnames": "tlsAllowInvalidHostnames",
"tlscafile": "tlsCAFile",
"tlscertificatekeyfile": "tlsCertificateKeyFile",
"tlscrlfile": "tlsCRLFile",
"tlscertificatekeyfilepassword": "tlsCertificateKeyFilePassword",
"tlsdisableocspendpointcheck": "tlsDisableOCSPEndpointCheck",
"readconcernlevel": "readConcernLevel",
}
_setting_name = KNOWN_CAMEL_CASE_SETTINGS.get(setting_name.lower())
return setting_name.lower() if _setting_name is None else _setting_name


def _sanitize_settings(settings: dict) -> dict:
"""Remove MONGODB_ prefix from dict values, to correct bypass to mongoengine."""
resolved_settings = {}
for k, v in settings.items():
if k.startswith("MONGODB_"):
k = k[len("MONGODB_") :]
k = k.lower()
resolved_settings[k] = v

# Handle uri style connections
if "://" in resolved_settings.get("host", ""):
# this section pulls the database name from the URI
# PyMongo requires URI to start with mongodb:// to parse
# this workaround allows mongomock to work
uri_to_check = resolved_settings["host"]

if uri_to_check.startswith("mongomock://"):
uri_to_check = uri_to_check.replace("mongomock://", "mongodb://")

uri_dict = uri_parser.parse_uri(uri_to_check)
resolved_settings["db"] = uri_dict["database"]

# Add a default name param or use the "db" key if exists
if resolved_settings.get("db"):
resolved_settings["name"] = resolved_settings.pop("db")
else:
resolved_settings["name"] = "test"

# Add various default values.
resolved_settings["alias"] = resolved_settings.get(
"alias", mongoengine.DEFAULT_CONNECTION_NAME
)
# TODO do we have to specify it here? MongoEngine should take care of that
resolved_settings["host"] = resolved_settings.get("host", "localhost")
# TODO this is the default host in pymongo.mongo_client.MongoClient, we may
# not need to explicitly set a default here
resolved_settings["port"] = resolved_settings.get("port", 27017)
# TODO this is the default port in pymongo.mongo_client.MongoClient, we may
# not need to explicitly set a default here

# Default to ReadPreference.PRIMARY if no read_preference is supplied
resolved_settings["read_preference"] = resolved_settings.get(
"read_preference", ReadPreference.PRIMARY
)

# Clean up empty values
for k, v in list(resolved_settings.items()):
if v is None:
del resolved_settings[k]
# Replace with k.lower().removeprefix("mongodb_") when python 3.8 support ends.
key = _get_name(k[8:]) if k.lower().startswith("mongodb_") else _get_name(k)
resolved_settings[key] = v

return resolved_settings


def get_connection_settings(config):
def get_connection_settings(config: dict) -> List[dict]:
"""
Given a config dict, return a sanitized dict of MongoDB connection
settings that we can then use to establish connections. For new
applications, settings should exist in a ``MONGODB_SETTINGS`` key, but
for backward compatibility we also support several config keys
prefixed by ``MONGODB_``, e.g. ``MONGODB_HOST``, ``MONGODB_PORT``, etc.
"""

# If no "MONGODB_SETTINGS", sanitize the "MONGODB_" keys as single connection.
if "MONGODB_SETTINGS" not in config:
config = {k: v for k, v in config.items() if k.lower().startswith("mongodb_")}
return [_sanitize_settings(config)]

# Sanitize all the settings living under a "MONGODB_SETTINGS" config var
if "MONGODB_SETTINGS" in config:
settings = config["MONGODB_SETTINGS"]
settings = config["MONGODB_SETTINGS"]

# If MONGODB_SETTINGS is a list of settings dicts, sanitize each
# dict separately.
if isinstance(settings, list):
return [_sanitize_settings(setting) for setting in settings]
else:
return _sanitize_settings(settings)
# If MONGODB_SETTINGS is a list of settings dicts, sanitize each dict separately.
if isinstance(settings, list):
return [_sanitize_settings(settings_dict) for settings_dict in settings]

else:
config = {k: v for k, v in config.items() if k in MONGODB_CONF_VARS}
return _sanitize_settings(config)
# Otherwise, it should be a single dict describing a single connection.
return [_sanitize_settings(settings)]


def create_connections(config):
def create_connections(config: dict):
"""
Given Flask application's config dict, extract relevant config vars
out of it and establish MongoEngine connection(s) based on them.
"""
# Validate that the config is a dict
if config is None or not isinstance(config, dict):
raise InvalidSettingsError("Invalid application configuration")
# Validate that the config is a dict and dict is not empty
if not config or not isinstance(config, dict):
raise TypeError(f"Config dictionary expected, but {type(config)} received.")

# Get sanitized connection settings based on the config
conn_settings = get_connection_settings(config)

# If conn_settings is a list, set up each item as a separate connection
# and return a dict of connection aliases and their connections.
if isinstance(conn_settings, list):
connections = {}
for each in conn_settings:
alias = each["alias"]
connections[alias] = _connect(each)
return connections
connection_settings = get_connection_settings(config)

# Otherwise, return a single connection
return _connect(conn_settings)
connections = {}
for connection_setting in connection_settings:
alias = connection_setting.setdefault(
"alias", mongoengine.DEFAULT_CONNECTION_NAME
)
connections[alias] = mongoengine.connect(**connection_setting)


def _connect(conn_settings):
"""Given a dict of connection settings, create a connection to
MongoDB by calling {func}`mongoengine.connect` and return its result.
"""
db_name = conn_settings.pop("name")
return mongoengine.connect(db_name, **conn_settings)
return connections
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,22 @@ def app():
def db(app):
app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db"
test_db = MongoEngine(app)
db_name = test_db.connection.get_database("flask_mongoengine_test_db").name
db_name = (
test_db.connection["default"].get_database("flask_mongoengine_test_db").name
)

if not db_name.endswith("_test_db"):
raise RuntimeError(
f"DATABASE_URL must point to testing db, not to master db ({db_name})"
)

# Clear database before tests, for cases when some test failed before.
test_db.connection.drop_database(db_name)
test_db.connection["default"].drop_database(db_name)

yield test_db

# Clear database after tests, for graceful exit.
test_db.connection.drop_database(db_name)
test_db.connection["default"].drop_database(db_name)


@pytest.fixture()
Expand Down
44 changes: 17 additions & 27 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import mongoengine
import pymongo
import pytest
from mongoengine.connection import ConnectionFailure
from mongoengine.context_managers import switch_db
Expand All @@ -11,6 +10,14 @@
from flask_mongoengine import MongoEngine, current_mongoengine_instance


def is_mongo_mock_installed() -> bool:
try:
import mongomock.__version__ # noqa
except ImportError:
return False
return True


def test_connection__should_use_defaults__if_no_settings_provided(app):
"""Make sure a simple connection to a standalone MongoDB works."""
db = MongoEngine()
Expand Down Expand Up @@ -129,6 +136,9 @@ def test_connection__should_parse_host_uri__if_host_formatted_as_uri(
assert connection.PORT == 27017


@pytest.mark.skipif(
is_mongo_mock_installed(), reason="This test require mongomock not exist"
)
@pytest.mark.parametrize(
("config_extension"),
[
Expand Down Expand Up @@ -281,46 +291,26 @@ class Todo(db.Document):
assert doc is not None


def test_ignored_mongodb_prefix_config(app):
"""Config starting by MONGODB_ but not used by flask-mongoengine
should be ignored.
"""
def test_incorrect_value_with_mongodb_prefix__should_trigger_mongoengine_raise(app):
db = MongoEngine()
app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db"
# Invalid host, should trigger exception if used
app.config["MONGODB_TEST_HOST"] = "dummy://localhost:27017/test"
db.init_app(app)

connection = mongoengine.get_connection()
mongo_engine_db = mongoengine.get_db()
assert isinstance(mongo_engine_db, Database)
assert isinstance(connection, MongoClient)
assert mongo_engine_db.name == "flask_mongoengine_test_db"
assert connection.HOST == "localhost"
assert connection.PORT == 27017
with pytest.raises(ConnectionFailure):
db.init_app(app)


def test_connection_kwargs(app):
"""Make sure additional connection kwargs work."""

# Figure out whether to use "MAX_POOL_SIZE" or "MAXPOOLSIZE" based
# on PyMongo version (former was changed to the latter as described
# in https://jira.mongodb.org/browse/PYTHON-854)
# TODO remove once PyMongo < 3.0 support is dropped
if pymongo.version_tuple[0] >= 3:
MAX_POOL_SIZE_KEY = "MAXPOOLSIZE"
else:
MAX_POOL_SIZE_KEY = "MAX_POOL_SIZE"

app.config["MONGODB_SETTINGS"] = {
"ALIAS": "tz_aware_true",
"DB": "flask_mongoengine_test_db",
"TZ_AWARE": True,
"READ_PREFERENCE": ReadPreference.SECONDARY,
MAX_POOL_SIZE_KEY: 10,
"MAXPOOLSIZE": 10,
}
db = MongoEngine(app)

assert db.connection.codec_options.tz_aware
# assert db.connection.max_pool_size == 10
assert db.connection.read_preference == ReadPreference.SECONDARY
assert db.connection["tz_aware_true"].codec_options.tz_aware
assert db.connection["tz_aware_true"].read_preference == ReadPreference.SECONDARY
8 changes: 5 additions & 3 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,22 @@ def extended_db(app):
app.json_encoder = DummyEncoder
app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db"
test_db = MongoEngine(app)
db_name = test_db.connection.get_database("flask_mongoengine_test_db").name
db_name = (
test_db.connection["default"].get_database("flask_mongoengine_test_db").name
)

if not db_name.endswith("_test_db"):
raise RuntimeError(
f"DATABASE_URL must point to testing db, not to master db ({db_name})"
)

# Clear database before tests, for cases when some test failed before.
test_db.connection.drop_database(db_name)
test_db.connection["default"].drop_database(db_name)

yield test_db

# Clear database after tests, for graceful exit.
test_db.connection.drop_database(db_name)
test_db.connection["default"].drop_database(db_name)


class DummyEncoder(flask.json.JSONEncoder):
Expand Down

0 comments on commit 6b18011

Please sign in to comment.