Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor aggregator use fit, transform methods #526

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 77 additions & 89 deletions cyclops/data/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin

from cyclops.data.clean import dropna_rows
from cyclops.data.constants import ALL, FIRST, LAST, MEAN, MEDIAN
Expand All @@ -30,7 +31,7 @@
TIMESTEP = "timestep"


class Aggregator:
class Aggregator(TransformerMixin): # type: ignore
"""Equal-spaced aggregation, or binning, of time-series data.

Computing aggregation metadata is expensive and should be done sparingly.
Expand All @@ -50,13 +51,20 @@
timestep_size: float
Time in hours for a single timestep, or bin.
window_duration: float or None
Time duration to consider after the start of a timestep.
Time in hours for the aggregation window. If None, the latest timestamp
for each time_by group is used as the window stop time.
window_start_time: pd.DataFrame or None
An optionally provided window start time for each time_by group.
window_stop_time: pd.DataFrame or None
An optionally provided window stop time for each time_by group.
agg_meta_for: list of str or None
Columns for which to compute aggregation metadata.
window_times: pd.DataFrame or None
The start/stop time windows used to aggregate the data.
imputer: AggregatedImputer or None
An imputer to perform aggregation.
num_timesteps: int or None
The number of timesteps in the aggregation window.

"""

Expand All @@ -68,6 +76,8 @@
agg_by: Union[str, List[str]],
timestep_size: Optional[int] = None,
window_duration: Optional[int] = None,
window_start_time: Optional[pd.DataFrame] = None,
window_stop_time: Optional[pd.DataFrame] = None,
imputer: Optional[AggregatedImputer] = None,
agg_meta_for: Optional[List[str]] = None,
):
Expand All @@ -81,6 +91,8 @@
self.agg_meta_for = to_list_optional(agg_meta_for)
self.timestep_size = timestep_size
self.window_duration = window_duration
self.window_start_time = window_start_time
self.window_stop_time = window_stop_time
self.window_times = pd.DataFrame() # Calculated when given the data
self.imputer = imputer
# Parameter checking
Expand All @@ -92,8 +104,13 @@
)
if window_duration is not None and timestep_size is not None:
divided = window_duration / timestep_size
self.num_timesteps = int(divided)
if divided != int(divided):
raise ValueError("Window duration be divisible by bucket size.")
elif timestep_size is not None:
self.num_timesteps = None # type: ignore
else:
self.num_timesteps = 1

Check warning on line 113 in cyclops/data/aggregate.py

View check run for this annotation

Codecov / codecov/patch

cyclops/data/aggregate.py#L113

Added line #L113 was not covered by tests

def _process_aggfuncs(
self,
Expand Down Expand Up @@ -315,19 +332,13 @@
def _compute_window_times(
self,
data: pd.DataFrame,
window_start_time: Optional[pd.DataFrame] = None,
window_stop_time: Optional[pd.DataFrame] = None,
) -> pd.DataFrame:
"""Compute the start/stop timestamps for each time_by window.

Parameters
----------
data: pandas.DataFrame
Data before aggregation.
window_start_time: pd.DataFrame, optional
An optionally provided window start time.
window_stop_time: pd.DataFrame, optional
An optionally provided window stop time.

Returns
-------
Expand All @@ -338,13 +349,13 @@
# Compute window start time
window_start_time = self._compute_window_start(
data,
window_start_time=window_start_time,
window_start_time=self.window_start_time,
)
# Compute window stop time
window_stop_time = self._compute_window_stop(
data,
window_start_time,
window_stop_time=window_stop_time,
window_stop_time=self.window_stop_time,
)
# Combine and compute additional information
window_start_time = window_start_time.rename(
Expand Down Expand Up @@ -485,67 +496,6 @@

return aggregated.set_index(self.agg_by + [TIMESTEP])

@time_function
def __call__(
self,
data: pd.DataFrame,
window_start_time: Optional[pd.DataFrame] = None,
window_stop_time: Optional[pd.DataFrame] = None,
include_timestep_start: bool = True,
) -> pd.DataFrame:
"""Aggregate.

The window start and stop times can be used to cut short the timeseries.

By default, the start time of a time_by group will be the earliest
recorded timestamp in said group. Otherwise, a window_start_time
can be provided by the user to override this default.

The end time of a time_by group work similarly, but with the additional
option of specifying a window_duration.

Parameters
----------
data: pandas.DataFrame
Input data.
window_start_time: pd.DataFrame, optional
An optionally provided window start time.
window_stop_time: pd.DataFrame, optional
An optionally provided window stop time. This cannot be provided if
window_duration was set.
include_timestep_start: bool, default = True
Whether to include the window start timestamps for each timestep.

Returns
-------
pandas.DataFrame
The aggregated data.

"""
# Parameter checking
if not isinstance(data, pd.DataFrame):
raise ValueError("Data to aggregate must be a pandas.DataFrame.")
has_columns(
data,
list(set([self.timestamp_col] + self.time_by + self.agg_by)),
raise_error=True,
)
if has_columns(data, TIMESTEP):
raise ValueError(f"Input data cannot have a column called {TIMESTEP}.")
# Ensure the timestamp column is a timestamp. Drop null times (NaT).
is_timestamp_series(data[self.timestamp_col], raise_error=True)
data = dropna_rows(data, self.timestamp_col)
# Compute start/stop timestamps
self.window_times = self._compute_window_times(
data,
window_start_time=window_start_time,
window_stop_time=window_stop_time,
)
# Restrict the data according to the start/stop
data = self._restrict_by_timestamp(data)

return self._aggregate(data, include_timestep_start=include_timestep_start)

@time_function
def vectorize(self, aggregated: pd.DataFrame) -> Vectorized:
"""Vectorize aggregated data.
Expand Down Expand Up @@ -602,26 +552,41 @@
axis_names=["aggfuncs"] + self.agg_by + [TIMESTEP],
)

def aggregate_values(
def fit(
self,
data: pd.DataFrame,
window_start_time: Optional[pd.DataFrame] = None,
window_stop_time: Optional[pd.DataFrame] = None,
) -> pd.DataFrame:
"""Aggregate temporal values.
) -> None:
"""Fit the aggregator.

Parameters
----------
data: pandas.DataFrame
Input data.

"""
# Parameter checking
if not isinstance(data, pd.DataFrame):
raise ValueError("Data to aggregate must be a DataFrame.")

Check warning on line 569 in cyclops/data/aggregate.py

View check run for this annotation

Codecov / codecov/patch

cyclops/data/aggregate.py#L569

Added line #L569 was not covered by tests
self.window_times = self._compute_window_times(
data,
)

The temporal values are restricted by start/stop and then aggregated.
No timestep is created.
def transform(
self,
data: pd.DataFrame,
y: None = None,
include_timestep_start: bool = True,
) -> pd.DataFrame:
"""Transform the data by aggregating.

Parameters
----------
data: pandas.DataFrame
Input data.
window_start_time: pd.DataFrame, optional
An optionally provided window start time.
window_stop_time: pd.DataFrame, optional
An optionally provided window stop time. This cannot be provided if
window_duration was set.
y: None
Placeholder for sklearn compatibility.
include_timestep_start: bool, default = True
Whether to include the window start timestamps for each timestep.

Returns
-------
Expand All @@ -634,19 +599,42 @@
list(set([self.timestamp_col] + self.time_by + self.agg_by)),
raise_error=True,
)
if has_columns(data, TIMESTEP):
raise ValueError(f"Input data cannot have a column called {TIMESTEP}.")

Check warning on line 603 in cyclops/data/aggregate.py

View check run for this annotation

Codecov / codecov/patch

cyclops/data/aggregate.py#L603

Added line #L603 was not covered by tests
# Ensure the timestamp column is a timestamp. Drop null times (NaT).
is_timestamp_series(data[self.timestamp_col], raise_error=True)
data = dropna_rows(data, self.timestamp_col)
self.window_times = self._compute_window_times(
data,
window_start_time=window_start_time,
window_stop_time=window_stop_time,
)
# Restrict the data according to the start/stop
data = self._restrict_by_timestamp(data)
grouped = data.groupby(self.agg_by, sort=False)

return grouped.agg(self.aggfuncs)
if self.num_timesteps == 1:
return grouped.agg(self.aggfuncs)

Check warning on line 612 in cyclops/data/aggregate.py

View check run for this annotation

Codecov / codecov/patch

cyclops/data/aggregate.py#L612

Added line #L612 was not covered by tests
if self.num_timesteps is None or self.num_timesteps > 1:
return self._aggregate(data, include_timestep_start=include_timestep_start)

raise ValueError("num_timesteps must be greater than 0.")

Check warning on line 616 in cyclops/data/aggregate.py

View check run for this annotation

Codecov / codecov/patch

cyclops/data/aggregate.py#L616

Added line #L616 was not covered by tests

def fit_transform(
self,
data: pd.DataFrame,
) -> pd.DataFrame:
"""Fit the aggregator and transform the data by aggregating.

Parameters
----------
data: pandas.DataFrame
Input data.

Returns
-------
pandas.DataFrame
The aggregated data.

"""
self.fit(data)

return self.transform(data)


def tabular_as_aggregated(
Expand Down
8 changes: 5 additions & 3 deletions docs/source/tutorials/mimiciv/mortality_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
" patients = querier.patients()\n",
" encounters = querier.mimiciv_hosp.admissions()\n",
" drop_op = qo.Drop(\n",
" [\"insurance\", \"language\", \"marital_status\", \"edregtime\", \"edouttime\"],\n",
" [\"language\", \"marital_status\", \"edregtime\", \"edouttime\"],\n",
" )\n",
" encounters = encounters.ops(drop_op)\n",
" patient_encounters = patients.join(encounters, on=\"subject_id\")\n",
Expand All @@ -168,6 +168,8 @@
" \"gender\",\n",
" \"anchor_year_difference\",\n",
" \"admission_location\",\n",
" \"admission_type\",\n",
" \"insurance\",\n",
" \"hospital_expire_flag\",\n",
" ]\n",
" ]\n",
Expand Down Expand Up @@ -240,6 +242,7 @@
" \"valuenum\": \"mean\",\n",
" },\n",
" window_duration=M * 24,\n",
" window_start_time=start_timestamps,\n",
" timestamp_col=\"charttime\",\n",
" time_by=\"hadm_id\",\n",
" agg_by=[\"hadm_id\", \"label\"],\n",
Expand All @@ -250,9 +253,8 @@
" labevents_batch,\n",
" patient_encounters,\n",
" )\n",
" means = mean_aggregator.aggregate_values(\n",
" means = mean_aggregator.fit_transform(\n",
" labevents_batch,\n",
" window_start_time=start_timestamps,\n",
" )\n",
" means = means.reset_index()\n",
" means = means.pivot(index=\"hadm_id\", columns=\"label\", values=\"valuenum\")\n",
Expand Down
20 changes: 10 additions & 10 deletions tests/cyclops/data/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_aggregate_events(
timestep_size=1,
agg_meta_for=EVENT_VALUE,
)
res = aggregator(data)
res = aggregator.fit_transform(data)

assert res.index.names == [ENCOUNTER_ID, EVENT_NAME, TIMESTEP]
assert res.loc[(2, "eventA", 1)][EVENT_VALUE] == 19
Expand All @@ -116,8 +116,8 @@ def test_aggregate_window_duration(
timestep_size=1,
window_duration=12,
)
res = aggregator.fit_transform(data)

res = aggregator(data)
res = res.reset_index()
assert (res[TIMESTEP] < 2).all()

Expand All @@ -134,13 +134,10 @@ def test_aggregate_start_stop_windows(
time_by=ENCOUNTER_ID,
agg_by=[ENCOUNTER_ID, EVENT_NAME],
timestep_size=1,
)

res = aggregator(
data,
window_start_time=window_start_time,
window_stop_time=window_stop_time,
)
res = aggregator.fit_transform(data)

assert res.loc[(2, "eventA", 0)][START_TIMESTEP] == DATE2

Expand All @@ -154,9 +151,10 @@ def test_aggregate_start_stop_windows(
agg_by=[ENCOUNTER_ID, EVENT_NAME],
timestep_size=1,
window_duration=10,
window_stop_time=window_stop_time,
)
try:
res = aggregator(data, window_stop_time=window_stop_time)
res = aggregator.fit_transform(data)
raise ValueError(
"""Should have raised an error that window_duration cannot be set when
window_stop_time is specified.""",
Expand Down Expand Up @@ -190,7 +188,9 @@ def test_aggregate_strings(
window_duration=20,
)

assert aggregator_str(data).equals(aggregator_fn(data))
assert aggregator_str.fit_transform(data).equals(
aggregator_fn.fit_transform(data),
)

with contextlib.suppress(ValueError):
aggregator_str = Aggregator(
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_aggregate_multiple(
window_duration=20,
)

res = aggregator(data)
res = aggregator.fit_transform(data)

res = res.reset_index()
assert res["event_value2"].equals(res[EVENT_VALUE] * 2)
Expand Down Expand Up @@ -318,7 +318,7 @@ def test_vectorization(
window_duration=15,
)

aggregated = aggregator(data)
aggregated = aggregator.fit_transform(data)

vectorized_obj = aggregator.vectorize(aggregated)
vectorized, indexes = vectorized_obj.data, vectorized_obj.indexes
Expand Down
Loading