# Decompensation Prediction

This notebook showcases decompensation prediction on the [MIMICIV](https://physionet.org/content/mimiciv/2.0) dataset using CyclOps. The task is formulated as a binary classification task, whether the patient will deteriorate (mortality prediction in the next 3 days after every 24 hours).

In [None]:
import cycquery.ops as qo
import pandas as pd
from cycquery import MIMICIVQuerier

from cyclops.process.aggregate import RESTRICT_TIMESTAMP, Aggregator
from cyclops.process.clean import normalize_names
from cyclops.process.feature.feature import TemporalFeatures
from cyclops.utils.common import add_years_approximate

In [None]:
OUTCOME_DEATH = "outcome_death"

# Query

In [None]:
querier = MIMICIVQuerier(
    dbms="postgresql",
    port=5432,
    host="localhost",
    database="mimiciv-2.0",
    user="postgres",
    password="pwd",
)

## Patient encounters

In [None]:
patients = querier.patients()
encounters = querier.mimiciv_hosp.admissions()
drop_op = qo.Drop(
    ["insurance", "language", "marital_status", "edregtime", "edouttime"],
)
encounters = encounters.ops(drop_op)
patient_encounters = patients.join(encounters, on="subject_id")
patient_encounters = patient_encounters.run()

In [None]:
patient_encounters["age"] = patient_encounters["admittime"].dt.year - (
    patient_encounters["anchor_year"]
    - patient_encounters["anchor_age"]
    + patient_encounters["anchor_year_difference"]
)
for col in ["admittime", "dischtime", "deathtime"]:
    patient_encounters[col] = add_years_approximate(
        patient_encounters[col],
        patient_encounters["anchor_year_difference"],
    )
patient_encounters = patient_encounters[
    [
        "hadm_id",
        "admittime",
        "dischtime",
        "deathtime",
        "anchor_age",
        "age",
        "gender",
        "anchor_year_difference",
        "admission_location",
        "hospital_expire_flag",
    ]
]

Create death indicator

Hospital expire flag:
 - 1 - Death in hospital
 - 0 - Survived past discharge

In [None]:
# Drop encounters ending in death which don't have a death timestamp
invalid = (patient_encounters["hospital_expire_flag"] == 1) & (
    patient_encounters["deathtime"].isna()
)
patient_encounters = patient_encounters[~invalid]

# (Died in hospital) & (Death timestamp is defined)
patient_encounters[OUTCOME_DEATH] = patient_encounters["hospital_expire_flag"] == 1

In [None]:
(patient_encounters[OUTCOME_DEATH] is True).sum() / len(
    patient_encounters,
)  # noqa: E712

## Events

In [None]:
labevents = querier.labevents().run(index_col="hadm_id", batch_mode=True)

# Preprocess

Can be run entirely separately from the querying.

In [None]:
def process_labevents(labevents, patient_encounters):
    """Process labevents before aggregation."""
    # Reverse deidentified dating
    labevents = pd.merge(
        patient_encounters[["hadm_id", "anchor_year_difference"]],
        labevents,
        on="hadm_id",
    )
    labevents["charttime"] = add_years_approximate(
        labevents["charttime"],
        labevents["anchor_year_difference"],
    )
    labevents = labevents.drop("anchor_year_difference", axis=1)

    # Pre-processing
    labevents["label"] = normalize_names(labevents["label"])
    labevents["category"] = normalize_names(labevents["category"])

    return labevents

In [None]:
start_timestamps = (
    patient_encounters[["hadm_id", "admittime"]]
    .set_index("hadm_id")
    .rename({"admittime": RESTRICT_TIMESTAMP}, axis=1)
)
aggregator = Aggregator(
    aggfuncs={"valuenum": "mean"},
    timestamp_col="charttime",
    time_by="hadm_id",
    agg_by=["hadm_id", "label"],
    timestep_size=24,
    window_duration=96,
)

In [None]:
for _count, labevents_batch in enumerate(labevents):
    labevents_batch = process_labevents(labevents_batch, patient_encounters)
    # Aggregate
    labevents_batch = labevents_batch.reset_index(drop=True)
    temporal_features = TemporalFeatures(
        labevents_batch,
        features="valuenum",
        by=["hadm_id", "label"],
        timestamp_col="charttime",
        aggregator=aggregator,
    )
    aggregated = temporal_features.aggregate(window_start_time=start_timestamps)
    vectorized = aggregator.vectorize(aggregated)
    print(vectorized)
    break

In [None]:
vectorized.data.shape