Skip to content

Commit

Permalink
Fixed bug with Group
Browse files Browse the repository at this point in the history
  • Loading branch information
VianneyMI committed Sep 15, 2023
1 parent b222bf8 commit bbb27a5
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 2 deletions.
3 changes: 2 additions & 1 deletion monggregate/stages/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from typing import Any
from monggregate.base import pyd
from monggregate.stages.stage import Stage
from monggregate.utils import validate_field_path
from monggregate.utils import validate_field_path, validate_field_paths

class Group(Stage):
"""
Expand All @@ -91,6 +91,7 @@ class Group(Stage):
# Validators
# ------------------------------------------
_validate_by = pyd.validator("by", pre=True, always=True, allow_reuse=True)(validate_field_path) # re-used pyd.validator
_validate_iterable_by = pyd.validator("by", pre=True, always=True, allow_reuse=True)(validate_field_paths) # re-used pyd.validator

@pyd.validator("query", always=True)
@classmethod
Expand Down
13 changes: 12 additions & 1 deletion monggregate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,18 @@ def to_unique_list(keys:T)->list[str]|T:
def validate_field_path(path:str|None)->str|None:
"""Validates field path"""

if path and not path.startswith("$"):
if isinstance(path, str) and not path.startswith("$"):
path = "$" + path

return path


def validate_field_paths(paths:list[str]|set[str])->list[str]:
"""Validates field paths"""

if isinstance(paths, list):
paths = [validate_field_path(path) for path in paths]
elif isinstance(paths, set):
paths = {validate_field_path(path) for path in paths}

return paths
47 changes: 47 additions & 0 deletions test/test_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def test_count(self, state:State)->None:
assert count
state["count"] = count


def test_group(self, state:State)->None:
"""Tests the group stage"""

Expand All @@ -141,6 +142,52 @@ def test_group(self, state:State)->None:
}
)

# Test by as list
# ------------------------
assert Group(
by=["name", "age"],
query = {
"output":{"$sum":"income"}
}
)

# Test by as set
# ------------------------
assert Group(
by=set(["name", "age"]),
query = {
"output":{"$sum":"income"}
}
)

# Test by as constant
# ------------------------
assert Group(
by=1,
query = {
"output":{"$sum":"income"}
}
)

# Test by as dict
# ------------------------
assert Group(
by={"name":"$name"},
query = {
"output":{"$sum":"income"}
}
)

# Test by as None
# ------------------------
assert Group(
by=None,
query = {
"output":{"$sum":"income"}
}
)



def test_limit(self, state:State)->None:
"""Tests the limit stage"""
Expand Down

0 comments on commit bbb27a5

Please sign in to comment.