From ce7bfe6d6bd5c26643c08a15b2fb2fd449709ec8 Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Sun, 17 Sep 2023 01:28:35 +0300 Subject: [PATCH] Preserve sort/skip/limit for aggregations --- beanie/odm/queries/find.py | 22 ++++++++-------------- beanie/odm/utils/find.py | 4 ---- tests/odm/query/test_aggregate.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/beanie/odm/queries/find.py b/beanie/odm/queries/find.py index 5e3e4cd5..287ec6aa 100644 --- a/beanie/odm/queries/find.py +++ b/beanie/odm/queries/find.py @@ -556,18 +556,11 @@ def aggregate( :return:[AggregationQuery](query.md#aggregationquery) """ self.set_session(session=session) - find_query = self.get_filter_query() - if self.fetch_links: - find_aggregation_pipeline = self.build_aggregation_pipeline() - aggregation_pipeline = ( - find_aggregation_pipeline + aggregation_pipeline - ) - find_query = {} return self.AggregationQueryType( - aggregation_pipeline=aggregation_pipeline, - document_model=self.document_model, + self.document_model, + self.build_aggregation_pipeline(*aggregation_pipeline), + find_query={}, projection_model=projection_model, - find_query=find_query, ignore_cache=ignore_cache, **pymongo_kwargs, ).set_session(session=self.session) @@ -605,7 +598,7 @@ def _set_cache(self, data): self._cache_key, data ) - def build_aggregation_pipeline(self): + def build_aggregation_pipeline(self, *extra_stages): aggregation_pipeline: List[Dict[str, Any]] = construct_lookup_queries( self.document_model ) @@ -614,9 +607,10 @@ def build_aggregation_pipeline(self): text_query = filter_query["$text"] aggregation_pipeline.insert(0, {"$match": {"$text": text_query}}) del filter_query["$text"] - - aggregation_pipeline.append({"$match": filter_query}) - + if filter_query: + aggregation_pipeline.append({"$match": filter_query}) + if extra_stages: + aggregation_pipeline.extend(extra_stages) sort_pipeline = {"$sort": {i[0]: i[1] for i in self.sort_expressions}} if sort_pipeline["$sort"]: aggregation_pipeline.append(sort_pipeline) diff --git a/beanie/odm/utils/find.py b/beanie/odm/utils/find.py index 45edc2ed..08fd9c95 100644 --- a/beanie/odm/utils/find.py +++ b/beanie/odm/utils/find.py @@ -1,8 +1,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Type -from beanie.exceptions import NotSupported from beanie.odm.fields import LinkInfo, LinkTypes -from beanie.odm.interfaces.detector import ModelType if TYPE_CHECKING: from beanie import Document @@ -13,8 +11,6 @@ def construct_lookup_queries(cls: Type["Document"]) -> List[Dict[str, Any]]: - if cls.get_model_type() == ModelType.UnionDoc: - raise NotSupported("UnionDoc doesn't support link fetching") queries: List = [] link_fields = cls.get_link_fields() if link_fields is not None: diff --git a/tests/odm/query/test_aggregate.py b/tests/odm/query/test_aggregate.py index 54d5d873..64a34a9d 100644 --- a/tests/odm/query/test_aggregate.py +++ b/tests/odm/query/test_aggregate.py @@ -2,6 +2,7 @@ from pydantic import Field from pydantic.main import BaseModel +from beanie.odm.enums import SortDirection from tests.odm.models import Sample @@ -35,6 +36,36 @@ async def test_aggregate_with_filter(preset_documents): assert {"_id": "test_3", "total": 3} in result +async def test_aggregate_with_sort_skip(preset_documents): + q = Sample.find(sort="_id", skip=2).aggregate( + [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}] + ) + assert q.get_aggregation_pipeline() == [ + {"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}, + {"$sort": {"_id": SortDirection.ASCENDING}}, + {"$skip": 2}, + ] + assert await q.to_list() == [ + {"_id": "test_2", "total": 6}, + {"_id": "test_3", "total": 3}, + ] + + +async def test_aggregate_with_sort_limit(preset_documents): + q = Sample.find(sort="_id", limit=2).aggregate( + [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}] + ) + assert q.get_aggregation_pipeline() == [ + {"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}, + {"$sort": {"_id": SortDirection.ASCENDING}}, + {"$limit": 2}, + ] + assert await q.to_list() == [ + {"_id": "test_0", "total": 0}, + {"_id": "test_1", "total": 3}, + ] + + async def test_aggregate_with_projection_model(preset_documents): class OutputItem(BaseModel): id: str = Field(None, alias="_id")