Skip to content

Commit

Permalink
fix: ensure all prediction times are kep when slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff committed Feb 13, 2024
1 parent bee91b0 commit 9e5cef4
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 16 deletions.
34 changes: 25 additions & 9 deletions src/timeseriesflattenerv2/_process_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def _normalise_lookdistances(spec: ValueSpecification) -> Sequence[LookDistance]
return lookdistances


def _null_values_outside_lookwindow(
df: pl.LazyFrame, lookwindow_predicate: pl.Expr, cols_to_null: Sequence[str]
) -> pl.LazyFrame:
for col_to_null in cols_to_null:
df = df.with_columns(
pl.when(lookwindow_predicate).then(pl.col(col_to_null)).otherwise(None)
)
return df


def _slice_frame(
timedelta_frame: TimedeltaFrame,
lookdistance: LookDistance,
Expand All @@ -87,22 +97,28 @@ def _slice_frame(

timedelta_col = pl.col(timedelta_frame.timedelta_col_name)

lookbehind = lookdistance < dt.timedelta(0)
no_predictor_value = timedelta_col.is_null()
is_lookbehind = lookdistance < dt.timedelta(0)

# The predictor case
if lookbehind:
if is_lookbehind:
after_lookbehind_start = lookdistance <= timedelta_col
before_pred_time = timedelta_col <= dt.timedelta(0)
sliced_frame = timedelta_frame.df.filter(
(after_lookbehind_start).and_(before_pred_time).or_(no_predictor_value)
before_prediction_time = timedelta_col <= dt.timedelta(0)

within_lookbehind = after_lookbehind_start.and_(before_prediction_time)
sliced_frame = _null_values_outside_lookwindow(
df=timedelta_frame.df,
lookwindow_predicate=within_lookbehind,
cols_to_null=[timedelta_frame.value_col_name, timedelta_frame.timedelta_col_name],
)
# The outcome case
else:
after_pred_time = dt.timedelta(0) <= timedelta_col
after_prediction_time = dt.timedelta(0) <= timedelta_col
before_lookahead_end = timedelta_col <= lookdistance
sliced_frame = timedelta_frame.df.filter(
(after_pred_time).and_(before_lookahead_end).or_(no_predictor_value)
within_lookahead = after_prediction_time.and_(before_lookahead_end)
sliced_frame = _null_values_outside_lookwindow(
df=timedelta_frame.df,
lookwindow_predicate=within_lookahead,
cols_to_null=[timedelta_frame.value_col_name, timedelta_frame.timedelta_col_name],
)

return SlicedFrame(
Expand Down
3 changes: 3 additions & 0 deletions src/timeseriesflattenerv2/feature_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class SlicedFrame:
def df(self) -> pl.LazyFrame:
return self.init_df

def collect(self) -> pl.DataFrame:
return self.init_df.collect()


@dataclass
class AggregatedValueFrame:
Expand Down
68 changes: 61 additions & 7 deletions src/timeseriesflattenerv2/test_process_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import polars as pl
from timeseriesflattener.testing.utils_for_testing import str_to_pl_df

import timeseriesflattenerv2._process_spec
import timeseriesflattenerv2._process_spec as process_spec

from ._horisontally_concat import horizontally_concatenate_dfs
from .aggregators import MaxAggregator, MeanAggregator
from .feature_specs import PredictionTimeFrame, SlicedFrame, ValueFrame
from .feature_specs import PredictionTimeFrame, SlicedFrame, TimedeltaFrame, ValueFrame
from .test_flattener import assert_frame_equal


Expand All @@ -19,7 +19,7 @@ def test_aggregate_over_fallback():
value_col_name="value",
)

aggregated_values = timeseriesflattenerv2._process_spec._aggregate_within_slice(
aggregated_values = process_spec._aggregate_within_slice(
sliced_frame=sliced_frame, aggregators=[MeanAggregator()], fallback=0
)

Expand All @@ -31,6 +31,26 @@ def test_aggregate_over_fallback():
assert_frame_equal(aggregated_values[0].df.collect(), expected)


def test_aggregate_with_null():
sliced_frame = SlicedFrame(
init_df=pl.LazyFrame(
{"pred_time_uuid": ["1-2021-01-03", "1-2021-01-03"], "value": [1, None]}
),
value_col_name="value",
)

aggregated_values = process_spec._aggregate_within_slice(
sliced_frame=sliced_frame, aggregators=[MeanAggregator()], fallback=0
)

expected = str_to_pl_df(
"""pred_time_uuid,value_mean_fallback_0
1-2021-01-03,1"""
)

assert_frame_equal(aggregated_values[0].df.collect(), expected)


def test_aggregate_within_slice():
sliced_frame = SlicedFrame(
init_df=str_to_pl_df(
Expand All @@ -43,7 +63,7 @@ def test_aggregate_within_slice():
value_col_name="value",
)

aggregated_values = timeseriesflattenerv2._process_spec._aggregate_within_slice(
aggregated_values = process_spec._aggregate_within_slice(
sliced_frame=sliced_frame, aggregators=[MeanAggregator()], fallback=0
)

Expand Down Expand Up @@ -71,15 +91,49 @@ def test_get_timedelta_frame():

expected_timedeltas = [dt.timedelta(days=-2), dt.timedelta(days=-1), dt.timedelta(days=0)]

result = timeseriesflattenerv2._process_spec._get_timedelta_frame(
result = process_spec._get_timedelta_frame(
predictiontime_frame=PredictionTimeFrame(init_df=pred_frame.lazy()),
value_frame=ValueFrame(init_df=value_frame.lazy(), value_col_name="value"),
)

assert result.get_timedeltas() == expected_timedeltas


def test_multiple_aggregatrs():
def test_slice_without_any_within_window():
timedelta_frame = TimedeltaFrame(
df=pl.LazyFrame(
{
"pred_time_uuid": [1, 1, 2, 2],
"time_from_prediction_to_value": [
dt.timedelta(days=1), # Outside the lookbehind
dt.timedelta(days=-1), # Inside the lookbehind
dt.timedelta(days=-2.1), # Outside the lookbehind
dt.timedelta(days=-2.1), # Outside the lookbehind
],
"is_null": [None, 0, None, None],
}
),
value_col_name="is_null",
)

result = process_spec._slice_frame(
timedelta_frame=timedelta_frame,
lookdistance=dt.timedelta(days=-2),
column_prefix="pred",
value_col_name="value",
).collect()

from polars.testing import assert_series_equal

assert_series_equal(
result.get_column("pred_value_within_2_days"),
timedelta_frame.df.collect().get_column("is_null"),
check_names=False,
check_dtype=False,
)


def test_multiple_aggregators():
sliced_frame = SlicedFrame(
init_df=str_to_pl_df(
"""pred_time_uuid,value
Expand All @@ -91,7 +145,7 @@ def test_multiple_aggregatrs():
value_col_name="value",
)

aggregated_values = timeseriesflattenerv2._process_spec._aggregate_within_slice(
aggregated_values = process_spec._aggregate_within_slice(
sliced_frame=sliced_frame, aggregators=[MeanAggregator(), MaxAggregator()], fallback=0
)

Expand Down

0 comments on commit 9e5cef4

Please sign in to comment.