diff --git a/mongo_thingy/__init__.py b/mongo_thingy/__init__.py index 042cd0b..84b9cac 100644 --- a/mongo_thingy/__init__.py +++ b/mongo_thingy/__init__.py @@ -36,15 +36,19 @@ def collection_name(cls): @classproperty def client(cls): - return cls._client + return cls.get_client() + + @classmethod + def _get_client(cls, database): + return database.client @classmethod def _get_database(cls, collection, name): if collection: return collection.database - if cls.client and name: - return cls.client[name] - raise AttributeError("Undefined client.") + if cls._client and name: + return cls._client[name] + raise AttributeError("Undefined database.") @classmethod def _get_table(cls, database, table_name): @@ -58,6 +62,12 @@ def _get_database_name(cls, database): def _get_table_name(cls, table): return table.name + @classmethod + def get_client(cls): + if cls._client: + return cls._client + return cls._get_client(cls.database) + @classmethod def get_collection(cls): return cls.get_table() diff --git a/tests/__init__.py b/tests/__init__.py index cfe3080..ed9383b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -10,6 +10,10 @@ def test_thingy_database(TestThingy, database): assert TestThingy.database == database +def test_thingy_client(TestThingy, client): + assert TestThingy.client == client + + def test_thingy_collection(TestThingy, collection): assert TestThingy.collection == collection @@ -54,6 +58,13 @@ class FooBar(Thingy): assert FooBar.collection_name == collection.name +def test_thingy_database_from_client(client): + class FooBar(Thingy): + _client = client + + assert FooBar.database == client.foo + + def test_thingy_database_from_collection(collection): class Foo(Thingy): _collection = collection @@ -61,6 +72,13 @@ class Foo(Thingy): assert Foo.database == collection.database +def test_thingy_client_from_database(database): + class Foo(Thingy): + _database = database + + assert Foo.client == database.client + + def test_thingy_collection_from_database(database): class Foo(Thingy): _database = database @@ -106,14 +124,15 @@ def test_thingy_count(TestThingy, collection): @pytest.mark.parametrize("connect", [connect, Thingy.connect]) @pytest.mark.parametrize("disconnect", [disconnect, Thingy.disconnect]) def test_thingy_connect_disconnect(connect, disconnect): - assert Thingy.client is None + with pytest.raises(AttributeError): + Thingy.client connect() assert isinstance(Thingy.client, MongoClient) assert Thingy._database is None disconnect() - assert Thingy.client is None + assert Thingy._client is None connect("mongodb://hostname/database") assert isinstance(Thingy.client, MongoClient) @@ -121,7 +140,10 @@ def test_thingy_connect_disconnect(connect, disconnect): assert Thingy.database.name == "database" disconnect() - assert Thingy.client is None + assert Thingy._client is None + with pytest.raises(AttributeError): + Thingy.client + assert Thingy._database is None with pytest.raises(AttributeError): Thingy.database