Skip to content

Commit

Permalink
fix(#440): type of AggregatedDataframe.df should always be LazyFrame
Browse files Browse the repository at this point in the history
Fixes #440
  • Loading branch information
MartinBernstorff committed Feb 19, 2024
1 parent 0febfae commit ab9be64
Show file tree
Hide file tree
Showing 9 changed files with 15 additions and 14 deletions.
1 change: 0 additions & 1 deletion requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ pytz==2024.1
rich==13.7.0
# via timeseriesflattener
ruff==0.2.2
# via timeseriesflattener
scikit-learn==1.4.1.post1
# via timeseriesflattener
scipy==1.12.0
Expand Down
2 changes: 0 additions & 2 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ pytz==2024.1
# via pandas
rich==13.7.0
# via timeseriesflattener
ruff==0.2.2
# via timeseriesflattener
scikit-learn==1.4.1.post1
# via timeseriesflattener
scipy==1.12.0
Expand Down
10 changes: 7 additions & 3 deletions src/timeseriesflattenerv2/_intermediary_frames.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING, Sequence

import polars as pl
Expand All @@ -8,6 +8,7 @@
default_pred_time_uuid_col_name,
default_timestamp_col_name,
)
from .frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe

if TYPE_CHECKING:
from .feature_specs.meta import ValueType
Expand Down Expand Up @@ -84,13 +85,16 @@ def collect(self) -> pl.DataFrame:
return self.df.collect()


@dataclass(frozen=True)
@dataclass
class AggregatedFrame:
df: pl.LazyFrame
init_df: InitVar[pl.LazyFrame]
entity_id_col_name: str
timestamp_col_name: str
pred_time_uuid_col_name: str

def __post_init__(self, init_df: pl.LazyFrame):
self.df = _anyframe_to_lazyframe(init_df)

def collect(self) -> pl.DataFrame:
if isinstance(self.df, pl.DataFrame):
return self.df
Expand Down
8 changes: 4 additions & 4 deletions src/timeseriesflattenerv2/feature_specs/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from timeseriesflattenerv2.feature_specs.default_column_names import default_entity_id_col_name

from .._frame_validator import _validate_col_name_columns_exist
from ..frame_utilities._anyframe_to_lazyframe import _anyframe_to_lazyframe
from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe

if TYPE_CHECKING:
from timeseriesflattenerv2.feature_specs.outcome import BooleanOutcomeSpec, OutcomeSpec
Expand All @@ -21,9 +21,9 @@
InitDF_T = pl.LazyFrame | pl.DataFrame | pd.DataFrame


ValueSpecification: TypeAlias = (
"Union[PredictorSpec, OutcomeSpec, BooleanOutcomeSpec, TimeDeltaSpec, StaticSpec]"
)
ValueSpecification: (
TypeAlias
) = "Union[PredictorSpec, OutcomeSpec, BooleanOutcomeSpec, TimeDeltaSpec, StaticSpec]"
LookDistance: TypeAlias = dt.timedelta


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import polars as pl

from .._frame_validator import _validate_col_name_columns_exist
from ..frame_utilities._anyframe_to_lazyframe import _anyframe_to_lazyframe
from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe
from .default_column_names import (
default_entity_id_col_name,
default_pred_time_col_name,
Expand Down
2 changes: 1 addition & 1 deletion src/timeseriesflattenerv2/feature_specs/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import polars as pl

from .._frame_validator import _validate_col_name_columns_exist
from ..frame_utilities._anyframe_to_lazyframe import _anyframe_to_lazyframe
from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe
from .default_column_names import default_entity_id_col_name
from .meta import ValueType

Expand Down
2 changes: 1 addition & 1 deletion src/timeseriesflattenerv2/feature_specs/timestamp_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import polars as pl

from .._frame_validator import _validate_col_name_columns_exist
from ..frame_utilities._anyframe_to_lazyframe import _anyframe_to_lazyframe
from ..frame_utilities.anyframe_to_lazyframe import _anyframe_to_lazyframe
from .default_column_names import default_entity_id_col_name

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion src/timeseriesflattenerv2/flattener.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def aggregate_timeseries(self, specs: Sequence["ValueSpecification"]) -> Aggrega
)

return AggregatedFrame(
df=horizontally_concatenate_dfs(
init_df=horizontally_concatenate_dfs(
[self.predictiontime_frame.df, feature_dfs], # type: ignore
pred_time_uuid_col_name=self.predictiontime_frame.pred_time_uuid_col_name,
),
Expand Down

0 comments on commit ab9be64

Please sign in to comment.