Skip to content

Commit

Permalink
fix: don't process as batch if no specs to process
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff committed Dec 7, 2022
1 parent 01b3957 commit aba0b67
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
15 changes: 13 additions & 2 deletions src/timeseriesflattener/flattened_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class SpecCollection(PydanticBaseModel):
predictor_specs: list[TemporalSpec] = []
static_specs: list[AnySpec] = []

def __len__(self):
return (
len(self.outcome_specs) + len(self.predictor_specs) + len(self.static_specs)
)


class TimeseriesFlattener: # pylint: disable=too-many-instance-attributes
"""Turn a set of time-series into tabular prediction-time data."""
Expand Down Expand Up @@ -638,7 +643,8 @@ def _process_temporal_specs(self):
temporal_batch = self.unprocessed_specs.outcome_specs
temporal_batch += self.unprocessed_specs.predictor_specs

self._add_temporal_batch(temporal_batch=temporal_batch)
if len(temporal_batch) > 0:
self._add_temporal_batch(temporal_batch=temporal_batch)

def _process_predictor_specs(self):
"""Process predictor specs."""
Expand Down Expand Up @@ -743,6 +749,10 @@ def add_age_and_birth_year(

def compute(self):
"""Compute the flattened dataset."""
if len(self.unprocessed_specs) == 0:
log.warning("No unprocessed specs, skipping")
return

self._process_temporal_specs()
self._process_static_specs()

Expand All @@ -752,7 +762,8 @@ def get_df(self) -> DataFrame:
Returns:
DataFrame: Flattened dataframe.
"""
self.compute()
if len(self.unprocessed_specs) > 0:
self.compute()

# Process
return self._df
4 changes: 2 additions & 2 deletions src/timeseriesflattener/testing/utils_for_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def assert_flattened_data_as_expected(
n_workers=4,
)

flattened_ds._add_temporal_col_to_flattened_dataset( # pylint: disable=protected-access
output_spec=output_spec,
flattened_ds.add_spec( # pylint: disable=protected-access
spec=output_spec,
)

if expected_df:
Expand Down

0 comments on commit aba0b67

Please sign in to comment.