diff --git a/mongo_thingy/__init__.py b/mongo_thingy/__init__.py index 782fae6..b486f0b 100644 --- a/mongo_thingy/__init__.py +++ b/mongo_thingy/__init__.py @@ -107,13 +107,13 @@ def count(cls, filter=None, *args, **kwargs): return cls.count_documents(filter=filter, *args, **kwargs) @classmethod - def connect(cls, *args, client_cls=None, **kwargs): + def connect(cls, *args, client_cls=None, database_name=None, **kwargs): if not client_cls: client_cls = cls._client_cls cls._client = client_cls(*args, **kwargs) try: - cls._database = cls._client.get_database() + cls._database = cls._client.get_database(database_name) except (ConfigurationError, TypeError): cls._database = cls._client["test"] diff --git a/tests/__init__.py b/tests/__init__.py index e959a5e..5e927de 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -156,6 +156,13 @@ def test_connect_disconnect(thingy_cls, client_cls): disconnect() assert thingy_cls._client is None + connect(client_cls=client_cls, database_name="database") + assert isinstance(thingy_cls.client, client_cls) + assert thingy_cls._database.name == "database" + + disconnect() + assert thingy_cls._client is None + thingy_cls._client_cls = client_cls connect() assert isinstance(thingy_cls.client, client_cls)