Skip to content

Commit

Permalink
Preserve sort/skip/limit for aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
gsakkis committed Sep 16, 2023
1 parent 60dc39c commit ce7bfe6
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 18 deletions.
22 changes: 8 additions & 14 deletions beanie/odm/queries/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,18 +556,11 @@ def aggregate(
:return:[AggregationQuery](query.md#aggregationquery)
"""
self.set_session(session=session)
find_query = self.get_filter_query()
if self.fetch_links:
find_aggregation_pipeline = self.build_aggregation_pipeline()
aggregation_pipeline = (
find_aggregation_pipeline + aggregation_pipeline
)
find_query = {}
return self.AggregationQueryType(
aggregation_pipeline=aggregation_pipeline,
document_model=self.document_model,
self.document_model,
self.build_aggregation_pipeline(*aggregation_pipeline),
find_query={},
projection_model=projection_model,
find_query=find_query,
ignore_cache=ignore_cache,
**pymongo_kwargs,
).set_session(session=self.session)
Expand Down Expand Up @@ -605,7 +598,7 @@ def _set_cache(self, data):
self._cache_key, data
)

def build_aggregation_pipeline(self):
def build_aggregation_pipeline(self, *extra_stages):
aggregation_pipeline: List[Dict[str, Any]] = construct_lookup_queries(
self.document_model
)
Expand All @@ -614,9 +607,10 @@ def build_aggregation_pipeline(self):
text_query = filter_query["$text"]
aggregation_pipeline.insert(0, {"$match": {"$text": text_query}})
del filter_query["$text"]

aggregation_pipeline.append({"$match": filter_query})

if filter_query:
aggregation_pipeline.append({"$match": filter_query})
if extra_stages:
aggregation_pipeline.extend(extra_stages)
sort_pipeline = {"$sort": {i[0]: i[1] for i in self.sort_expressions}}
if sort_pipeline["$sort"]:
aggregation_pipeline.append(sort_pipeline)
Expand Down
4 changes: 0 additions & 4 deletions beanie/odm/utils/find.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Type

from beanie.exceptions import NotSupported
from beanie.odm.fields import LinkInfo, LinkTypes
from beanie.odm.interfaces.detector import ModelType

if TYPE_CHECKING:
from beanie import Document
Expand All @@ -13,8 +11,6 @@


def construct_lookup_queries(cls: Type["Document"]) -> List[Dict[str, Any]]:
if cls.get_model_type() == ModelType.UnionDoc:
raise NotSupported("UnionDoc doesn't support link fetching")
queries: List = []
link_fields = cls.get_link_fields()
if link_fields is not None:
Expand Down
31 changes: 31 additions & 0 deletions tests/odm/query/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pydantic import Field
from pydantic.main import BaseModel

from beanie.odm.enums import SortDirection
from tests.odm.models import Sample


Expand Down Expand Up @@ -35,6 +36,36 @@ async def test_aggregate_with_filter(preset_documents):
assert {"_id": "test_3", "total": 3} in result


async def test_aggregate_with_sort_skip(preset_documents):
q = Sample.find(sort="_id", skip=2).aggregate(
[{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}]
)
assert q.get_aggregation_pipeline() == [
{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}},
{"$sort": {"_id": SortDirection.ASCENDING}},
{"$skip": 2},
]
assert await q.to_list() == [
{"_id": "test_2", "total": 6},
{"_id": "test_3", "total": 3},
]


async def test_aggregate_with_sort_limit(preset_documents):
q = Sample.find(sort="_id", limit=2).aggregate(
[{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}]
)
assert q.get_aggregation_pipeline() == [
{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}},
{"$sort": {"_id": SortDirection.ASCENDING}},
{"$limit": 2},
]
assert await q.to_list() == [
{"_id": "test_0", "total": 0},
{"_id": "test_1", "total": 3},
]


async def test_aggregate_with_projection_model(preset_documents):
class OutputItem(BaseModel):
id: str = Field(None, alias="_id")
Expand Down

0 comments on commit ce7bfe6

Please sign in to comment.