diff --git a/monggregate/stages/group.py b/monggregate/stages/group.py index 53ad66b..05ea3c9 100644 --- a/monggregate/stages/group.py +++ b/monggregate/stages/group.py @@ -66,7 +66,7 @@ from typing import Any from monggregate.base import pyd from monggregate.stages.stage import Stage -from monggregate.utils import validate_field_path +from monggregate.utils import validate_field_path, validate_field_paths class Group(Stage): """ @@ -91,6 +91,7 @@ class Group(Stage): # Validators # ------------------------------------------ _validate_by = pyd.validator("by", pre=True, always=True, allow_reuse=True)(validate_field_path) # re-used pyd.validator + _validate_iterable_by = pyd.validator("by", pre=True, always=True, allow_reuse=True)(validate_field_paths) # re-used pyd.validator @pyd.validator("query", always=True) @classmethod diff --git a/monggregate/utils.py b/monggregate/utils.py index 4ba419d..e38c882 100644 --- a/monggregate/utils.py +++ b/monggregate/utils.py @@ -66,7 +66,18 @@ def to_unique_list(keys:T)->list[str]|T: def validate_field_path(path:str|None)->str|None: """Validates field path""" - if path and not path.startswith("$"): + if isinstance(path, str) and not path.startswith("$"): path = "$" + path return path + + +def validate_field_paths(paths:list[str]|set[str])->list[str]: + """Validates field paths""" + + if isinstance(paths, list): + paths = [validate_field_path(path) for path in paths] + elif isinstance(paths, set): + paths = {validate_field_path(path) for path in paths} + + return paths diff --git a/test/test_stages.py b/test/test_stages.py index 6b52310..d1d371f 100644 --- a/test/test_stages.py +++ b/test/test_stages.py @@ -115,6 +115,7 @@ def test_count(self, state:State)->None: assert count state["count"] = count + def test_group(self, state:State)->None: """Tests the group stage""" @@ -141,6 +142,52 @@ def test_group(self, state:State)->None: } ) + # Test by as list + # ------------------------ + assert Group( + by=["name", "age"], + query = { + "output":{"$sum":"income"} + } + ) + + # Test by as set + # ------------------------ + assert Group( + by=set(["name", "age"]), + query = { + "output":{"$sum":"income"} + } + ) + + # Test by as constant + # ------------------------ + assert Group( + by=1, + query = { + "output":{"$sum":"income"} + } + ) + + # Test by as dict + # ------------------------ + assert Group( + by={"name":"$name"}, + query = { + "output":{"$sum":"income"} + } + ) + + # Test by as None + # ------------------------ + assert Group( + by=None, + query = { + "output":{"$sum":"income"} + } + ) + + def test_limit(self, state:State)->None: """Tests the limit stage"""