Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite connection module #429

Merged
merged 6 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions flask_mongoengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def current_mongoengine_instance():
return k


class MongoEngine(object):
class MongoEngine:
"""Main class used for initialization of Flask-MongoEngine."""

def __init__(self, app=None, config=None):
Expand Down Expand Up @@ -110,7 +110,7 @@ def init_app(self, app, config=None):
app.extensions["mongoengine"][self] = s

@property
def connection(self):
def connection(self) -> dict:
"""
Return MongoDB connection(s) associated with this MongoEngine
instance.
Expand Down
179 changes: 79 additions & 100 deletions flask_mongoengine/connection.py
Original file line number Diff line number Diff line change
@@ -1,138 +1,117 @@
from typing import List

import mongoengine
from pymongo import ReadPreference, uri_parser

__all__ = (
"create_connections",
"get_connection_settings",
"InvalidSettingsError",
)


MONGODB_CONF_VARS = (
"MONGODB_ALIAS",
"MONGODB_DB",
"MONGODB_HOST",
"MONGODB_IS_MOCK",
"MONGODB_PASSWORD",
"MONGODB_PORT",
"MONGODB_USERNAME",
"MONGODB_CONNECT",
"MONGODB_TZ_AWARE",
)

def _get_name(setting_name: str) -> str:
"""
Return known pymongo setting name, or lower case name for unknown.

class InvalidSettingsError(Exception):
pass
This problem discovered in issue #451. As mentioned there pymongo settings are not
case-sensitive, but mongoengine use exact name of some settings for matching,
overwriting pymongo behaviour.

This function address this issue, and potentially address cases when pymongo will
become case-sensitive in some settings by same reasons as mongoengine done.

def _sanitize_settings(settings):
"""Given a dict of connection settings, sanitize the keys and fall
back to some sane defaults.
Based on pymongo 4.1.1 settings.
"""
# Remove the "MONGODB_" prefix and make all settings keys lower case.
KNOWN_CAMEL_CASE_SETTINGS = {
"directconnection": "directConnection",
"maxpoolsize": "maxPoolSize",
"minpoolsize": "minPoolSize",
"maxidletimems": "maxIdleTimeMS",
"maxconnecting": "maxConnecting",
"sockettimeoutms": "socketTimeoutMS",
"connecttimeoutms": "connectTimeoutMS",
"serverselectiontimeoutms": "serverSelectionTimeoutMS",
"waitqueuetimeoutms": "waitQueueTimeoutMS",
"heartbeatfrequencyms": "heartbeatFrequencyMS",
"retrywrites": "retryWrites",
"retryreads": "retryReads",
"zlibcompressionlevel": "zlibCompressionLevel",
"uuidrepresentation": "uuidRepresentation",
"srvservicename": "srvServiceName",
"wtimeoutms": "wTimeoutMS",
"replicaset": "replicaSet",
"readpreference": "readPreference",
"readpreferencetags": "readPreferenceTags",
"maxstalenessseconds": "maxStalenessSeconds",
"authsource": "authSource",
"authmechanism": "authMechanism",
"authmechanismproperties": "authMechanismProperties",
"tlsinsecure": "tlsInsecure",
"tlsallowinvalidcertificates": "tlsAllowInvalidCertificates",
"tlsallowinvalidhostnames": "tlsAllowInvalidHostnames",
"tlscafile": "tlsCAFile",
"tlscertificatekeyfile": "tlsCertificateKeyFile",
"tlscrlfile": "tlsCRLFile",
"tlscertificatekeyfilepassword": "tlsCertificateKeyFilePassword",
"tlsdisableocspendpointcheck": "tlsDisableOCSPEndpointCheck",
"readconcernlevel": "readConcernLevel",
}
_setting_name = KNOWN_CAMEL_CASE_SETTINGS.get(setting_name.lower())
return setting_name.lower() if _setting_name is None else _setting_name


def _sanitize_settings(settings: dict) -> dict:
"""Remove MONGODB_ prefix from dict values, to correct bypass to mongoengine."""
resolved_settings = {}
for k, v in settings.items():
if k.startswith("MONGODB_"):
k = k[len("MONGODB_") :]
k = k.lower()
resolved_settings[k] = v

# Handle uri style connections
if "://" in resolved_settings.get("host", ""):
# this section pulls the database name from the URI
# PyMongo requires URI to start with mongodb:// to parse
# this workaround allows mongomock to work
uri_to_check = resolved_settings["host"]

if uri_to_check.startswith("mongomock://"):
uri_to_check = uri_to_check.replace("mongomock://", "mongodb://")

uri_dict = uri_parser.parse_uri(uri_to_check)
resolved_settings["db"] = uri_dict["database"]

# Add a default name param or use the "db" key if exists
if resolved_settings.get("db"):
resolved_settings["name"] = resolved_settings.pop("db")
else:
resolved_settings["name"] = "test"

# Add various default values.
resolved_settings["alias"] = resolved_settings.get(
"alias", mongoengine.DEFAULT_CONNECTION_NAME
)
# TODO do we have to specify it here? MongoEngine should take care of that
resolved_settings["host"] = resolved_settings.get("host", "localhost")
# TODO this is the default host in pymongo.mongo_client.MongoClient, we may
# not need to explicitly set a default here
resolved_settings["port"] = resolved_settings.get("port", 27017)
# TODO this is the default port in pymongo.mongo_client.MongoClient, we may
# not need to explicitly set a default here

# Default to ReadPreference.PRIMARY if no read_preference is supplied
resolved_settings["read_preference"] = resolved_settings.get(
"read_preference", ReadPreference.PRIMARY
)

# Clean up empty values
for k, v in list(resolved_settings.items()):
if v is None:
del resolved_settings[k]
# Replace with k.lower().removeprefix("mongodb_") when python 3.8 support ends.
key = _get_name(k[8:]) if k.lower().startswith("mongodb_") else _get_name(k)
resolved_settings[key] = v

return resolved_settings


def get_connection_settings(config):
def get_connection_settings(config: dict) -> List[dict]:
"""
Given a config dict, return a sanitized dict of MongoDB connection
settings that we can then use to establish connections. For new
applications, settings should exist in a ``MONGODB_SETTINGS`` key, but
for backward compatibility we also support several config keys
prefixed by ``MONGODB_``, e.g. ``MONGODB_HOST``, ``MONGODB_PORT``, etc.
"""

# If no "MONGODB_SETTINGS", sanitize the "MONGODB_" keys as single connection.
if "MONGODB_SETTINGS" not in config:
config = {k: v for k, v in config.items() if k.lower().startswith("mongodb_")}
return [_sanitize_settings(config)]

# Sanitize all the settings living under a "MONGODB_SETTINGS" config var
if "MONGODB_SETTINGS" in config:
settings = config["MONGODB_SETTINGS"]
settings = config["MONGODB_SETTINGS"]

# If MONGODB_SETTINGS is a list of settings dicts, sanitize each
# dict separately.
if isinstance(settings, list):
return [_sanitize_settings(setting) for setting in settings]
else:
return _sanitize_settings(settings)
# If MONGODB_SETTINGS is a list of settings dicts, sanitize each dict separately.
if isinstance(settings, list):
return [_sanitize_settings(settings_dict) for settings_dict in settings]

else:
config = {k: v for k, v in config.items() if k in MONGODB_CONF_VARS}
return _sanitize_settings(config)
# Otherwise, it should be a single dict describing a single connection.
return [_sanitize_settings(settings)]


def create_connections(config):
def create_connections(config: dict):
"""
Given Flask application's config dict, extract relevant config vars
out of it and establish MongoEngine connection(s) based on them.
"""
# Validate that the config is a dict
if config is None or not isinstance(config, dict):
raise InvalidSettingsError("Invalid application configuration")
# Validate that the config is a dict and dict is not empty
if not config or not isinstance(config, dict):
raise TypeError(f"Config dictionary expected, but {type(config)} received.")

# Get sanitized connection settings based on the config
conn_settings = get_connection_settings(config)

# If conn_settings is a list, set up each item as a separate connection
# and return a dict of connection aliases and their connections.
if isinstance(conn_settings, list):
connections = {}
for each in conn_settings:
alias = each["alias"]
connections[alias] = _connect(each)
return connections
connection_settings = get_connection_settings(config)

# Otherwise, return a single connection
return _connect(conn_settings)
connections = {}
for connection_setting in connection_settings:
alias = connection_setting.setdefault(
"alias", mongoengine.DEFAULT_CONNECTION_NAME
)
connections[alias] = mongoengine.connect(**connection_setting)


def _connect(conn_settings):
"""Given a dict of connection settings, create a connection to
MongoDB by calling {func}`mongoengine.connect` and return its result.
"""
db_name = conn_settings.pop("name")
return mongoengine.connect(db_name, **conn_settings)
return connections
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,22 @@ def app():
def db(app):
app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db"
test_db = MongoEngine(app)
db_name = test_db.connection.get_database("flask_mongoengine_test_db").name
db_name = (
test_db.connection["default"].get_database("flask_mongoengine_test_db").name
)

if not db_name.endswith("_test_db"):
raise RuntimeError(
f"DATABASE_URL must point to testing db, not to master db ({db_name})"
)

# Clear database before tests, for cases when some test failed before.
test_db.connection.drop_database(db_name)
test_db.connection["default"].drop_database(db_name)

yield test_db

# Clear database after tests, for graceful exit.
test_db.connection.drop_database(db_name)
test_db.connection["default"].drop_database(db_name)


@pytest.fixture()
Expand Down
44 changes: 17 additions & 27 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import mongoengine
import pymongo
import pytest
from mongoengine.connection import ConnectionFailure
from mongoengine.context_managers import switch_db
Expand All @@ -11,6 +10,14 @@
from flask_mongoengine import MongoEngine, current_mongoengine_instance


def is_mongo_mock_installed() -> bool:
try:
import mongomock.__version__ # noqa
except ImportError:
return False
return True


def test_connection__should_use_defaults__if_no_settings_provided(app):
"""Make sure a simple connection to a standalone MongoDB works."""
db = MongoEngine()
Expand Down Expand Up @@ -129,6 +136,9 @@ def test_connection__should_parse_host_uri__if_host_formatted_as_uri(
assert connection.PORT == 27017


@pytest.mark.skipif(
is_mongo_mock_installed(), reason="This test require mongomock not exist"
)
@pytest.mark.parametrize(
("config_extension"),
[
Expand Down Expand Up @@ -281,46 +291,26 @@ class Todo(db.Document):
assert doc is not None


def test_ignored_mongodb_prefix_config(app):
"""Config starting by MONGODB_ but not used by flask-mongoengine
should be ignored.
"""
def test_incorrect_value_with_mongodb_prefix__should_trigger_mongoengine_raise(app):
db = MongoEngine()
app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db"
# Invalid host, should trigger exception if used
app.config["MONGODB_TEST_HOST"] = "dummy://localhost:27017/test"
db.init_app(app)

connection = mongoengine.get_connection()
mongo_engine_db = mongoengine.get_db()
assert isinstance(mongo_engine_db, Database)
assert isinstance(connection, MongoClient)
assert mongo_engine_db.name == "flask_mongoengine_test_db"
assert connection.HOST == "localhost"
assert connection.PORT == 27017
with pytest.raises(ConnectionFailure):
db.init_app(app)


def test_connection_kwargs(app):
"""Make sure additional connection kwargs work."""

# Figure out whether to use "MAX_POOL_SIZE" or "MAXPOOLSIZE" based
# on PyMongo version (former was changed to the latter as described
# in https://jira.mongodb.org/browse/PYTHON-854)
# TODO remove once PyMongo < 3.0 support is dropped
if pymongo.version_tuple[0] >= 3:
MAX_POOL_SIZE_KEY = "MAXPOOLSIZE"
else:
MAX_POOL_SIZE_KEY = "MAX_POOL_SIZE"

app.config["MONGODB_SETTINGS"] = {
"ALIAS": "tz_aware_true",
"DB": "flask_mongoengine_test_db",
"TZ_AWARE": True,
"READ_PREFERENCE": ReadPreference.SECONDARY,
MAX_POOL_SIZE_KEY: 10,
"MAXPOOLSIZE": 10,
}
db = MongoEngine(app)

assert db.connection.codec_options.tz_aware
# assert db.connection.max_pool_size == 10
assert db.connection.read_preference == ReadPreference.SECONDARY
assert db.connection["tz_aware_true"].codec_options.tz_aware
assert db.connection["tz_aware_true"].read_preference == ReadPreference.SECONDARY
8 changes: 5 additions & 3 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,22 @@ def extended_db(app):
app.json_encoder = DummyEncoder
app.config["MONGODB_HOST"] = "mongodb://localhost:27017/flask_mongoengine_test_db"
test_db = MongoEngine(app)
db_name = test_db.connection.get_database("flask_mongoengine_test_db").name
db_name = (
test_db.connection["default"].get_database("flask_mongoengine_test_db").name
)

if not db_name.endswith("_test_db"):
raise RuntimeError(
f"DATABASE_URL must point to testing db, not to master db ({db_name})"
)

# Clear database before tests, for cases when some test failed before.
test_db.connection.drop_database(db_name)
test_db.connection["default"].drop_database(db_name)

yield test_db

# Clear database after tests, for graceful exit.
test_db.connection.drop_database(db_name)
test_db.connection["default"].drop_database(db_name)


class DummyEncoder(flask.json.JSONEncoder):
Expand Down