-
Notifications
You must be signed in to change notification settings - Fork 254
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #429 from MongoEngine/rewrite-connection-module
Rewrite connection module
- Loading branch information
Showing
5 changed files
with
108 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,138 +1,117 @@ | ||
from typing import List | ||
|
||
import mongoengine | ||
from pymongo import ReadPreference, uri_parser | ||
|
||
__all__ = ( | ||
"create_connections", | ||
"get_connection_settings", | ||
"InvalidSettingsError", | ||
) | ||
|
||
|
||
MONGODB_CONF_VARS = ( | ||
"MONGODB_ALIAS", | ||
"MONGODB_DB", | ||
"MONGODB_HOST", | ||
"MONGODB_IS_MOCK", | ||
"MONGODB_PASSWORD", | ||
"MONGODB_PORT", | ||
"MONGODB_USERNAME", | ||
"MONGODB_CONNECT", | ||
"MONGODB_TZ_AWARE", | ||
) | ||
|
||
def _get_name(setting_name: str) -> str: | ||
""" | ||
Return known pymongo setting name, or lower case name for unknown. | ||
class InvalidSettingsError(Exception): | ||
pass | ||
This problem discovered in issue #451. As mentioned there pymongo settings are not | ||
case-sensitive, but mongoengine use exact name of some settings for matching, | ||
overwriting pymongo behaviour. | ||
This function address this issue, and potentially address cases when pymongo will | ||
become case-sensitive in some settings by same reasons as mongoengine done. | ||
def _sanitize_settings(settings): | ||
"""Given a dict of connection settings, sanitize the keys and fall | ||
back to some sane defaults. | ||
Based on pymongo 4.1.1 settings. | ||
""" | ||
# Remove the "MONGODB_" prefix and make all settings keys lower case. | ||
KNOWN_CAMEL_CASE_SETTINGS = { | ||
"directconnection": "directConnection", | ||
"maxpoolsize": "maxPoolSize", | ||
"minpoolsize": "minPoolSize", | ||
"maxidletimems": "maxIdleTimeMS", | ||
"maxconnecting": "maxConnecting", | ||
"sockettimeoutms": "socketTimeoutMS", | ||
"connecttimeoutms": "connectTimeoutMS", | ||
"serverselectiontimeoutms": "serverSelectionTimeoutMS", | ||
"waitqueuetimeoutms": "waitQueueTimeoutMS", | ||
"heartbeatfrequencyms": "heartbeatFrequencyMS", | ||
"retrywrites": "retryWrites", | ||
"retryreads": "retryReads", | ||
"zlibcompressionlevel": "zlibCompressionLevel", | ||
"uuidrepresentation": "uuidRepresentation", | ||
"srvservicename": "srvServiceName", | ||
"wtimeoutms": "wTimeoutMS", | ||
"replicaset": "replicaSet", | ||
"readpreference": "readPreference", | ||
"readpreferencetags": "readPreferenceTags", | ||
"maxstalenessseconds": "maxStalenessSeconds", | ||
"authsource": "authSource", | ||
"authmechanism": "authMechanism", | ||
"authmechanismproperties": "authMechanismProperties", | ||
"tlsinsecure": "tlsInsecure", | ||
"tlsallowinvalidcertificates": "tlsAllowInvalidCertificates", | ||
"tlsallowinvalidhostnames": "tlsAllowInvalidHostnames", | ||
"tlscafile": "tlsCAFile", | ||
"tlscertificatekeyfile": "tlsCertificateKeyFile", | ||
"tlscrlfile": "tlsCRLFile", | ||
"tlscertificatekeyfilepassword": "tlsCertificateKeyFilePassword", | ||
"tlsdisableocspendpointcheck": "tlsDisableOCSPEndpointCheck", | ||
"readconcernlevel": "readConcernLevel", | ||
} | ||
_setting_name = KNOWN_CAMEL_CASE_SETTINGS.get(setting_name.lower()) | ||
return setting_name.lower() if _setting_name is None else _setting_name | ||
|
||
|
||
def _sanitize_settings(settings: dict) -> dict: | ||
"""Remove MONGODB_ prefix from dict values, to correct bypass to mongoengine.""" | ||
resolved_settings = {} | ||
for k, v in settings.items(): | ||
if k.startswith("MONGODB_"): | ||
k = k[len("MONGODB_") :] | ||
k = k.lower() | ||
resolved_settings[k] = v | ||
|
||
# Handle uri style connections | ||
if "://" in resolved_settings.get("host", ""): | ||
# this section pulls the database name from the URI | ||
# PyMongo requires URI to start with mongodb:// to parse | ||
# this workaround allows mongomock to work | ||
uri_to_check = resolved_settings["host"] | ||
|
||
if uri_to_check.startswith("mongomock://"): | ||
uri_to_check = uri_to_check.replace("mongomock://", "mongodb://") | ||
|
||
uri_dict = uri_parser.parse_uri(uri_to_check) | ||
resolved_settings["db"] = uri_dict["database"] | ||
|
||
# Add a default name param or use the "db" key if exists | ||
if resolved_settings.get("db"): | ||
resolved_settings["name"] = resolved_settings.pop("db") | ||
else: | ||
resolved_settings["name"] = "test" | ||
|
||
# Add various default values. | ||
resolved_settings["alias"] = resolved_settings.get( | ||
"alias", mongoengine.DEFAULT_CONNECTION_NAME | ||
) | ||
# TODO do we have to specify it here? MongoEngine should take care of that | ||
resolved_settings["host"] = resolved_settings.get("host", "localhost") | ||
# TODO this is the default host in pymongo.mongo_client.MongoClient, we may | ||
# not need to explicitly set a default here | ||
resolved_settings["port"] = resolved_settings.get("port", 27017) | ||
# TODO this is the default port in pymongo.mongo_client.MongoClient, we may | ||
# not need to explicitly set a default here | ||
|
||
# Default to ReadPreference.PRIMARY if no read_preference is supplied | ||
resolved_settings["read_preference"] = resolved_settings.get( | ||
"read_preference", ReadPreference.PRIMARY | ||
) | ||
|
||
# Clean up empty values | ||
for k, v in list(resolved_settings.items()): | ||
if v is None: | ||
del resolved_settings[k] | ||
# Replace with k.lower().removeprefix("mongodb_") when python 3.8 support ends. | ||
key = _get_name(k[8:]) if k.lower().startswith("mongodb_") else _get_name(k) | ||
resolved_settings[key] = v | ||
|
||
return resolved_settings | ||
|
||
|
||
def get_connection_settings(config): | ||
def get_connection_settings(config: dict) -> List[dict]: | ||
""" | ||
Given a config dict, return a sanitized dict of MongoDB connection | ||
settings that we can then use to establish connections. For new | ||
applications, settings should exist in a ``MONGODB_SETTINGS`` key, but | ||
for backward compatibility we also support several config keys | ||
prefixed by ``MONGODB_``, e.g. ``MONGODB_HOST``, ``MONGODB_PORT``, etc. | ||
""" | ||
|
||
# If no "MONGODB_SETTINGS", sanitize the "MONGODB_" keys as single connection. | ||
if "MONGODB_SETTINGS" not in config: | ||
config = {k: v for k, v in config.items() if k.lower().startswith("mongodb_")} | ||
return [_sanitize_settings(config)] | ||
|
||
# Sanitize all the settings living under a "MONGODB_SETTINGS" config var | ||
if "MONGODB_SETTINGS" in config: | ||
settings = config["MONGODB_SETTINGS"] | ||
settings = config["MONGODB_SETTINGS"] | ||
|
||
# If MONGODB_SETTINGS is a list of settings dicts, sanitize each | ||
# dict separately. | ||
if isinstance(settings, list): | ||
return [_sanitize_settings(setting) for setting in settings] | ||
else: | ||
return _sanitize_settings(settings) | ||
# If MONGODB_SETTINGS is a list of settings dicts, sanitize each dict separately. | ||
if isinstance(settings, list): | ||
return [_sanitize_settings(settings_dict) for settings_dict in settings] | ||
|
||
else: | ||
config = {k: v for k, v in config.items() if k in MONGODB_CONF_VARS} | ||
return _sanitize_settings(config) | ||
# Otherwise, it should be a single dict describing a single connection. | ||
return [_sanitize_settings(settings)] | ||
|
||
|
||
def create_connections(config): | ||
def create_connections(config: dict): | ||
""" | ||
Given Flask application's config dict, extract relevant config vars | ||
out of it and establish MongoEngine connection(s) based on them. | ||
""" | ||
# Validate that the config is a dict | ||
if config is None or not isinstance(config, dict): | ||
raise InvalidSettingsError("Invalid application configuration") | ||
# Validate that the config is a dict and dict is not empty | ||
if not config or not isinstance(config, dict): | ||
raise TypeError(f"Config dictionary expected, but {type(config)} received.") | ||
|
||
# Get sanitized connection settings based on the config | ||
conn_settings = get_connection_settings(config) | ||
|
||
# If conn_settings is a list, set up each item as a separate connection | ||
# and return a dict of connection aliases and their connections. | ||
if isinstance(conn_settings, list): | ||
connections = {} | ||
for each in conn_settings: | ||
alias = each["alias"] | ||
connections[alias] = _connect(each) | ||
return connections | ||
connection_settings = get_connection_settings(config) | ||
|
||
# Otherwise, return a single connection | ||
return _connect(conn_settings) | ||
connections = {} | ||
for connection_setting in connection_settings: | ||
alias = connection_setting.setdefault( | ||
"alias", mongoengine.DEFAULT_CONNECTION_NAME | ||
) | ||
connections[alias] = mongoengine.connect(**connection_setting) | ||
|
||
|
||
def _connect(conn_settings): | ||
"""Given a dict of connection settings, create a connection to | ||
MongoDB by calling {func}`mongoengine.connect` and return its result. | ||
""" | ||
db_name = conn_settings.pop("name") | ||
return mongoengine.connect(db_name, **conn_settings) | ||
return connections |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters