Skip to content

Commit

Permalink
fix: change to type instead of instantiated type
Browse files Browse the repository at this point in the history
  • Loading branch information
HLasse committed May 23, 2024
1 parent 46c6707 commit e327665
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/timeseriesflattener/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def _validate_compatible_fallback_type_for_aggregator(
aggregator: Aggregator, fallback: str | int | float | None
) -> None:
try:
pl.Series([aggregator.output_type]).fill_null(fallback)
pl.Series([aggregator.output_type()]).fill_null(fallback)
except:
raise ValueError(
f"Invalid fallback value {fallback} for aggregator {aggregator.__class__.__name__}. Fallback of type {type(fallback)} is not compatible with the aggregator's output type of {type(aggregator.output_type)}."
Expand All @@ -20,7 +20,7 @@ def _validate_compatible_fallback_type_for_aggregator(

class Aggregator(ABC):
name: str
output_type: float | int | bool
output_type: type[float | int | bool]

@abstractmethod
def __call__(self, column_name: str) -> pl.Expr:
Expand All @@ -34,7 +34,7 @@ class MinAggregator(Aggregator):
"""Returns the minimum value in the look window."""

name: str = "min"
output_type = float()
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).min().alias(self.new_col_name(column_name))
Expand All @@ -44,7 +44,7 @@ class MaxAggregator(Aggregator):
"""Returns the maximum value in the look window."""

name: str = "max"
output_type = float()
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).max().alias(self.new_col_name(column_name))
Expand All @@ -54,7 +54,7 @@ class MeanAggregator(Aggregator):
"""Returns the mean value in the look window."""

name: str = "mean"
output_type = float()
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).mean().alias(self.new_col_name(column_name))
Expand All @@ -64,7 +64,7 @@ class CountAggregator(Aggregator):
"""Returns the count of non-null values in the look window."""

name: str = "count"
output_type = int()
output_type = int

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).count().alias(self.new_col_name(column_name))
Expand All @@ -76,7 +76,7 @@ class EarliestAggregator(Aggregator):

timestamp_col_name: str
name: str = "earliest"
output_type = float()
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return (
Expand All @@ -93,7 +93,7 @@ class LatestAggregator(Aggregator):

timestamp_col_name: str
name: str = "latest"
output_type = float()
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return (
Expand All @@ -108,7 +108,7 @@ class SumAggregator(Aggregator):
"""Returns the sum of all values in the look window."""

name: str = "sum"
output_type = float()
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).sum().alias(self.new_col_name(column_name))
Expand All @@ -118,7 +118,7 @@ class VarianceAggregator(Aggregator):
"""Returns the variance of the values in the look window"""

name: str = "var"
output_type = float()
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
return pl.col(column_name).var().alias(self.new_col_name(column_name))
Expand All @@ -128,7 +128,7 @@ class HasValuesAggregator(Aggregator):
"""Examines whether any values exist in the column. If so, returns True, else False."""

name: str = "bool"
output_type = bool()
output_type = bool

def __call__(self, column_name: str) -> pl.Expr:
return (
Expand All @@ -147,7 +147,7 @@ class SlopeAggregator(Aggregator):

timestamp_col_name: str
name: str = "slope"
output_type = float()
output_type = float

def __call__(self, column_name: str) -> pl.Expr:
# Convert to days for the slope. Arbitrarily chosen to be the number of days since 1970-01-01.
Expand Down

0 comments on commit e327665

Please sign in to comment.