Skip to content

Commit

Permalink
fix: fix aggregations with text queries
Browse files Browse the repository at this point in the history
  • Loading branch information
MrEarle committed Oct 19, 2023
1 parent 0718894 commit 92181ee
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 8 deletions.
33 changes: 27 additions & 6 deletions beanie/odm/queries/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from beanie.odm.utils.dump import get_dict
from beanie.odm.utils.encoder import Encoder
from beanie.odm.utils.find import construct_lookup_queries
from beanie.odm.utils.find import construct_lookup_queries, split_text_query
from beanie.odm.utils.parsing import parse_obj
from beanie.odm.utils.projection import get_projection
from beanie.odm.utils.relations import convert_ids
Expand Down Expand Up @@ -603,12 +603,33 @@ def build_aggregation_pipeline(self, *extra_stages):
self.document_model
)
filter_query = self.get_filter_query()
if "$text" in filter_query:
text_query = filter_query["$text"]
aggregation_pipeline.insert(0, {"$match": {"$text": text_query}})
del filter_query["$text"]

if filter_query:
aggregation_pipeline.append({"$match": filter_query})
text_queries, non_text_queries = split_text_query(filter_query)

if text_queries:
aggregation_pipeline.insert(
0,
{
"$match": (
{"$and": text_queries}
if len(text_queries) > 1
else text_queries[0]
)
},
)

if non_text_queries:
aggregation_pipeline.append(
{
"$match": (
{"$and": non_text_queries}
if len(non_text_queries) > 1
else non_text_queries[0]
)
}
)

if extra_stages:
aggregation_pipeline.extend(extra_stages)
sort_pipeline = {"$sort": {i[0]: i[1] for i in self.sort_expressions}}
Expand Down
33 changes: 32 additions & 1 deletion beanie/odm/utils/find.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, List, Type
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type

from beanie.odm.fields import LinkInfo, LinkTypes

Expand Down Expand Up @@ -325,3 +325,34 @@ def construct_query(
queries.append(lookup_step)

return queries


def split_text_query(
query: Dict[str, Any]
) -> Tuple[list[Dict[str, Any]], list[Dict[str, Any]]]:
"""Divide query into text and non-text matches
:param query: Dict[str, Any] - query dict
:return: Tuple[Dict[str, Any], Dict[str, Any]] - text and non-text queries,
respectively
"""

root_text_query_args: Dict[str, Any] = query.get("$text", None)
root_non_text_queries: Dict[str, Any] = {
k: v for k, v in query.items() if k not in {"$text", "$and"}
}

text_queries: list[Dict[str, Any]] = (
[{"$text": root_text_query_args}] if root_text_query_args else []
)
non_text_queries: list[Dict[str, Any]] = (
[root_non_text_queries] if root_non_text_queries else []
)

for match_case in query.get("$and", []):
if "$text" in match_case:
text_queries.append(match_case)
else:
non_text_queries.append(match_case)

return text_queries, non_text_queries
52 changes: 51 additions & 1 deletion tests/odm/query/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from pymongo.errors import OperationFailure

from beanie.odm.enums import SortDirection
from tests.odm.models import Sample
from beanie.odm.utils.find import construct_lookup_queries
from tests.odm.models import DocumentWithTextIndexAndLink, Sample


async def test_aggregate(preset_documents):
Expand Down Expand Up @@ -138,3 +139,52 @@ async def test_clone(preset_documents):
{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}},
{"a": "b"},
]


@pytest.mark.parametrize("text_query_count", [0, 1, 2])
@pytest.mark.parametrize("non_text_query_count", [0, 1, 2])
async def test_with_text_queries(
text_query_count: int, non_text_query_count: int
):
text_query = {"$text": {"$search": "text_search"}}
non_text_query = {"s": "test_string"}
aggregation_pipeline = [{"$count": "count"}]
queries = []

if text_query_count:
queries.append(text_query)
if text_query_count > 1:
queries.append(text_query)

if non_text_query_count:
queries.append(non_text_query)
if non_text_query_count > 1:
queries.append(non_text_query)

query = DocumentWithTextIndexAndLink.find(*queries, fetch_links=True)

expected_aggregation_pipeline = []
if text_query_count:
expected_aggregation_pipeline.append(
{"$match": text_query}
if text_query_count == 1
else {"$match": {"$and": [text_query, text_query]}}
)

expected_aggregation_pipeline.extend(
construct_lookup_queries(query.document_model)
)

if non_text_query_count:
expected_aggregation_pipeline.append(
{"$match": non_text_query}
if non_text_query_count == 1
else {"$match": {"$and": [non_text_query, non_text_query]}}
)

expected_aggregation_pipeline.extend(aggregation_pipeline)

assert (
query.build_aggregation_pipeline(*aggregation_pipeline)
== expected_aggregation_pipeline
)
29 changes: 29 additions & 0 deletions tests/odm/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,35 @@ async def test_with_chaining_aggregation(self):

assert addresses_count[0] == {"count": 10}

async def test_with_chaining_aggregation_and_text_search(self):
linked_document = LinkDocumentForTextSeacrh(i=1)
await linked_document.insert()

for i in range(10):
await DocumentWithTextIndexAndLink(
s="lower" if i < 5 else "UPPER", link=linked_document
).insert()

linked_document_2 = LinkDocumentForTextSeacrh(i=2)
await linked_document_2.insert()

for i in range(10):
await DocumentWithTextIndexAndLink(
s="lower" if i < 5 else "UPPER", link=linked_document_2
).insert()

document_count = (
await DocumentWithTextIndexAndLink.find(
{"$text": {"$search": "lower"}},
DocumentWithTextIndexAndLink.link.i == 1,
fetch_links=True,
)
.aggregate([{"$count": "count"}])
.to_list()
)

assert document_count[0] == {"count": 5}

async def test_with_extra_allow(self, houses):
res = await House.find(fetch_links=True).to_list()
assert get_model_fields(res[0]).keys() == {
Expand Down

0 comments on commit 92181ee

Please sign in to comment.