In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [12]:
# import data
df_cleaned = pd.read_csv("../data/MIMIC-ED/event_level_with_lac_wbc.csv")

# count NaNs
nan_counts = df_cleaned.isna().sum().sum()
print("Number of NaNs in the DataFrame:", nan_counts)

# analyze data types of each column
print(df_cleaned.dtypes)
df_cleaned.head()

KeyboardInterrupt: 

In [None]:
# locate the nans
nan_columns = df_cleaned.isna().sum()
nan_columns = nan_columns[nan_columns > 0]
print("Columns with NaNs and their counts:")
print(nan_columns)

Columns with NaNs and their counts:
pain_stay             171131
chiefcomplaint           101
icd_title              15935
hadm_id              5649620
sepsis_onset_time    4537570
dtype: int64


In [None]:
df_cleaned = df_cleaned.drop(columns=nan_columns.index)
print(df_cleaned.isna().sum().sum())

0


In [None]:
# group patients by "stay_id" and take 2000 random patients
patient_ids = df_cleaned["stay_id"].unique()
# make sure there are 1000 septic and 1000 nonseptic patients
#extract stay_ids with sepsis_dx_any == 1
septic_patient_ids = df_cleaned[df_cleaned["sepsis_dx_any"] == 1]["stay_id"].unique()
# count of septic patients
count_septic = len(septic_patient_ids)
print("Number of septic patients:", count_septic)
nonseptic_patient_ids = df_cleaned[df_cleaned["sepsis_dx_any"] == 0]["stay_id"].unique()
# take 2 times septic number of random patients from the nonseptic patients
random_nonseptic_patient_ids = np.random.choice(nonseptic_patient_ids, size=count_septic, replace=False)
random_patient_ids = np.concatenate([septic_patient_ids, random_nonseptic_patient_ids])
#shuffle the ids and filter df to only include these patients
np.random.shuffle(random_patient_ids)
df_small = df_cleaned[df_cleaned["stay_id"].isin(random_patient_ids)]
print("Number of patients in small df:", df_small["stay_id"].nunique())

Number of septic patients: 2339
Number of patients in small df: 4678


In [6]:
# drop all other labels
df_small_new = df_small.drop(columns=["sepsis_dx_any", "sepsis_dx", "sirs_count", "sirs_ge2", "is_sepsis_onset"])
count_sepsis = df_small_new["is_sepsis"].sum()
print("Number of sepsis cases in small df:", count_sepsis)

Number of sepsis cases in small df: 46702


In [8]:
df_small_new = df_small_new.drop(columns=["subject_id", "outtime", "disposition", 'temperature_stay', 'heartrate_stay', 'resprate_stay','o2sat_stay', 'sbp_stay', 'dbp_stay', 'acuity'])
print(df_small_new.columns)

Index(['stay_id', 'charttime', 'temperature', 'heartrate', 'resprate', 'o2sat',
       'sbp', 'dbp', 'pain', 'rhythm_flag', 'med_rn', 'gsn_rn', 'gsn',
       'is_antibiotic', 'ndc', 'etc_rn', 'etccode', 'hadm_id_x', 'intime',
       'race', 'is_white', 'is_black', 'is_asian', 'is_hispanic',
       'is_other_race', 'gender_F', 'gender_M', 'arrival_transport_AMBULANCE',
       'arrival_transport_HELICOPTER', 'arrival_transport_OTHER',
       'arrival_transport_UNKNOWN', 'arrival_transport_WALK IN', 'lactate',
       'wbc', 'is_sepsis'],
      dtype='object')


In [9]:
# drop columns with nan times
print(df_small_new.isna().sum())
df_small_new["time_since_adm"] = (pd.to_datetime(df_small_new["charttime"]) - pd.to_datetime(df_small_new["intime"])).dt.total_seconds() / 3600.0
df_small_new = df_small_new.drop(columns = ["intime", "charttime"])
print(df_small_new.columns)

stay_id                         0
charttime                       0
temperature                     0
heartrate                       0
resprate                        0
o2sat                           0
sbp                             0
dbp                             0
pain                            0
rhythm_flag                     0
med_rn                          0
gsn_rn                          0
gsn                             0
is_antibiotic                   0
ndc                             0
etc_rn                          0
etccode                         0
hadm_id_x                       0
intime                          0
race                            0
is_white                        0
is_black                        0
is_asian                        0
is_hispanic                     0
is_other_race                   0
gender_F                        0
gender_M                        0
arrival_transport_AMBULANCE     0
arrival_transport_HELICOPTER    0
arrival_transp

In [10]:
# get most common gsns for septic patients
most_common_gsns = df_small_new['gsn'].value_counts().nlargest(5).index.tolist()
print("Most common gsn values:", most_common_gsns)

# one hot encode gsn
for gsn in most_common_gsns:
    df_small_new[f'gsn_{gsn}'] = (df_small_new['gsn'] == gsn).astype(int)

Most common gsn values: [16599.0, 43952.0, 4490.0, 66419.0, 61716.0]


In [11]:
df_small_new = df_small_new.drop(columns=['race', 'gsn', 'gsn_rn', 'med_rn', 'is_antibiotic', 'ndc', 'etc_rn', 'etccode'])
df_small_new.columns

Index(['stay_id', 'temperature', 'heartrate', 'resprate', 'o2sat', 'sbp',
       'dbp', 'pain', 'rhythm_flag', 'hadm_id_x', 'is_white', 'is_black',
       'is_asian', 'is_hispanic', 'is_other_race', 'gender_F', 'gender_M',
       'arrival_transport_AMBULANCE', 'arrival_transport_HELICOPTER',
       'arrival_transport_OTHER', 'arrival_transport_UNKNOWN',
       'arrival_transport_WALK IN', 'lactate', 'wbc', 'is_sepsis',
       'time_since_adm', 'gsn_16599.0', 'gsn_43952.0', 'gsn_4490.0',
       'gsn_66419.0', 'gsn_61716.0'],
      dtype='object')

In [17]:
# change all columns to numeric
df_small_new_num = df_small_new.copy()
for col in df_small_new_num.columns:
    df_small_new_num[col] = pd.to_numeric(df_small_new_num[col], errors='coerce')

# change all booleans to int
bool_cols = df_small_new_num.select_dtypes(include=['bool']).columns
df_small_new_num[bool_cols] = df_small_new_num[bool_cols].astype(int)
print(df_small_new_num.dtypes)

stay_id                           int64
temperature                     float64
heartrate                       float64
resprate                        float64
o2sat                           float64
sbp                             float64
dbp                             float64
pain                            float64
rhythm_flag                       int64
hadm_id_x                       float64
is_white                          int64
is_black                          int64
is_asian                          int64
is_hispanic                       int64
is_other_race                     int64
gender_F                          int64
gender_M                          int64
arrival_transport_AMBULANCE       int64
arrival_transport_HELICOPTER      int64
arrival_transport_OTHER           int64
arrival_transport_UNKNOWN         int64
arrival_transport_WALK IN         int64
lactate                         float64
wbc                             float64
is_sepsis                         int64


In [18]:
# reorder columns to have is_sepsis at the end
cols = list(df_small_new_num.columns)
cols.remove('is_sepsis')
cols.append('is_sepsis')
df_small_new_num = df_small_new_num[cols]
df_small_new_num.columns

Index(['stay_id', 'temperature', 'heartrate', 'resprate', 'o2sat', 'sbp',
       'dbp', 'pain', 'rhythm_flag', 'hadm_id_x', 'is_white', 'is_black',
       'is_asian', 'is_hispanic', 'is_other_race', 'gender_F', 'gender_M',
       'arrival_transport_AMBULANCE', 'arrival_transport_HELICOPTER',
       'arrival_transport_OTHER', 'arrival_transport_UNKNOWN',
       'arrival_transport_WALK IN', 'lactate', 'wbc', 'time_since_adm',
       'gsn_16599.0', 'gsn_43952.0', 'gsn_4490.0', 'gsn_66419.0',
       'gsn_61716.0', 'is_sepsis'],
      dtype='object')

In [19]:
# drop columns that could cause data leakage
df_small_new_num = df_small_new_num.drop(columns=['hadm_id_x'])

In [20]:
df_small_new_num

Unnamed: 0,stay_id,temperature,heartrate,resprate,o2sat,sbp,dbp,pain,rhythm_flag,is_white,...,arrival_transport_WALK IN,lactate,wbc,time_since_adm,gsn_16599.0,gsn_43952.0,gsn_4490.0,gsn_66419.0,gsn_61716.0,is_sepsis
626,37459204,98.1,79.0,18.0,98.0,126.0,72.0,0.0,0,1,...,0,1.6,7.8,0.016667,1,0,0,0,0,0
627,37459204,102.9,92.0,18.0,94.0,123.0,53.0,0.0,0,1,...,0,1.6,7.8,0.083333,1,0,0,0,0,1
628,37459204,102.9,92.0,18.0,94.0,123.0,53.0,0.0,0,1,...,0,1.6,7.8,0.150000,0,0,0,0,0,1
629,37459204,104.2,83.0,22.0,93.0,111.0,48.0,0.0,0,1,...,0,3.2,7.8,0.450000,0,0,0,0,0,1
630,37459204,104.2,83.0,22.0,93.0,111.0,48.0,0.0,0,1,...,0,3.2,15.6,1.083333,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6136674,39316677,98.5,118.0,23.0,97.0,117.0,82.0,0.0,0,0,...,1,1.8,7.7,10.650000,0,0,0,0,0,1
6136675,39316677,98.5,118.0,23.0,97.0,117.0,82.0,0.0,0,0,...,1,1.8,7.7,11.033333,0,0,0,0,0,1
6136676,39316677,98.5,130.0,20.0,99.0,132.0,100.0,0.0,0,0,...,1,1.8,7.7,11.133333,0,0,0,0,0,1
6136677,39316677,98.0,124.0,19.0,100.0,121.0,73.0,0.0,0,0,...,1,1.8,7.7,11.516667,0,0,0,0,0,1


In [21]:
df_small_new_num.to_csv("../data/MIMIC-ED/event_level_training_data.csv", index=False)

# Survival Analysis

In [3]:

# import data
df_cleaned = pd.read_csv("../data/MIMIC-ED/event_level_with_lac_wbc.csv")

# count NaNs
nan_counts = df_cleaned.isna().sum().sum()
print("Number of NaNs in the DataFrame:", nan_counts)

# analyze data types of each column
print(df_cleaned.dtypes)
df_cleaned.head()

Number of NaNs in the DataFrame: 10374357
subject_id                        int64
stay_id                           int64
charttime                        object
temperature                     float64
heartrate                       float64
resprate                        float64
o2sat                           float64
sbp                             float64
dbp                             float64
pain                            float64
rhythm_flag                       int64
med_rn                          float64
gsn_rn                          float64
gsn                             float64
is_antibiotic                     int64
ndc                             float64
etc_rn                          float64
etccode                         float64
hadm_id_x                       float64
intime                           object
outtime                          object
race                             object
disposition                      object
temperature_stay                float6

Unnamed: 0,subject_id,stay_id,charttime,temperature,heartrate,resprate,o2sat,sbp,dbp,pain,...,lactate,hadm_id,wbc,sepsis_dx_any,sepsis_dx,sirs_count,sirs_ge2,sepsis_onset_time,is_sepsis_onset,is_sepsis
0,13238787,35341790,2110-01-11 01:49:00,98.4,77.0,16.0,100.0,149.0,104.0,8.0,...,1.6,,7.8,0,0,0,0,,0,0
1,15350437,39042378,2110-01-11 03:45:00,97.1,71.0,16.0,100.0,117.0,79.0,0.0,...,1.6,,7.8,0,0,0,0,,0,0
2,13238787,35341790,2110-01-11 04:02:00,98.0,78.0,18.0,99.0,138.0,92.0,0.0,...,1.6,,7.8,0,0,0,0,,0,0
3,13238787,35341790,2110-01-11 05:21:00,98.0,78.0,18.0,99.0,138.0,92.0,0.0,...,1.6,,7.8,0,0,0,0,,0,0
4,13238787,35341790,2110-01-11 05:21:00,98.0,78.0,18.0,99.0,138.0,92.0,0.0,...,1.6,,7.8,0,0,0,0,,0,0


In [15]:
# group patients by "stay_id" and take 2000 random patients
patient_ids = df_cleaned["stay_id"].unique()
# make sure there are 1000 septic and 1000 nonseptic patients
#extract stay_ids with sepsis_dx_any == 1
septic_patient_ids = df_cleaned[df_cleaned["sepsis_dx_any"] == 1]["stay_id"].unique()
# count of septic patients
count_septic = len(septic_patient_ids)
print("Number of septic patients:", count_septic)
nonseptic_patient_ids = df_cleaned[df_cleaned["sepsis_dx_any"] == 0]["stay_id"].unique()
# take 2 times septic number of random patients from the nonseptic patients
random_nonseptic_patient_ids = np.random.choice(nonseptic_patient_ids, size=count_septic, replace=False)
random_patient_ids = np.concatenate([septic_patient_ids, random_nonseptic_patient_ids])
#shuffle the ids and filter df to only include these patients
np.random.shuffle(random_patient_ids)
df_small = df_cleaned[df_cleaned["stay_id"].isin(random_patient_ids)]
print("Number of patients in small df:", df_small["stay_id"].nunique())

Number of septic patients: 2339
Number of patients in small df: 4678


In [16]:
print(df_small.isna().sum())
df_small["time_since_adm"] = (pd.to_datetime(df_small["charttime"]) - pd.to_datetime(df_small["intime"])).dt.total_seconds() / 3600.0
print(df_small.columns)

subject_id                          0
stay_id                             0
charttime                           0
temperature                         0
heartrate                           0
resprate                            0
o2sat                               0
sbp                                 0
dbp                                 0
pain                                0
rhythm_flag                         0
med_rn                              0
gsn_rn                              0
gsn                                 0
is_antibiotic                       0
ndc                                 0
etc_rn                              0
etccode                             0
hadm_id_x                           0
intime                              0
outtime                             0
race                                0
disposition                         0
temperature_stay                    0
heartrate_stay                      0
resprate_stay                       0
o2sat_stay  

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_small["time_since_adm"] = (pd.to_datetime(df_small["charttime"]) - pd.to_datetime(df_small["intime"])).dt.total_seconds() / 3600.0


In [17]:
df_small.isna().sum()

subject_id                          0
stay_id                             0
charttime                           0
temperature                         0
heartrate                           0
resprate                            0
o2sat                               0
sbp                                 0
dbp                                 0
pain                                0
rhythm_flag                         0
med_rn                              0
gsn_rn                              0
gsn                                 0
is_antibiotic                       0
ndc                                 0
etc_rn                              0
etccode                             0
hadm_id_x                           0
intime                              0
outtime                             0
race                                0
disposition                         0
temperature_stay                    0
heartrate_stay                      0
resprate_stay                       0
o2sat_stay  

In [18]:
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Optional, Tuple

# ----------------------------
# Config (adjust paths as needed)
# ----------------------------
OUTPUT_PATH_STATIC = "../data/MIMIC-ED/cox_static_random_train.csv"
OUTPUT_PATH_TVC    = "../data/MIMIC-ED/cox_timevarying_train.csv"
OUTPUT_PATH_STATIC_STACKED = "../data/MIMIC-ED/cox_static_landmark_stacked_train.csv"
RNG_SEED = 42

# New schema
ID_COLS   = ["subject_id", "hadm_id", 'stay_id', 'hadm_id_x']
TIME_COLS = ["charttime", "intime", "outtime", "sepsis_onset_time"]

# Labels / leak-prone columns to exclude from covariates
LABEL_COLS = [
    "sepsis_dx_any", "sepsis_dx",
    "sirs_count", "sirs_ge2",
    "is_sepsis_onset", "is_sepsis",
    # Keep both event indicators & dx out of covariates
]

# Survival target columns
DURATION_COL = "duration"  # hours
EVENT_COL    = "event"     # 1 if event observed, else 0

# ----------------------------
# Helpers
# ----------------------------
def _to_dt(s):
    """Parse to pandas datetime with coercion."""
    return pd.to_datetime(s, errors="coerce")

def hours_between(a, b):
    """Return positive hours (b - a) or NaN if any is NaT."""
    if pd.isna(a) or pd.isna(b):
        return np.nan
    return max(0.0, (b - a).total_seconds() / 3600.0)

def _feature_columns(df: pd.DataFrame) -> list[str]:
    """Numeric covariates only, excluding IDs/time/labels."""
    drop = set([c for c in ID_COLS if c in df.columns] +
               [c for c in TIME_COLS if c in df.columns] +
               [c for c in LABEL_COLS if c in df.columns])
    print("is in drop set?", "time_since_adm" in drop)
    num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
    return [c for c in num_cols if c not in drop]

# def _stay_event_info(g: pd.DataFrame):
#     """
#     Return (event_flag:int, event_time:Timestamp|NaT, censor_time:Timestamp|NaT).
#       - event_flag = int(g['is_sepsis'].max() >= 1)
#       - event_time = sepsis_onset_time if event_flag==1 else NaT
#       - censor_time = outtime if censored (and available) else last observed charttime
#     """
#     # Prefer a stable, stay-level interpretation
#     is_event = int((g.get("is_sepsis", pd.Series([0])).astype(int).max()) >= 1)

#     # Choose a representative row for stay-level timestamps if repeated
#     # Use first non-null across the group
#     ev_times = _to_dt(g.get("sepsis_onset_time"))
#     event_time = ev_times.dropna().min() if is_event == 1 else pd.NaT

#     out_times = _to_dt(g.get("outtime"))
#     out_time = out_times.dropna().min() if len(out_times.dropna()) else pd.NaT

#     last_obs = _to_dt(g["charttime"]).max()

#     if is_event == 1:
#         censor_time = pd.NaT  # not used for event stays
#     else:
#         censor_time = out_time if pd.notna(out_time) else last_obs

#     return is_event, event_time, censor_time
def _stay_event_info(g: pd.DataFrame) -> Tuple[int, Optional[pd.Timestamp], Optional[pd.Timestamp]]:
    """
    Return (event_flag, event_time, censor_time) where event_time/censor_time
    are either a pandas Timestamp or None.
    """
    is_event = int((g.get("is_sepsis", pd.Series([0])).astype(int).max()) >= 1)

    ev_times = _to_dt(g.get("sepsis_onset_time"))
    event_time = ev_times.dropna().min() if is_event == 1 else None
    if pd.isna(event_time):
        event_time = None

    out_times = _to_dt(g.get("outtime"))
    out_time = out_times.dropna().min() if len(out_times.dropna()) else None

    last_obs = _to_dt(g["charttime"]).max()
    if pd.isna(last_obs):
        last_obs = None

    if is_event == 1:
        censor_time = None
    else:
        censor_time = out_time if out_time is not None else last_obs

    return is_event, event_time, censor_time


# def _choose_random_baseline(g: pd.DataFrame, terminal_time: pd.Timestamp | pd.NaT, rng: np.random.Generator):
#     """
#     For event stays: choose a random row with charttime < event_time.
#     For censored stays: choose a random row with charttime < censor_time (to avoid zero duration).
#     Returns the chosen row (Series) or None if not possible.
#     """
#     ct = _to_dt(g["charttime"])
#     if pd.notna(terminal_time):
#         candidates = g.loc[ct < terminal_time]
#     else:
#         # If terminal_time is NaT (shouldn't happen for censored), fallback to any row
#         candidates = g

#     if candidates.empty:
#         return None
#     sel = candidates.iloc[rng.integers(0, len(candidates))]
#     return sel

def _choose_random_baseline(g: pd.DataFrame, terminal_time: Optional[pd.Timestamp], rng: np.random.Generator):
    """
    For event stays: choose a random row with charttime < event_time.
    For censored stays: choose a random row with charttime < censor_time.
    """
    ct = _to_dt(g["charttime"])
    if terminal_time is not None:
        candidates = g.loc[ct < terminal_time]
    else:
        # If terminal_time is missing, there is nothing to measure duration to.
        return None

    if candidates.empty:
        return None
    sel = candidates.iloc[rng.integers(0, len(candidates))]
    return sel


# ----------------------------
# Static Cox (one random row per stay)
# ----------------------------
def make_static_random_snapshot(
    df: pd.DataFrame,
    *,
    id_col: str = "stay_id",
    duration_col: str = DURATION_COL,
    event_col: str = EVENT_COL,
    seed: int = RNG_SEED,
) -> pd.DataFrame:
    """
    Leakage-safe CoxPH table using new schema:
      - Event stays: one random pre-event snapshot (charttime < sepsis_onset_time)
      - Censored stays: one random snapshot strictly before censor_time (outtime if present, else last charttime)
      - duration = hours from snapshot to event or censoring
      - event = 1 if event observed, else 0
    """
    assert id_col in df.columns, f"Missing id column: {id_col}"

    # Ensure datetimes & sort
    for c in ["charttime", "intime", "outtime", "sepsis_onset_time"]:
        if c in df.columns:
            df[c] = _to_dt(df[c])
    df = df.dropna(subset=["charttime"])
    df = df.sort_values([id_col, "charttime"], kind="mergesort")

    rng = np.random.default_rng(seed)
    feats = _feature_columns(df)

    rows = []
    dropped_no_candidate = 0
    dropped_nonpos_duration = 0

    for sid, g in df.groupby(id_col, sort=False):
        g = g.reset_index(drop=True)
        event_flag, event_time, censor_time = _stay_event_info(g)

        terminal_time = event_time if event_flag == 1 else censor_time
        if pd.isna(terminal_time):
            # No valid terminal time and/or no observations
            dropped_no_candidate += 1
            continue

        sel = _choose_random_baseline(g, terminal_time, rng)
        if sel is None:
            dropped_no_candidate += 1
            continue

        t0 = sel["charttime"]
        dur = hours_between(t0, terminal_time)
        if not np.isfinite(dur) or dur <= 0:
            dropped_nonpos_duration += 1
            continue

        row = {duration_col: float(dur), event_col: int(event_flag)}
        for c in feats:
            row[c] = sel[c]
        rows.append(row)

    out = pd.DataFrame(rows)

    if not out.empty:
        covs = [c for c in out.columns if c not in [duration_col, event_col]]
        out = out[[duration_col, event_col] + covs]
        out[duration_col] = out[duration_col].astype(float)
        out[event_col] = out[event_col].astype(int)

        if not set(out[event_col].unique()).issubset({0, 1}):
            raise ValueError("event must be binary (0/1).")

    print(f"[OK] Built static Cox dataset: n={len(out)} "
          f"(dropped_no_candidate={dropped_no_candidate}, dropped_nonpos_duration={dropped_nonpos_duration})")
    return out

# ----------------------------
# Survival-stacked (landmark) Cox
# ----------------------------
def make_static_landmark_stack(
    df: pd.DataFrame,
    *,
    id_col: str = "stay_id",
    duration_col: str = DURATION_COL,
    event_col: str = EVENT_COL,
) -> pd.DataFrame:
    """
    For each stay, create one row for EVERY landmark row (leakage-safe):
      - Event stays: all rows with charttime < sepsis_onset_time
      - Censored stays: all rows with charttime < censor_time (outtime else last charttime)
      - duration = hours from landmark charttime to terminal time
      - event = 1 if event observed, else 0
    """
    assert id_col in df.columns, f"Missing id column: {id_col}"

    # Ensure datetimes & sort
    for c in ["charttime", "intime", "outtime", "sepsis_onset_time"]:
        if c in df.columns:
            df[c] = _to_dt(df[c])
    df = df.dropna(subset=["charttime"])
    df = df.sort_values([id_col, "charttime"], kind="mergesort")

    feats = _feature_columns(df)

    rows = []
    n_dropped_nonpos = 0
    n_no_candidates = 0

    for sid, g in df.groupby(id_col, sort=False):
        g = g.reset_index(drop=True)
        event_flag, event_time, censor_time = _stay_event_info(g)
        terminal_time = event_time if event_flag == 1 else censor_time

        if pd.isna(terminal_time):
            n_no_candidates += 1
            continue

        ct = g["charttime"]
        candidates = g.loc[ct < terminal_time]

        if candidates.empty:
            n_no_candidates += 1
            continue

        for _, sel in candidates.iterrows():
            dur = hours_between(sel["charttime"], terminal_time)
            if not np.isfinite(dur) or dur <= 0:
                n_dropped_nonpos += 1
                continue

            row = {duration_col: float(dur), event_col: int(event_flag)}
            row[id_col] = sid
            for c in feats:
                row[c] = sel[c]
            rows.append(row)

    out = pd.DataFrame(rows)

    out = pd.DataFrame(rows)

    if not out.empty:
        covs = [c for c in out.columns if c not in [id_col, duration_col, event_col]]  # ✅ include id_col here
        out = out[[id_col, duration_col, event_col] + covs]                            # ✅ put id first
        out[duration_col] = out[duration_col].astype(float)
        out[event_col] = out[event_col].astype(int)

        if not set(out[event_col].unique()).issubset({0, 1}):
            raise ValueError("event must be binary (0/1).")

    print("[OK] Built survival-stacked (landmark) Cox dataset: "
          f"n={len(out)} | dropped_nonpos={n_dropped_nonpos} | no_candidates={n_no_candidates}")
    return out

# ----------------------------
# Time-varying (Andersen–Gill) long format
# ----------------------------
def make_time_varying_long(
    df: pd.DataFrame,
    *,
    id_col: str = "stay_id",
    start_col: str = "start",
    stop_col: str = "stop",
    event_col: str = "event",
    epsilon_seconds: int = 1,
) -> pd.DataFrame:
    """
    lifelines.CoxTimeVaryingFitter-ready:
      - Intervals are [start, stop) in HOURS since the first observed charttime in the stay.
      - Covariates are taken from the row at 'start'.
      - Event assignment uses right-closed rule: start <= event_time < stop.
      - If event_time equals/after the last observed charttime, create a tiny terminal interval
        [last_time, last_time + epsilon] with event=1. For censored stays, no terminal event row.
    """
    required = ["charttime"]
    for c in required:
        if c not in df.columns:
            raise ValueError(f"Missing required column: {c}")

    df = df.copy()
    for c in ["charttime", "intime", "outtime", "sepsis_onset_time"]:
        if c in df.columns:
            df[c] = _to_dt(df[c])
    df = df.dropna(subset=["charttime"])
    df = df.sort_values([id_col, "charttime"], kind="mergesort")

    feats = _feature_columns(df)

    rows = []
    n_skipped_short = 0
    n_eps_used = 0

    for sid, g in df.groupby(id_col, sort=False):
        g = g.drop_duplicates(subset=["charttime"], keep="last").reset_index(drop=True)
        if len(g) == 0:
            continue

        event_flag, event_time, censor_time = _stay_event_info(g)
        base_t = g.loc[0, "charttime"]
        times = g["charttime"].tolist()

        # Build intervals between consecutive unique times
        fired = False
        for i in range(len(times) - 1):
            t_start = times[i]
            t_stop  = times[i + 1]

            start_hr = hours_between(base_t, t_start)
            stop_hr  = hours_between(base_t, t_stop)
            if not np.isfinite(start_hr) or not np.isfinite(stop_hr) or stop_hr <= start_hr:
                # Defensive guard
                continue

            ev = 0
            if pd.notna(event_time) and (t_start <= event_time < t_stop):
                ev = 1
                fired = True

            row = {
                id_col: sid,
                start_col: start_hr,
                stop_col:  stop_hr,
                event_col: ev,
            }
            for c in feats:
                row[c] = g.loc[i, c]
            rows.append(row)

            if ev == 1:
                break  # stop after first event

        # No event inside intervals but event exists at/after last observation
        if (not fired) and pd.notna(event_time):
            last_time = times[-1]
            if event_time >= last_time:
                # Build terminal epsilon interval to carry the event
                t_stop = last_time + pd.Timedelta(seconds=epsilon_seconds)
                n_eps_used += 1

                start_hr = hours_between(base_t, last_time)
                stop_hr  = hours_between(base_t, t_stop)
                if stop_hr > start_hr:
                    row = {
                        id_col: sid,
                        start_col: start_hr,
                        stop_col:  stop_hr,
                        event_col: 1 if (last_time <= event_time < t_stop) else 0,
                    }
                    for c in feats:
                        row[c] = g.iloc[-1][c]
                    rows.append(row)
                else:
                    n_skipped_short += 1
        # Censored stays: nothing extra; already wrote all non-event intervals

    out = pd.DataFrame(rows)

    if not out.empty:
        bad = out[out[stop_col] <= out[start_col]]
        if len(bad):
            print(f"[WARN] Dropping {len(bad)} non-positive intervals after construction.")
            out = out[out[stop_col] > out[start_col]]

        if not set(out[event_col].unique()).issubset({0, 1}):
            raise ValueError("event must be binary in time-varying data.")

    print(f"[OK] Time-varying dataset: rows={len(out)} | used_epsilon={n_eps_used} | skipped_short={n_skipped_short}")
    return out

# ----------------------------
# Main
# ----------------------------
def main():
    # Create output directory
    Path(OUTPUT_PATH_STATIC).parent.mkdir(parents=True, exist_ok=True)

    df = df_small

    # Ensure required columns exist
    need = {"stay_id", "charttime", "is_sepsis"}
    missing = need - set(df.columns)
    if missing:
        raise FileNotFoundError(f"Input missing required columns: {missing}")

    # Build STATIC Cox table (leakage-safe)
    cox_df = make_static_random_snapshot(df, id_col="stay_id")
    cox_df.to_csv(OUTPUT_PATH_STATIC, index=False)
    print(f"✅ Wrote {OUTPUT_PATH_STATIC} (cols: {cox_df.shape[1]}, rows: {cox_df.shape[0]})")

    # TIME-VARYING table for CoxTimeVaryingFitter
    tv_df = make_time_varying_long(df, id_col="stay_id")
    tv_df.to_csv(OUTPUT_PATH_TVC, index=False)
    print(f"✅ Wrote {OUTPUT_PATH_TVC} (cols: {tv_df.shape[1]}, rows: {tv_df.shape[0]})")

    # SURVIVAL-STACKED (landmark) Cox table
    cox_stack = make_static_landmark_stack(df, id_col="stay_id")
    cox_stack.to_csv(OUTPUT_PATH_STATIC_STACKED, index=False)
    print(f"✅ Wrote {OUTPUT_PATH_STATIC_STACKED} (cols: {cox_stack.shape[1]}, rows: {cox_stack.shape[0]})")

if __name__ == "__main__":
    main()


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[c] = _to_dt(df[c])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[c] = _to_dt(df[c])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[c] = _to_dt(df[c])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats

is in drop set? False
[OK] Built static Cox dataset: n=4016 (dropped_no_candidate=662, dropped_nonpos_duration=0)
✅ Wrote ../data/MIMIC-ED/cox_static_random_train.csv (cols: 39, rows: 4016)
is in drop set? False
[OK] Time-varying dataset: rows=23296 | used_epsilon=68 | skipped_short=0
✅ Wrote ../data/MIMIC-ED/cox_timevarying_train.csv (cols: 41, rows: 23296)
is in drop set? False


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[c] = _to_dt(df[c])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[c] = _to_dt(df[c])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[c] = _to_dt(df[c])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats

[OK] Built survival-stacked (landmark) Cox dataset: n=49199 | dropped_nonpos=0 | no_candidates=662
✅ Wrote ../data/MIMIC-ED/cox_static_landmark_stacked_train.csv (cols: 40, rows: 49199)


In [19]:
df_cox_random = pd.read_csv("../data/MIMIC-ED/cox_static_random_train.csv")
df_cox_timevarying = pd.read_csv("../data/MIMIC-ED/cox_timevarying_train.csv")
df_cox_stacked = pd.read_csv("../data/MIMIC-ED/cox_static_landmark_stacked_train.csv")

In [20]:
df_cox_random.columns

Index(['duration', 'event', 'temperature', 'heartrate', 'resprate', 'o2sat',
       'sbp', 'dbp', 'pain', 'rhythm_flag', 'med_rn', 'gsn_rn', 'gsn',
       'is_antibiotic', 'ndc', 'etc_rn', 'etccode', 'temperature_stay',
       'heartrate_stay', 'resprate_stay', 'o2sat_stay', 'sbp_stay', 'dbp_stay',
       'acuity', 'is_white', 'is_black', 'is_asian', 'is_hispanic',
       'is_other_race', 'gender_F', 'gender_M', 'arrival_transport_AMBULANCE',
       'arrival_transport_HELICOPTER', 'arrival_transport_OTHER',
       'arrival_transport_UNKNOWN', 'arrival_transport_WALK IN', 'lactate',
       'wbc', 'time_since_adm'],
      dtype='object')

In [21]:
# for boolean columns, change them to int
bool_cols_random = df_cox_random.select_dtypes(include=['bool']).columns
df_cox_random[bool_cols_random] = df_cox_random[bool_cols_random].astype(int)
bool_cols_timevarying = df_cox_timevarying.select_dtypes(include=['bool']).columns
df_cox_timevarying[bool_cols_timevarying] = df_cox_timevarying[bool_cols_timevarying].astype(int)
bool_cols_stacked = df_cox_stacked.select_dtypes(include=['bool']).columns
df_cox_stacked[bool_cols_stacked] = df_cox_stacked[bool_cols_stacked].astype(int).astype(int)

In [22]:
# save these back
df_cox_random.to_csv("../data/MIMIC-ED/cox_static_random_train.csv", index=False)
df_cox_timevarying.to_csv("../data/MIMIC-ED/cox_timevarying_train.csv", index=False)
df_cox_stacked.to_csv("../data/MIMIC-ED/cox_static_landmark_stacked_train.csv", index=False)

In [23]:
# get most common gsns for septic patients
most_common_gsns = [16599.0, 4490.0, 61716.0, 43952.0, 66419.0]

# one hot encode gsn
for gsn in most_common_gsns:
    df_cox_random[f'gsn_{gsn}'] = (df_cox_random['gsn'] == gsn).astype(int)
    df_cox_timevarying[f'gsn_{gsn}'] = (df_cox_timevarying['gsn'] == gsn).astype(int)
    df_cox_stacked[f'gsn_{gsn}'] = (df_cox_stacked['gsn'] == gsn).astype(int)


df_cox_random = df_cox_random.drop(columns=['gsn', 'gsn_rn', 'med_rn', 'is_antibiotic', 'ndc', 'etc_rn', 'etccode'])
df_cox_timevarying = df_cox_timevarying.drop(columns=['gsn', 'gsn_rn', 'med_rn', 'is_antibiotic', 'ndc', 'etc_rn', 'etccode'])
df_cox_stacked = df_cox_stacked.drop(columns=['gsn', 'gsn_rn', 'med_rn', 'is_antibiotic', 'ndc', 'etc_rn', 'etccode'])


cols_to_drop = [col for col in df_cox_random.columns if col.endswith('_stay')]
df_cox_random = df_cox_random.drop(columns=cols_to_drop)
cols_to_drop = [col for col in df_cox_timevarying.columns if col.endswith('_stay')]
df_cox_timevarying = df_cox_timevarying.drop(columns=cols_to_drop)
cols_to_drop = [col for col in df_cox_stacked.columns if col.endswith('_stay')]
df_cox_stacked = df_cox_stacked.drop(columns=cols_to_drop)

# for col in df_cox_random.columns:
#     if col.endswith('_stay'):
#         # drop column
# for col in df_cox_timevarying.columns:
#     if col.endswith('_stay'):
#         new_col = col[:-5]
#         df_cox_timevarying.rename(columns={col: new_col}, inplace=True)
# for col in df_cox_stacked.columns:
#     if col.endswith('_stay'):
#         new_col = col[:-5]
#         df_cox_stacked.rename(columns={col: new_col}, inplace=True)


# reorder columns
required_cols = ['temperature', 'heartrate', 'resprate', 'o2sat', 'sbp', 'dbp', 'pain', 'rhythm_flag', 'is_white', 'is_black', 'is_asian', 'is_hispanic', 'is_other_race', 'gender_F', 'gender_M', 'arrival_transport_AMBULANCE', 'arrival_transport_HELICOPTER', 'arrival_transport_OTHER', 'arrival_transport_UNKNOWN', 'arrival_transport_WALK IN', 'lactate', 'wbc', 'time_since_adm', 'gsn_16599.0', 'gsn_43952.0', 'gsn_4490.0', 'gsn_66419.0', 'gsn_61716.0']
random_order = ["duration", "event"] + required_cols
df_cox_random = df_cox_random[random_order]
time_varying_order = ["stay_id", "start", "stop", "event"] + required_cols
df_cox_timevarying = df_cox_timevarying[time_varying_order]
stacked_order = ["stay_id", "duration", "event"] + required_cols
df_cox_stacked = df_cox_stacked[stacked_order]

In [24]:
# reorder columns
required_cols = ['temperature', 'heartrate', 'resprate', 'o2sat', 'sbp', 'dbp', 'pain', 'rhythm_flag', 'is_white', 'is_black', 'is_asian', 'is_hispanic', 'is_other_race', 'gender_F', 'gender_M', 'arrival_transport_AMBULANCE', 'arrival_transport_HELICOPTER', 'arrival_transport_OTHER', 'arrival_transport_UNKNOWN', 'arrival_transport_WALK IN', 'lactate', 'wbc', 'time_since_adm', 'gsn_16599.0', 'gsn_43952.0', 'gsn_4490.0', 'gsn_66419.0', 'gsn_61716.0']
order = ["duration", "event"] + required_cols
df_cox_random = df_cox_random[order]
print(df_cox_random.columns)

Index(['duration', 'event', 'temperature', 'heartrate', 'resprate', 'o2sat',
       'sbp', 'dbp', 'pain', 'rhythm_flag', 'is_white', 'is_black', 'is_asian',
       'is_hispanic', 'is_other_race', 'gender_F', 'gender_M',
       'arrival_transport_AMBULANCE', 'arrival_transport_HELICOPTER',
       'arrival_transport_OTHER', 'arrival_transport_UNKNOWN',
       'arrival_transport_WALK IN', 'lactate', 'wbc', 'time_since_adm',
       'gsn_16599.0', 'gsn_43952.0', 'gsn_4490.0', 'gsn_66419.0',
       'gsn_61716.0'],
      dtype='object')


In [25]:
# save these back
df_cox_random.to_csv("../data/MIMIC-ED/cox_static_random_train.csv", index=False)
df_cox_timevarying.to_csv("../data/MIMIC-ED/cox_timevarying_train.csv", index=False)
df_cox_stacked.to_csv("../data/MIMIC-ED/cox_static_landmark_stacked_train.csv", index=False)

In [42]:
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Optional, Tuple

# ----------------------------------
# Config (adjust paths as needed)
# ----------------------------------
OUTPUT_PATH_DT = "../data/MIMIC-ED/discrete_time_30min_train.csv"
RNG_SEED = 42  # for any optional randomness you may add later

# New schema
ID_COLS   = ["subject_id", "hadm_id", "stay_id", "hadm_id_x"]
TIME_COLS = ["charttime", "intime", "outtime", "sepsis_onset_time"]

# Labels / leak-prone columns to exclude from covariates
LABEL_COLS = [
    "sepsis_dx_any", "sepsis_dx",
    "sirs_count", "sirs_ge2",
    "is_sepsis_onset", "is_sepsis",
    # Keep both event indicators & dx out of covariates
]

# Survival target columns
DURATION_COL = "duration"  # hours (not written in output here, but kept for parity)
EVENT_COL    = "event"     # 1 if event observed in this discrete bin, else 0

# ----------------------------------
# Helpers (matching your Cox utils)
# ----------------------------------
def _to_dt(s):
    """Parse to pandas datetime with coercion."""
    return pd.to_datetime(s, errors="coerce")

def hours_between(a, b):
    """Return positive hours (b - a) or NaN if any is NaT."""
    if pd.isna(a) or pd.isna(b):
        return np.nan
    return max(0.0, (b - a).total_seconds() / 3600.0)

def _feature_columns(df: pd.DataFrame) -> list[str]:
    """Numeric/boolean covariates only, excluding IDs/time/labels."""
    drop = set([c for c in ID_COLS if c in df.columns] +
               [c for c in TIME_COLS if c in df.columns] +
               [c for c in LABEL_COLS if c in df.columns])
    num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
    return [c for c in num_cols if c not in drop]

def _stay_event_info(g: pd.DataFrame) -> Tuple[int, Optional[pd.Timestamp], Optional[pd.Timestamp]]:
    """
    Return (event_flag, event_time, censor_time) where event_time/censor_time
    are either a pandas Timestamp or None.
    """
    is_event = int((g.get("is_sepsis", pd.Series([0])).astype(int).max()) >= 1)

    ev_times = _to_dt(g.get("sepsis_onset_time"))
    event_time = ev_times.dropna().min() if is_event == 1 else None
    if pd.isna(event_time):
        event_time = None

    out_times = _to_dt(g.get("outtime"))
    out_time = out_times.dropna().min() if len(out_times.dropna()) else None

    last_obs = _to_dt(g["charttime"]).max()
    if pd.isna(last_obs):
        last_obs = None

    if is_event == 1:
        censor_time = None
    else:
        censor_time = out_time if out_time is not None else last_obs

    return is_event, event_time, censor_time

# ----------------------------------
# Discrete-time person-period builder
# ----------------------------------
def make_discrete_time_person_period(
    df: pd.DataFrame,
    *,
    id_col: str = "stay_id",
    bin_minutes: int = 30,
    hmax_hours: Optional[float] = None,
    include_landmark_cols: bool = True,
) -> pd.DataFrame:
    """
    Build a leakage-safe discrete-time survival (person-period) dataset.

    For each stay and each landmark row (a row in the raw table):
      - Determine terminal_time = event_time (if event stay) else censor_time.
      - Keep ONLY landmarks with charttime < terminal_time (pre-terminal, leakage-safe).
      - Compute remaining time rem_hours = terminal_time - charttime.
      - Discretize into fixed-width bins of `bin_minutes` (default 30 min).
        * Event stays: create rows for k = 1..min(ceil(rem/bin), ceil(hmax/bin) if hmax)
          and set event=1 only for the bin k == ceil(rem/bin) if the event happens
          within the capped horizon; otherwise event=0 for all emitted bins.
        * Censored stays: create rows for k = 1..min(floor(rem/bin), floor(hmax/bin) if hmax),
          with event=0.
      - Covariates for each emitted row are exactly the landmark's covariates
        (NO carry-over or change-rate features).

    Output columns:
      - id_col
      - t_bin (1-based bin index from the landmark)
      - event (0/1, event happens inside this bin)
      - [all feature columns from _feature_columns(df)]
      - (optional) landmark meta: landmark_charttime, hours_since_intime

    Notes:
      - Requires columns: id_col, charttime, is_sepsis, and (optionally) intime/outtime/sepsis_onset_time.
      - If `hmax_hours` is provided, the horizon is capped to that many hours.

    """
    required = {id_col, "charttime"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"Missing required columns: {missing}")

    # Prepare datetimes & sort deterministically
    work = df.copy()
    for c in ["charttime", "intime", "outtime", "sepsis_onset_time"]:
        if c in work.columns:
            work[c] = _to_dt(work[c])

    work = work.dropna(subset=["charttime"])
    work = work.sort_values([id_col, "charttime"], kind="mergesort")

    feats = _feature_columns(work)

    bin_hours = float(bin_minutes) / 60.0
    rows = []

    for sid, g in work.groupby(id_col, sort=False):
        g = g.reset_index(drop=True)
        # Terminal info (event_time or censor_time)
        event_flag, event_time, censor_time = _stay_event_info(g)
        terminal_time = event_time if event_flag == 1 else censor_time

        if terminal_time is None:
            # No terminal reference => cannot form durations
            continue

        # Candidate landmarks strictly before terminal_time
        ct = g["charttime"]
        candidates = g.loc[ct < terminal_time]
        if candidates.empty:
            continue

        # Base admission time (optional hours_since_intime calc)
        intime = g["intime"].iloc[0] if "intime" in g.columns else None

        for _, sel in candidates.iterrows():
            t0 = sel["charttime"]
            rem = hours_between(t0, terminal_time)
            if not np.isfinite(rem) or rem <= 0:
                continue

            # Horizon cap in hours
            if hmax_hours is not None:
                rem_capped = min(rem, hmax_hours)
            else:
                rem_capped = rem

            # Determine number of bins to emit and the "event bin" if applicable
            if event_flag == 1:
                # Event stays: event occurs in the bin where cumulative time crosses rem
                event_bin = int(np.ceil(rem / bin_hours))
                # Rows emitted up to the capped horizon
                last_bin = int(np.ceil(rem_capped / bin_hours))
                # If rem was capped below the true event bin, then this landmark contributes no event
                for k in range(1, max(1, last_bin) + 1):
                    y = 1 if (k == event_bin and (hmax_hours is None or k <= int(np.ceil(hmax_hours / bin_hours)))) else 0
                    row = {
                        id_col: sid,
                        "t_bin": k,
                        EVENT_COL: int(y),
                    }
                    # Attach covariates from the landmark row
                    for c in feats:
                        row[c] = sel.get(c, np.nan)

                    if include_landmark_cols:
                        row["landmark_charttime"] = t0
                        if intime is not None and pd.notna(intime):
                            row["hours_since_intime"] = hours_between(intime, t0)
                    rows.append(row)
            else:
                # Censored stays: no event rows; emit bins fully contained in remaining time
                last_bin = int(np.floor(rem_capped / bin_hours))
                for k in range(1, last_bin + 1):
                    row = {
                        id_col: sid,
                        "t_bin": k,
                        EVENT_COL: 0,
                    }
                    for c in feats:
                        row[c] = sel.get(c, np.nan)

                    if include_landmark_cols:
                        row["landmark_charttime"] = t0
                        if intime is not None and pd.notna(intime):
                            row["hours_since_intime"] = hours_between(intime, t0)
                    rows.append(row)

    out = pd.DataFrame(rows)

    # Final tidy-up / dtype hygiene
    if not out.empty:
        out[EVENT_COL] = out[EVENT_COL].astype(int)
        if "hours_since_intime" in out.columns:
            out["hours_since_intime"] = out["hours_since_intime"].astype(float)

        # sanity: event should be binary
        if not set(out[EVENT_COL].unique()).issubset({0, 1}):
            raise ValueError("event must be binary (0/1) in discrete-time data.")

        # Reorder for readability: id, landmark, t_bin, event, then features
        base_cols = [c for c in [id_col, "landmark_charttime", "hours_since_intime", "t_bin", EVENT_COL]
                     if c in out.columns]
        feat_cols = [c for c in _feature_columns(out) if c not in base_cols]
        out = out[base_cols + feat_cols]

    print(f"[OK] Discrete-time (bin={bin_minutes}m) dataset: rows={len(out)}")
    return out

def make_discrete_time_person_period_fast(
    df: pd.DataFrame,
    *,
    id_col: str = "stay_id",
    bin_minutes: int = 30,
    hmax_hours: float | None = None,
    include_landmark_cols: bool = True,
) -> pd.DataFrame:
    req = {id_col, "charttime", "is_sepsis"}
    miss = req - set(df.columns)
    if miss:
        raise ValueError(f"Missing required columns: {miss}")

    work = df.copy()
    for c in ["charttime", "intime", "outtime", "sepsis_onset_time"]:
        if c in work.columns:
            work[c] = pd.to_datetime(work[c], errors="coerce")
    work = work.dropna(subset=["charttime"])
    work = work.sort_values([id_col, "charttime"], kind="mergesort")

    feats = _feature_columns(work)
    bin_h = float(bin_minutes) / 60.0

    # ---------- stay-level terminal times (vectorized) ----------
    # event per stay
    is_event = (work.groupby(id_col)["is_sepsis"]
                .transform(lambda s: int(s.astype(int).max() >= 1)))
    work["_is_event"] = is_event

    # event_time = min non-null sepsis_onset_time per stay (if event)
    if "sepsis_onset_time" in work.columns:
        ev_time = (work.groupby(id_col)["sepsis_onset_time"]
                   .transform(lambda s: s.dropna().min() if s.notna().any() else pd.NaT))
    else:
        ev_time = pd.Series(pd.NaT, index=work.index)

    # out_time = min non-null outtime per stay
    if "outtime" in work.columns:
        out_time = (work.groupby(id_col)["outtime"]
                    .transform(lambda s: s.dropna().min() if s.notna().any() else pd.NaT))
    else:
        out_time = pd.Series(pd.NaT, index=work.index)

    # last observed time per stay
    last_obs = work.groupby(id_col)["charttime"].transform("max")

    # terminal time per row
    term_time = ev_time.where(is_event == 1, out_time.fillna(last_obs))
    work["_terminal_time"] = term_time

    # leakage-safe landmarks
    mask_ok = work["_terminal_time"].notna() & (work["charttime"] < work["_terminal_time"])
    L = work.loc[mask_ok, [id_col, "charttime", "_terminal_time", "_is_event"] + feats].copy()

    if L.empty:
        return pd.DataFrame(columns=[id_col, "t_bin", EVENT_COL] + feats)

    # ---------- remaining time & bin counts (vectorized) ----------
    rem_h = (L["_terminal_time"] - L["charttime"]).dt.total_seconds() / 3600.0
    if hmax_hours is not None:
        rem_cap = np.minimum(rem_h, hmax_hours)
    else:
        rem_cap = rem_h

    # event and censor bin counts
    event_bin = np.ceil(rem_h / bin_h).astype("int64")          # where event exists
    last_bin_ev = np.ceil(rem_cap / bin_h).astype("int64")      # emitted for event stays
    last_bin_ce = np.floor(rem_cap / bin_h).astype("int64")     # emitted for censored stays

    is_ev = L["_is_event"].to_numpy().astype(bool)
    # number of bins to emit per landmark row
    n_bins = np.where(is_ev, np.maximum(1, last_bin_ev), np.maximum(0, last_bin_ce))

    # drop rows emitting zero bins (censored too close to terminal)
    keep = n_bins > 0
    if not np.any(keep):
        return pd.DataFrame(columns=[id_col, "t_bin", EVENT_COL] + feats)
    L = L.loc[keep].reset_index(drop=True)
    n_bins = n_bins[keep]
    is_ev = is_ev[keep]
    event_bin = event_bin[keep]
    last_bin_ev = last_bin_ev[keep]
    last_bin_ce = last_bin_ce[keep]

    # ---------- explode bins in NumPy ----------
    # build t_bin per row: 1..n_bins[i]
    tbin = np.concatenate([np.arange(1, nb + 1, dtype=np.int32) for nb in n_bins])

    # repeat metadata/covariates
    reps = n_bins.astype(int)
    ids_rep = np.repeat(L[id_col].to_numpy(), reps)
    feat_mat = np.column_stack([np.repeat(L[c].to_numpy(), reps) for c in feats]) if feats else None

    # event flag per emitted row
    # event rows: k == event_bin and k <= last_bin_ev
    ev_idx = np.repeat(is_ev, reps)
    eb_rep  = np.repeat(event_bin, reps)
    lbev_rep = np.repeat(last_bin_ev, reps)
    lbc_rep = np.repeat(last_bin_ce, reps)  # for completeness; not used directly

    event_vec = np.zeros_like(tbin, dtype=np.int8)
    # event only possible on event stays
    mask_e = ev_idx & (tbin == eb_rep) & (tbin <= lbev_rep)
    event_vec[mask_e] = 1

    out_cols = [id_col, "t_bin", EVENT_COL] + feats
    data = {
        id_col: ids_rep,
        "t_bin": tbin,
        EVENT_COL: event_vec,
    }
    for j, c in enumerate(feats):
        data[c] = feat_mat[:, j] if feat_mat is not None else np.array([], dtype=float)

    out = pd.DataFrame(data)

    if include_landmark_cols:
        # optional metadata to help analysis/debug; kept aligned via repeat
        lm_time = np.repeat(L["charttime"].to_numpy(), reps)
        data_meta = {"landmark_charttime": lm_time}
        if "intime" in work.columns:
            # hours since intime if available
            intime_stay = work.groupby(id_col)["intime"].transform("first")
            L_intime = intime_stay.loc[L.index].to_numpy()
            hs = (lm_time - np.repeat(L_intime, reps)).astype("timedelta64[s]").astype("float64") / 3600.0
            data_meta["hours_since_intime"] = hs
        out = pd.concat([pd.DataFrame(data_meta), out], axis=1)

    # tidy dtypes / order
    out[EVENT_COL] = out[EVENT_COL].astype(int)
    base = [c for c in ["landmark_charttime", "hours_since_intime"] if c in out.columns]
    out = out[base + [id_col, "t_bin", EVENT_COL] + feats]

    return out


# ----------------------------------
# Main
# ----------------------------------
def main():
    # Example expects a dataframe named df with required columns present.
    # Replace this with your actual load pipeline.
    # df = pd.read_csv("path/to/your/source.csv", parse_dates=["charttime","intime","outtime","sepsis_onset_time"])
    # For illustration only, we assume `df` is already in scope.
    df = df_small

    # Ensure required columns exist
    need = {"stay_id", "charttime", "is_sepsis"}
    missing = need - set(df.columns)
    if missing:
        raise FileNotFoundError(f"Input missing required columns: {missing}")

    # Build discrete-time table with 30-minute bins; optionally cap horizon (e.g., hmax_hours=6.0)
    dt_df = make_discrete_time_person_period(
        df,
        id_col="stay_id",
        bin_minutes=30,
        hmax_hours=None,            # set e.g. 6.0 to cap at 6 hours
        include_landmark_cols=True, # keep helpful landmark metadata
    )

    Path(OUTPUT_PATH_DT).parent.mkdir(parents=True, exist_ok=True)
    dt_df.to_csv(OUTPUT_PATH_DT, index=False)
    print(f"✅ Wrote {OUTPUT_PATH_DT} (cols: {dt_df.shape[1]}, rows: {dt_df.shape[0]})")

if __name__ == "__main__":
    main()


[OK] Discrete-time (bin=30m) dataset: rows=406271
✅ Wrote ../data/MIMIC-ED/discrete_time_30min_train.csv (cols: 42, rows: 406271)


In [48]:
df_discrete = pd.read_csv("../data/MIMIC-ED/discrete_time_30min_train.csv")

In [44]:
df_discrete.columns

Index(['stay_id', 'landmark_charttime', 'hours_since_intime', 't_bin', 'event',
       'temperature', 'heartrate', 'resprate', 'o2sat', 'sbp', 'dbp', 'pain',
       'rhythm_flag', 'med_rn', 'gsn_rn', 'gsn', 'is_antibiotic', 'ndc',
       'etc_rn', 'etccode', 'temperature_stay', 'heartrate_stay',
       'resprate_stay', 'o2sat_stay', 'sbp_stay', 'dbp_stay', 'acuity',
       'is_white', 'is_black', 'is_asian', 'is_hispanic', 'is_other_race',
       'gender_F', 'gender_M', 'arrival_transport_AMBULANCE',
       'arrival_transport_HELICOPTER', 'arrival_transport_OTHER',
       'arrival_transport_UNKNOWN', 'arrival_transport_WALK IN', 'lactate',
       'wbc', 'time_since_adm'],
      dtype='object')

In [45]:
# for boolean columns, change them to int
bool_cols_random = df_discrete.select_dtypes(include=['bool']).columns
df_discrete[bool_cols_random] = df_discrete[bool_cols_random].astype(int)

In [46]:
# save these back
df_discrete.to_csv("../data/MIMIC-ED/discrete_time_30min_train.csv", index=False)

In [49]:
# get most common gsns for septic patients
most_common_gsns = [16599.0, 4490.0, 61716.0, 43952.0, 66419.0]

# one hot encode gsn
for gsn in most_common_gsns:
    df_discrete[f'gsn_{gsn}'] = (df_discrete['gsn'] == gsn).astype(int)


df_discrete = df_discrete.drop(columns=['gsn', 'gsn_rn', 'med_rn', 'is_antibiotic', 'ndc', 'etc_rn', 'etccode'])



cols_to_drop = [col for col in df_discrete.columns if col.endswith('_stay')]
df_discrete = df_discrete.drop(columns=cols_to_drop)

# reorder columns
required_cols = ['temperature', 'heartrate', 'resprate', 'o2sat', 'sbp', 'dbp', 'pain', 'rhythm_flag', 'is_white', 'is_black', 'is_asian', 'is_hispanic', 'is_other_race', 'gender_F', 'gender_M', 'arrival_transport_AMBULANCE', 'arrival_transport_HELICOPTER', 'arrival_transport_OTHER', 'arrival_transport_UNKNOWN', 'arrival_transport_WALK IN', 'lactate', 'wbc', 'time_since_adm', 'gsn_16599.0', 'gsn_43952.0', 'gsn_4490.0', 'gsn_66419.0', 'gsn_61716.0']
discrete_order = ['stay_id', 'landmark_charttime', 'hours_since_intime', 't_bin', 'event'] + required_cols
df_discrete = df_discrete[discrete_order]

In [51]:
# save these back
df_discrete.to_csv("../data/MIMIC-ED/discrete_time_30min_train.csv", index=False)