Skip to content

Commit

Permalink
Wrap cursor results into a list subclass that users can extend
Browse files Browse the repository at this point in the history
It allows funny stuff such as:

    class FooList(ThingyList):
        def find_bars(self):
            bar_ids = self.distinct("bar_id")
            return Bar.find({"_id": {"$in": bar_ids}})

    class Foo(Thingy):
        _result_cls = FooList

    foos = Foo.find().to_list()
    bars = foos.find_bars().to_list()
  • Loading branch information
ramnes committed Dec 1, 2022
1 parent eba0dc3 commit 31df715
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 4 deletions.
12 changes: 12 additions & 0 deletions mongo_thingy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@
AsyncIOMotorClient = None


class ThingyList(list):
def distinct(self, key):
def __get_value(document):
if hasattr(document, "__dict__"):
document = document.__dict__
return document.get(key)

values = set(__get_value(result) for result in self)
return list(values)


class BaseThingy(DatabaseThingy):
"""Represents a document in a collection"""

Expand All @@ -27,6 +38,7 @@ class BaseThingy(DatabaseThingy):
_collection = None
_collection_name = None
_cursor_cls = None
_result_cls = ThingyList

@classproperty
def _table(cls):
Expand Down
5 changes: 3 additions & 2 deletions mongo_thingy/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __call__(self, cursor):
async def wrapper(*args, **kwargs):
result = await method(*args, **kwargs)
if isinstance(result, list):
return [cursor.bind(r) for r in result]
return cursor.result_cls(cursor.bind(r) for r in result)
return cursor.bind(result)

return wrapper
Expand All @@ -57,6 +57,7 @@ class BaseCursor:
def __init__(self, delegate, thingy_cls=None, view=None):
self.delegate = delegate
self.thingy_cls = thingy_cls
self.result_cls = getattr(thingy_cls, "_result_cls", list)

if isinstance(view, str):
view = self.get_view(view)
Expand Down Expand Up @@ -100,7 +101,7 @@ def __getitem__(self, index):

def to_list(self, length):
self.limit(length)
return list(self)
return self.result_cls(self)

def delete(self):
ids = self.distinct("_id")
Expand Down
33 changes: 32 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,38 @@
import pytest
from bson import ObjectId

from mongo_thingy import connect, create_indexes, disconnect, registry
from mongo_thingy import (
Thingy,
ThingyList,
connect,
create_indexes,
disconnect,
registry,
)


async def test_thingy_list_distinct_thingies():
foos = ThingyList()
foos.append(Thingy())
foos.append(Thingy())
foos.append(Thingy(bar="baz"))
foos.append(Thingy(bar="qux"))

distinct = foos.distinct("bar")
assert distinct.count(None) == 1
assert set(distinct) == {None, "baz", "qux"}


async def test_thingy_list_distinct_dicts():
foos = ThingyList()
foos.append({})
foos.append({})
foos.append({"bar": "baz"})
foos.append({"bar": "qux"})

distinct = foos.distinct("bar")
assert distinct.count(None) == 1
assert set(distinct) == {None, "baz", "qux"}


@pytest.mark.all_backends
Expand Down
23 changes: 22 additions & 1 deletion tests/cursor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from mongo_thingy import Thingy
from mongo_thingy import Thingy, ThingyList
from mongo_thingy.cursor import (
AsyncCursor,
Cursor,
Expand Down Expand Up @@ -68,6 +68,23 @@ class FooCursor(Cursor):
assert foo.bar == "baz"


def test_cursor_result_cls():
cursor = Cursor(None)
assert cursor.result_cls == list

class Foo(Thingy):
pass

cursor = Cursor(None, thingy_cls=Foo)
assert cursor.result_cls == ThingyList

class Foo(Thingy):
_result_cls = set

cursor = Cursor(None, thingy_cls=Foo)
assert cursor.result_cls == set


def test_cursor_bind():
cursor = Cursor(None)
result = cursor.bind({"foo": "bar"})
Expand Down Expand Up @@ -230,6 +247,8 @@ class Foo(thingy_cls):
cursor = Cursor(collection.find(), thingy_cls=Foo)

results = cursor.to_list(length=10)
assert isinstance(results, ThingyList)

assert isinstance(results[0], Foo)
assert results[0].bar == "baz"

Expand All @@ -245,6 +264,8 @@ class Foo(thingy_cls):
cursor = AsyncCursor(collection.find(), thingy_cls=Foo)

results = await cursor.to_list(length=10)
assert isinstance(results, ThingyList)

assert isinstance(results[0], Foo)
assert results[0].bar == "baz"

Expand Down

0 comments on commit 31df715

Please sign in to comment.