Skip to content

Commit

Permalink
Refactor aggregator use fit, transform methods (#526)
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Dec 8, 2023
1 parent 4cd2eb2 commit 033de88
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 102 deletions.
166 changes: 77 additions & 89 deletions cyclops/data/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin

from cyclops.data.clean import dropna_rows
from cyclops.data.constants import ALL, FIRST, LAST, MEAN, MEDIAN
Expand All @@ -30,7 +31,7 @@
TIMESTEP = "timestep"


class Aggregator:
class Aggregator(TransformerMixin): # type: ignore
"""Equal-spaced aggregation, or binning, of time-series data.
Computing aggregation metadata is expensive and should be done sparingly.
Expand All @@ -50,13 +51,20 @@ class Aggregator:
timestep_size: float
Time in hours for a single timestep, or bin.
window_duration: float or None
Time duration to consider after the start of a timestep.
Time in hours for the aggregation window. If None, the latest timestamp
for each time_by group is used as the window stop time.
window_start_time: pd.DataFrame or None
An optionally provided window start time for each time_by group.
window_stop_time: pd.DataFrame or None
An optionally provided window stop time for each time_by group.
agg_meta_for: list of str or None
Columns for which to compute aggregation metadata.
window_times: pd.DataFrame or None
The start/stop time windows used to aggregate the data.
imputer: AggregatedImputer or None
An imputer to perform aggregation.
num_timesteps: int or None
The number of timesteps in the aggregation window.
"""

Expand All @@ -68,6 +76,8 @@ def __init__(
agg_by: Union[str, List[str]],
timestep_size: Optional[int] = None,
window_duration: Optional[int] = None,
window_start_time: Optional[pd.DataFrame] = None,
window_stop_time: Optional[pd.DataFrame] = None,
imputer: Optional[AggregatedImputer] = None,
agg_meta_for: Optional[List[str]] = None,
):
Expand All @@ -81,6 +91,8 @@ def __init__(
self.agg_meta_for = to_list_optional(agg_meta_for)
self.timestep_size = timestep_size
self.window_duration = window_duration
self.window_start_time = window_start_time
self.window_stop_time = window_stop_time
self.window_times = pd.DataFrame() # Calculated when given the data
self.imputer = imputer
# Parameter checking
Expand All @@ -92,8 +104,13 @@ def __init__(
)
if window_duration is not None and timestep_size is not None:
divided = window_duration / timestep_size
self.num_timesteps = int(divided)
if divided != int(divided):
raise ValueError("Window duration be divisible by bucket size.")
elif timestep_size is not None:
self.num_timesteps = None # type: ignore
else:
self.num_timesteps = 1

def _process_aggfuncs(
self,
Expand Down Expand Up @@ -315,19 +332,13 @@ def _compute_window_stop(
def _compute_window_times(
self,
data: pd.DataFrame,
window_start_time: Optional[pd.DataFrame] = None,
window_stop_time: Optional[pd.DataFrame] = None,
) -> pd.DataFrame:
"""Compute the start/stop timestamps for each time_by window.
Parameters
----------
data: pandas.DataFrame
Data before aggregation.
window_start_time: pd.DataFrame, optional
An optionally provided window start time.
window_stop_time: pd.DataFrame, optional
An optionally provided window stop time.
Returns
-------
Expand All @@ -338,13 +349,13 @@ def _compute_window_times(
# Compute window start time
window_start_time = self._compute_window_start(
data,
window_start_time=window_start_time,
window_start_time=self.window_start_time,
)
# Compute window stop time
window_stop_time = self._compute_window_stop(
data,
window_start_time,
window_stop_time=window_stop_time,
window_stop_time=self.window_stop_time,
)
# Combine and compute additional information
window_start_time = window_start_time.rename(
Expand Down Expand Up @@ -485,67 +496,6 @@ def _aggregate(

return aggregated.set_index(self.agg_by + [TIMESTEP])

@time_function
def __call__(
self,
data: pd.DataFrame,
window_start_time: Optional[pd.DataFrame] = None,
window_stop_time: Optional[pd.DataFrame] = None,
include_timestep_start: bool = True,
) -> pd.DataFrame:
"""Aggregate.
The window start and stop times can be used to cut short the timeseries.
By default, the start time of a time_by group will be the earliest
recorded timestamp in said group. Otherwise, a window_start_time
can be provided by the user to override this default.
The end time of a time_by group work similarly, but with the additional
option of specifying a window_duration.
Parameters
----------
data: pandas.DataFrame
Input data.
window_start_time: pd.DataFrame, optional
An optionally provided window start time.
window_stop_time: pd.DataFrame, optional
An optionally provided window stop time. This cannot be provided if
window_duration was set.
include_timestep_start: bool, default = True
Whether to include the window start timestamps for each timestep.
Returns
-------
pandas.DataFrame
The aggregated data.
"""
# Parameter checking
if not isinstance(data, pd.DataFrame):
raise ValueError("Data to aggregate must be a pandas.DataFrame.")
has_columns(
data,
list(set([self.timestamp_col] + self.time_by + self.agg_by)),
raise_error=True,
)
if has_columns(data, TIMESTEP):
raise ValueError(f"Input data cannot have a column called {TIMESTEP}.")
# Ensure the timestamp column is a timestamp. Drop null times (NaT).
is_timestamp_series(data[self.timestamp_col], raise_error=True)
data = dropna_rows(data, self.timestamp_col)
# Compute start/stop timestamps
self.window_times = self._compute_window_times(
data,
window_start_time=window_start_time,
window_stop_time=window_stop_time,
)
# Restrict the data according to the start/stop
data = self._restrict_by_timestamp(data)

return self._aggregate(data, include_timestep_start=include_timestep_start)

@time_function
def vectorize(self, aggregated: pd.DataFrame) -> Vectorized:
"""Vectorize aggregated data.
Expand Down Expand Up @@ -602,26 +552,41 @@ def vectorize(self, aggregated: pd.DataFrame) -> Vectorized:
axis_names=["aggfuncs"] + self.agg_by + [TIMESTEP],
)

def aggregate_values(
def fit(
self,
data: pd.DataFrame,
window_start_time: Optional[pd.DataFrame] = None,
window_stop_time: Optional[pd.DataFrame] = None,
) -> pd.DataFrame:
"""Aggregate temporal values.
) -> None:
"""Fit the aggregator.
Parameters
----------
data: pandas.DataFrame
Input data.
"""
# Parameter checking
if not isinstance(data, pd.DataFrame):
raise ValueError("Data to aggregate must be a DataFrame.")
self.window_times = self._compute_window_times(
data,
)

The temporal values are restricted by start/stop and then aggregated.
No timestep is created.
def transform(
self,
data: pd.DataFrame,
y: None = None,
include_timestep_start: bool = True,
) -> pd.DataFrame:
"""Transform the data by aggregating.
Parameters
----------
data: pandas.DataFrame
Input data.
window_start_time: pd.DataFrame, optional
An optionally provided window start time.
window_stop_time: pd.DataFrame, optional
An optionally provided window stop time. This cannot be provided if
window_duration was set.
y: None
Placeholder for sklearn compatibility.
include_timestep_start: bool, default = True
Whether to include the window start timestamps for each timestep.
Returns
-------
Expand All @@ -634,19 +599,42 @@ def aggregate_values(
list(set([self.timestamp_col] + self.time_by + self.agg_by)),
raise_error=True,
)
if has_columns(data, TIMESTEP):
raise ValueError(f"Input data cannot have a column called {TIMESTEP}.")
# Ensure the timestamp column is a timestamp. Drop null times (NaT).
is_timestamp_series(data[self.timestamp_col], raise_error=True)
data = dropna_rows(data, self.timestamp_col)
self.window_times = self._compute_window_times(
data,
window_start_time=window_start_time,
window_stop_time=window_stop_time,
)
# Restrict the data according to the start/stop
data = self._restrict_by_timestamp(data)
grouped = data.groupby(self.agg_by, sort=False)

return grouped.agg(self.aggfuncs)
if self.num_timesteps == 1:
return grouped.agg(self.aggfuncs)
if self.num_timesteps is None or self.num_timesteps > 1:
return self._aggregate(data, include_timestep_start=include_timestep_start)

raise ValueError("num_timesteps must be greater than 0.")

def fit_transform(
self,
data: pd.DataFrame,
) -> pd.DataFrame:
"""Fit the aggregator and transform the data by aggregating.
Parameters
----------
data: pandas.DataFrame
Input data.
Returns
-------
pandas.DataFrame
The aggregated data.
"""
self.fit(data)

return self.transform(data)


def tabular_as_aggregated(
Expand Down
8 changes: 5 additions & 3 deletions docs/source/tutorials/mimiciv/mortality_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
" patients = querier.patients()\n",
" encounters = querier.mimiciv_hosp.admissions()\n",
" drop_op = qo.Drop(\n",
" [\"insurance\", \"language\", \"marital_status\", \"edregtime\", \"edouttime\"],\n",
" [\"language\", \"marital_status\", \"edregtime\", \"edouttime\"],\n",
" )\n",
" encounters = encounters.ops(drop_op)\n",
" patient_encounters = patients.join(encounters, on=\"subject_id\")\n",
Expand All @@ -168,6 +168,8 @@
" \"gender\",\n",
" \"anchor_year_difference\",\n",
" \"admission_location\",\n",
" \"admission_type\",\n",
" \"insurance\",\n",
" \"hospital_expire_flag\",\n",
" ]\n",
" ]\n",
Expand Down Expand Up @@ -240,6 +242,7 @@
" \"valuenum\": \"mean\",\n",
" },\n",
" window_duration=M * 24,\n",
" window_start_time=start_timestamps,\n",
" timestamp_col=\"charttime\",\n",
" time_by=\"hadm_id\",\n",
" agg_by=[\"hadm_id\", \"label\"],\n",
Expand All @@ -250,9 +253,8 @@
" labevents_batch,\n",
" patient_encounters,\n",
" )\n",
" means = mean_aggregator.aggregate_values(\n",
" means = mean_aggregator.fit_transform(\n",
" labevents_batch,\n",
" window_start_time=start_timestamps,\n",
" )\n",
" means = means.reset_index()\n",
" means = means.pivot(index=\"hadm_id\", columns=\"label\", values=\"valuenum\")\n",
Expand Down
20 changes: 10 additions & 10 deletions tests/cyclops/data/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_aggregate_events(
timestep_size=1,
agg_meta_for=EVENT_VALUE,
)
res = aggregator(data)
res = aggregator.fit_transform(data)

assert res.index.names == [ENCOUNTER_ID, EVENT_NAME, TIMESTEP]
assert res.loc[(2, "eventA", 1)][EVENT_VALUE] == 19
Expand All @@ -116,8 +116,8 @@ def test_aggregate_window_duration(
timestep_size=1,
window_duration=12,
)
res = aggregator.fit_transform(data)

res = aggregator(data)
res = res.reset_index()
assert (res[TIMESTEP] < 2).all()

Expand All @@ -134,13 +134,10 @@ def test_aggregate_start_stop_windows(
time_by=ENCOUNTER_ID,
agg_by=[ENCOUNTER_ID, EVENT_NAME],
timestep_size=1,
)

res = aggregator(
data,
window_start_time=window_start_time,
window_stop_time=window_stop_time,
)
res = aggregator.fit_transform(data)

assert res.loc[(2, "eventA", 0)][START_TIMESTEP] == DATE2

Expand All @@ -154,9 +151,10 @@ def test_aggregate_start_stop_windows(
agg_by=[ENCOUNTER_ID, EVENT_NAME],
timestep_size=1,
window_duration=10,
window_stop_time=window_stop_time,
)
try:
res = aggregator(data, window_stop_time=window_stop_time)
res = aggregator.fit_transform(data)
raise ValueError(
"""Should have raised an error that window_duration cannot be set when
window_stop_time is specified.""",
Expand Down Expand Up @@ -190,7 +188,9 @@ def test_aggregate_strings(
window_duration=20,
)

assert aggregator_str(data).equals(aggregator_fn(data))
assert aggregator_str.fit_transform(data).equals(
aggregator_fn.fit_transform(data),
)

with contextlib.suppress(ValueError):
aggregator_str = Aggregator(
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_aggregate_multiple(
window_duration=20,
)

res = aggregator(data)
res = aggregator.fit_transform(data)

res = res.reset_index()
assert res["event_value2"].equals(res[EVENT_VALUE] * 2)
Expand Down Expand Up @@ -318,7 +318,7 @@ def test_vectorization(
window_duration=15,
)

aggregated = aggregator(data)
aggregated = aggregator.fit_transform(data)

vectorized_obj = aggregator.vectorize(aggregated)
vectorized, indexes = vectorized_obj.data, vectorized_obj.indexes
Expand Down

0 comments on commit 033de88

Please sign in to comment.