diff --git a/mongo_thingy/__init__.py b/mongo_thingy/__init__.py index b486f0b..bc962b1 100644 --- a/mongo_thingy/__init__.py +++ b/mongo_thingy/__init__.py @@ -19,6 +19,17 @@ AsyncIOMotorClient = None +class ThingyList(list): + def distinct(self, key): + def __get_value(document): + if hasattr(document, "__dict__"): + document = document.__dict__ + return document.get(key) + + values = set(__get_value(result) for result in self) + return list(values) + + class BaseThingy(DatabaseThingy): """Represents a document in a collection""" @@ -27,6 +38,7 @@ class BaseThingy(DatabaseThingy): _collection = None _collection_name = None _cursor_cls = None + _result_cls = ThingyList @classproperty def _table(cls): diff --git a/mongo_thingy/cursor.py b/mongo_thingy/cursor.py index 65d5f6e..298931d 100644 --- a/mongo_thingy/cursor.py +++ b/mongo_thingy/cursor.py @@ -41,7 +41,7 @@ def __call__(self, cursor): async def wrapper(*args, **kwargs): result = await method(*args, **kwargs) if isinstance(result, list): - return [cursor.bind(r) for r in result] + return cursor.result_cls(cursor.bind(r) for r in result) return cursor.bind(result) return wrapper @@ -57,6 +57,7 @@ class BaseCursor: def __init__(self, delegate, thingy_cls=None, view=None): self.delegate = delegate self.thingy_cls = thingy_cls + self.result_cls = getattr(thingy_cls, "_result_cls", list) if isinstance(view, str): view = self.get_view(view) @@ -100,7 +101,7 @@ def __getitem__(self, index): def to_list(self, length): self.limit(length) - return list(self) + return self.result_cls(self) def delete(self): ids = self.distinct("_id") diff --git a/tests/__init__.py b/tests/__init__.py index 08a082e..1ac6ef6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,7 +3,38 @@ import pytest from bson import ObjectId -from mongo_thingy import connect, create_indexes, disconnect, registry +from mongo_thingy import ( + Thingy, + ThingyList, + connect, + create_indexes, + disconnect, + registry, +) + + +async def test_thingy_list_distinct_thingies(): + foos = ThingyList() + foos.append(Thingy()) + foos.append(Thingy()) + foos.append(Thingy(bar="baz")) + foos.append(Thingy(bar="qux")) + + distinct = foos.distinct("bar") + assert distinct.count(None) == 1 + assert set(distinct) == {None, "baz", "qux"} + + +async def test_thingy_list_distinct_dicts(): + foos = ThingyList() + foos.append({}) + foos.append({}) + foos.append({"bar": "baz"}) + foos.append({"bar": "qux"}) + + distinct = foos.distinct("bar") + assert distinct.count(None) == 1 + assert set(distinct) == {None, "baz", "qux"} @pytest.mark.all_backends diff --git a/tests/cursor.py b/tests/cursor.py index 3c9d146..54572de 100644 --- a/tests/cursor.py +++ b/tests/cursor.py @@ -1,6 +1,6 @@ import pytest -from mongo_thingy import Thingy +from mongo_thingy import Thingy, ThingyList from mongo_thingy.cursor import ( AsyncCursor, Cursor, @@ -68,6 +68,23 @@ class FooCursor(Cursor): assert foo.bar == "baz" +def test_cursor_result_cls(): + cursor = Cursor(None) + assert cursor.result_cls == list + + class Foo(Thingy): + pass + + cursor = Cursor(None, thingy_cls=Foo) + assert cursor.result_cls == ThingyList + + class Foo(Thingy): + _result_cls = set + + cursor = Cursor(None, thingy_cls=Foo) + assert cursor.result_cls == set + + def test_cursor_bind(): cursor = Cursor(None) result = cursor.bind({"foo": "bar"}) @@ -230,6 +247,8 @@ class Foo(thingy_cls): cursor = Cursor(collection.find(), thingy_cls=Foo) results = cursor.to_list(length=10) + assert isinstance(results, ThingyList) + assert isinstance(results[0], Foo) assert results[0].bar == "baz" @@ -245,6 +264,8 @@ class Foo(thingy_cls): cursor = AsyncCursor(collection.find(), thingy_cls=Foo) results = await cursor.to_list(length=10) + assert isinstance(results, ThingyList) + assert isinstance(results[0], Foo) assert results[0].bar == "baz"