In [None]:

import os
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler

DATA_DIR = "/content"

PATIENTS_CSV = os.path.join(DATA_DIR, "PATIENTS_sorted.csv")
ADMISSIONS_CSV = os.path.join(DATA_DIR, "ADMISSIONS_sorted.csv")
ICUSTAYS_CSV = os.path.join(DATA_DIR, "ICUSTAYS_sorted.csv")
DIAG_CSV = os.path.join(DATA_DIR, "DIAGNOSES_ICD_sorted.csv")
CHARTEVENTS_CSV = os.path.join(DATA_DIR, "CHARTEVENTS.csv")
LABEVENTS_CSV = os.path.join(DATA_DIR, "/content/D_LABITEMS.csv")


for p in [PATIENTS_CSV, ADMISSIONS_CSV, ICUSTAYS_CSV, DIAG_CSV]:
    print(p, "->", "FOUND" if os.path.exists(p) else "MISSING")


/content/PATIENTS_sorted.csv -> FOUND
/content/ADMISSIONS_sorted.csv -> FOUND
/content/ICUSTAYS_sorted.csv -> FOUND
/content/DIAGNOSES_ICD_sorted.csv -> FOUND


In [None]:

patients = pd.read_csv(PATIENTS_CSV, usecols=["SUBJECT_ID","GENDER","DOB"], parse_dates=["DOB"], low_memory=False)
admissions = pd.read_csv(ADMISSIONS_CSV, usecols=["SUBJECT_ID","HADM_ID","ADMITTIME","DISCHTIME","HOSPITAL_EXPIRE_FLAG"], parse_dates=["ADMITTIME","DISCHTIME"], low_memory=False)
icustays = pd.read_csv(ICUSTAYS_CSV, usecols=["SUBJECT_ID","HADM_ID","ICUSTAY_ID","INTIME","OUTTIME"], parse_dates=["INTIME","OUTTIME"], low_memory=False)

print("patients:", patients.shape)
print("admissions:", admissions.shape)
print("icustays:", icustays.shape)


patients: (10000, 3)
admissions: (12911, 5)
icustays: (13436, 5)


In [None]:

diag = pd.read_csv(DIAG_CSV, dtype={"ICD9_CODE": str, "HADM_ID": object, "SUBJECT_ID": object}, low_memory=False)


diag['ICD9_CODE'] = diag['ICD9_CODE'].fillna("").astype(str).str.strip().str.replace(r"\.", "", regex=True)


ICD_PREFIXES = ['428']


mask = pd.Series(False, index=diag.index)
for pfx in ICD_PREFIXES:
    mask = mask | diag['ICD9_CODE'].str.startswith(pfx, na=False)

hf_hadm_ids = diag.loc[mask, 'HADM_ID'].dropna().unique()
hf_subject_ids = diag.loc[mask, 'SUBJECT_ID'].dropna().unique()

print("Matched HADM_IDs:", len(hf_hadm_ids))
print("Matched SUBJECT_IDs:", len(hf_subject_ids))


Matched HADM_IDs: 2967
Matched SUBJECT_IDs: 2163


In [None]:
import math

diag['HADM_ID'] = diag['HADM_ID'].astype(str)


admissions_raw = pd.read_csv(ADMISSIONS_CSV, usecols=["SUBJECT_ID","HADM_ID","ADMITTIME","DISCHTIME","HOSPITAL_EXPIRE_FLAG"],
                             dtype={"HADM_ID": str, "ADMITTIME": str, "DISCHTIME": str}, low_memory=False)


hf_adm_raw = admissions_raw[admissions_raw['HADM_ID'].isin([str(x) for x in hf_hadm_ids])].copy()
print("hf_adm_raw rows:", hf_adm_raw.shape[0])


patients_raw = pd.read_csv(PATIENTS_CSV, usecols=["SUBJECT_ID","GENDER","DOB"], dtype={"SUBJECT_ID": str, "DOB": str}, low_memory=False)

hf_adm_raw rows: 2967


In [None]:
hf_adm_raw['HADM_ID'] = hf_adm_raw['HADM_ID'].astype(str)
hf_adm_raw['SUBJECT_ID'] = hf_adm_raw['SUBJECT_ID'].astype(str)

patients_raw['SUBJECT_ID'] = patients_raw['SUBJECT_ID'].astype(str)
patients_raw['DOB'] = patients_raw['DOB'].astype(str)


hf_data = hf_adm_raw.merge(patients_raw, on="SUBJECT_ID", how="left")
print("Merged rows:", hf_data.shape[0])


def parse_ymd(date_str):
    if not isinstance(date_str, str) or date_str.strip() == "" or pd.isna(date_str):
        return None
    date_part = date_str.split(" ")[0]
    parts = date_part.split("-")
    if len(parts) < 3:
        return None
    try:
        y = int(parts[0]); m = int(parts[1]); d = int(parts[2])
        return (y, m, d)
    except Exception:
        return None

def safe_age(admit_str, dob_str):
    a = parse_ymd(admit_str)
    b = parse_ymd(dob_str)
    if (a is None) or (b is None):
        return float("nan")
    ay, am, ad = a
    by, bm, bd = b

    if (ay < by) or (ay == by and (am < bm or (am == bm and ad < bd))):
        return float("nan")
    age = ay - by
    if (am, ad) < (bm, bd):
        age -= 1
    if age > 89:
        return 90.0
    if age < 0:
        return float("nan")
    return float(age)


hf_data['AGE'] = hf_data.apply(lambda r: safe_age(r.get('ADMITTIME', None), r.get('DOB', None)), axis=1)


nan_mask = hf_data['AGE'].isna()
if nan_mask.any():
    def fallback_age(admit_str, dob_str):
        a = parse_ymd(admit_str); b = parse_ymd(dob_str)
        if (a is None) or (b is None):
            return float("nan")
        return float(max(0, a[0] - b[0]))
    hf_data.loc[nan_mask, 'AGE'] = hf_data.loc[nan_mask].apply(lambda r: fallback_age(r.get('ADMITTIME', None), r.get('DOB', None)), axis=1)


hf_data['AGE'] = hf_data['AGE'].clip(lower=0, upper=120)


print("AGE stats -> count:", hf_data['AGE'].count(), " mean:", round(hf_data['AGE'].mean(),2),
      " min:", hf_data['AGE'].min(), " max:", hf_data['AGE'].max())
display_columns = ['SUBJECT_ID','HADM_ID','ADMITTIME','DOB','AGE','GENDER','HOSPITAL_EXPIRE_FLAG']
print(hf_data[display_columns].head(50))


Merged rows: 2967
AGE stats -> count: 2967  mean: 70.3  min: 0.0  max: 90.0
   SUBJECT_ID HADM_ID            ADMITTIME                  DOB   AGE GENDER  \
0           3  145834  2101-10-20 19:08:00  2025-04-11 00:00:00  76.0      M   
1           9  150750  2149-11-09 13:06:00  2108-01-26 00:00:00  41.0      M   
2          21  109451  2134-09-11 12:17:00  2047-04-04 00:00:00  87.0      M   
3          26  197661  2126-05-06 15:16:00  2054-05-04 00:00:00  72.0      M   
4          30  104557  2172-10-14 14:17:00  1872-10-14 00:00:00  90.0      M   
5          34  144319  2191-02-23 05:23:00  1886-07-18 00:00:00  90.0      M   
6          34  115799  2186-07-18 16:46:00  1886-07-18 00:00:00  90.0      M   
7          37  188670  2183-08-21 16:48:00  2114-09-17 00:00:00  68.0      M   
8          38  185910  2166-08-10 00:28:00  2090-08-31 00:00:00  75.0      M   
9          42  119203  2116-04-26 18:58:00  2055-02-25 00:00:00  61.0      M   
10         49  190539  2186-11-21 07:15:00  

In [None]:

import os
import pandas as pd

CHARTEVENTS_CSV = "/content/CHARTEVENTS.csv"
LABEVENTS_CSV   = "/content/LABEVENTS_sorted.csv"
D_ITEMS_CSV     = "/content/D_ITEMS.csv"
D_LABITEMS_CSV  = "/content/D_LABITEMS.csv"
CHUNK_SIZE = 200000

# --- 1. Load dictionary files ---
d_items = pd.read_csv(D_ITEMS_CSV, low_memory=False)
d_labitems = pd.read_csv(D_LABITEMS_CSV, low_memory=False)

# normalize column names
d_items.columns = d_items.columns.str.lower()
d_labitems.columns = d_labitems.columns.str.lower()

def find_itemids(keyword, df, col="label"):
    """Return list of itemids matching keyword (case-insensitive)."""
    return df[df[col].str.contains(keyword, case=False, na=False)]['itemid'].tolist()

# --- 2. Find ITEMIDs ---
# Restrict ITEMIDs
HR_ITEMS = [220045]    # Heart Rate
SBP_ITEMS = [220179]   # Systolic BP
CR_ITEMS  = [50912]    # Serum Creatinine
      # from D_LABITEMS

print("HeartRate ITEMIDs:", HR_ITEMS)
print("SysBP ITEMIDs:", SBP_ITEMS)
print("Creatinine ITEMIDs:", CR_ITEMS)



# --- 2. Helper: extract latest values by subject/hadm/itemid ---
def extract_latest(path, itemids, source="chartevents"):
    if not os.path.exists(path) or not itemids:
        return pd.DataFrame()

    if source == "chartevents":
        usecols = ['subject_id','hadm_id','itemid','charttime','valuenum']
    else:  # labevents
        usecols = ['SUBJECT_ID','HADM_ID','ITEMID','CHARTTIME','VALUENUM','VALUEUOM','FLAG']

    # figure out which charttime spelling exists
    with open(path, "r") as f:
        header = f.readline().strip().split(",")
    charttime_col = "CHARTTIME" if "CHARTTIME" in header else "charttime"

    frames = []
    reader = pd.read_csv(
        path,
        usecols=usecols,
        parse_dates=[charttime_col],   # âœ… dynamic
        chunksize=CHUNK_SIZE,
        low_memory=False
    )

    for chunk in reader:
        # normalize column names
        chunk.columns = chunk.columns.str.lower()

        # ensure itemid is int
        chunk['itemid'] = chunk['itemid'].astype(int, errors='ignore')

        # filter
        chunk = chunk[chunk['itemid'].isin(itemids) & chunk['valuenum'].notna()]
        if chunk.empty:
            continue

        # take latest record per subject/hadm/itemid
        chunk = (chunk
                 .sort_values('charttime')
                 .groupby(['subject_id','hadm_id','itemid'], as_index=False)
                 .last())
        frames.append(chunk)

    if not frames:
        return pd.DataFrame(columns=[c.lower() for c in usecols])

    all_df = pd.concat(frames, ignore_index=True)
    all_df = (all_df
              .sort_values('charttime')
              .groupby(['subject_id','hadm_id','itemid'], as_index=False)
              .last())
    return all_df





# --- 3. Run extraction ---
print("\nExtracting CHARTEVENTS (HR)...")
chart_df = extract_latest(CHARTEVENTS_CSV, HR_ITEMS + SBP_ITEMS, source="chartevents")

print("chart_df rows:", chart_df.shape[0])

print("\nExtracting LABEVENTS (Creatinine)...")
lab_df = extract_latest(LABEVENTS_CSV, CR_ITEMS, source="labevents")
print("lab_df rows:", lab_df.shape[0])



if not chart_df.empty:
    ch_pivot = chart_df.pivot_table(
        index=['subject_id','hadm_id'], columns='itemid', values='valuenum', aggfunc='last'
    ).reset_index()
    if HR_ITEMS:
        ch_pivot = ch_pivot.rename(columns={HR_ITEMS[0]: 'heart_rate'})
    if SBP_ITEMS:
        ch_pivot = ch_pivot.rename(columns={SBP_ITEMS[0]: 'sys_bp'})
else:
    ch_pivot = pd.DataFrame(columns=['subject_id','hadm_id','heart_rate','sys_bp'])

if not lab_df.empty:
    lab_pivot = lab_df.pivot_table(
        index=['subject_id','hadm_id'], columns='itemid', values='valuenum', aggfunc='last'
    ).reset_index()
    if CR_ITEMS:
        lab_pivot = lab_pivot.rename(columns={CR_ITEMS[0]: 'creatinine'})
else:
    lab_pivot = pd.DataFrame(columns=['subject_id','hadm_id','creatinine'])

print("chart pivot shape:", ch_pivot.shape)
print("lab pivot shape:", lab_pivot.shape)
display(ch_pivot.head())
display(lab_pivot.head())


HeartRate ITEMIDs: [220045]
SysBP ITEMIDs: [220179]
Creatinine ITEMIDs: [50912]

Extracting CHARTEVENTS (HR)...
chart_df rows: 140

Extracting LABEVENTS (Creatinine)...
lab_df rows: 10391
chart pivot shape: (70, 4)
lab pivot shape: (10391, 3)


itemid,subject_id,hadm_id,heart_rate,sys_bp
0,40124,126179,80.0,149.0
1,40124,146893,84.0,139.0
2,40177,198480,50.0,90.0
3,40204,175237,90.0,159.0
4,40277,127703,71.0,112.0


itemid,subject_id,hadm_id,creatinine
0,3,145834.0,1.5
1,4,185777.0,0.5
2,6,107064.0,0.9
3,9,150750.0,2.0
4,11,194540.0,0.6


In [None]:

CHARTEVENTS_CSV = "/content/CHARTEVENTS.csv"
LABEVENTS_CSV   = "/content/LABEVENTS_sorted.csv"
D_ITEMS_CSV     = "/content/D_ITEMS.csv"
D_LABITEMS_CSV  = "/content/D_LABITEMS.csv"
CHUNK_SIZE = 200_000


d_items = pd.read_csv(D_ITEMS_CSV, low_memory=False)
d_labitems = pd.read_csv(D_LABITEMS_CSV, low_memory=False)
d_items.columns = d_items.columns.str.lower()
d_labitems.columns = d_labitems.columns.str.lower()

def find_itemids(keyword, df, col="label"):
    return df[df[col].str.contains(keyword, case=False, na=False)]['itemid'].unique().tolist()

HR_ITEMS  = find_itemids("heart rate", d_items) or [220045]
SBP_ITEMS = find_itemids("systolic", d_items)   or [220179]
CR_ITEMS  = find_itemids("creatinine", d_labitems) or [50912]

print("Detected HR ITEMIDs:", HR_ITEMS)
print("Detected SBP ITEMIDs:", SBP_ITEMS)
print("Detected Creatinine ITEMIDs:", CR_ITEMS)

def extract_latest(path, itemids, source="chartevents"):
    """Robustly extract latest valuenum rows for given itemids.
       Handles uppercase/lowercase header names and compares itemid as string."""
    if (not os.path.exists(path)) or (not itemids):
        return pd.DataFrame(columns=['subject_id','hadm_id','itemid','charttime','valuenum'])


    if source == "chartevents":
        want_lower = ['subject_id','hadm_id','itemid','charttime','valuenum','value']
    else:
        want_lower = ['subject_id','hadm_id','itemid','charttime','valuenum','value','valueuom','flag']


    header = pd.read_csv(path, nrows=0).columns.tolist()

    header_map = {h.lower(): h for h in header}


    actual_usecols = [header_map[c] for c in want_lower if c in header_map]

    charttime_actual = header_map.get('charttime') or header_map.get('charttime'.upper()) or header_map.get('charttime'.lower())

    if not charttime_actual:

        for candidate in ['charttime','chart_time','chart_time','time','date']:
            if candidate in header_map:
                charttime_actual = header_map[candidate]
                break

    frames = []
    # reading in chunks using actual_usecols
    reader = pd.read_csv(path, usecols=actual_usecols, chunksize=CHUNK_SIZE, low_memory=False)
    itemid_str_set = set(str(i) for i in itemids)

    for chunk in reader:

        chunk.columns = chunk.columns.str.lower()
        chunk = chunk.copy()


        if 'itemid' not in chunk.columns or 'subject_id' not in chunk.columns:
            continue


        if 'charttime' in chunk.columns:
            chunk['charttime'] = pd.to_datetime(chunk['charttime'], errors='coerce')
        else:
            chunk['charttime'] = pd.NaT


        chunk['itemid_str'] = chunk['itemid'].astype(str).str.strip()

        # picking valuenum if present else fallback to value
        if 'valuenum' in chunk.columns:
            chunk['valuenum'] = pd.to_numeric(chunk['valuenum'], errors='coerce')
        elif 'value' in chunk.columns:
            chunk['valuenum'] = pd.to_numeric(chunk['value'], errors='coerce')
        else:
            chunk['valuenum'] = np.nan


        mask = chunk['itemid_str'].isin(itemid_str_set) & chunk['valuenum'].notna()
        chunk = chunk.loc[mask, ['subject_id','hadm_id','itemid','charttime','valuenum']]

        if chunk.empty:
            continue


        chunk['subject_id'] = pd.to_numeric(chunk['subject_id'], errors='coerce')
        chunk['hadm_id'] = pd.to_numeric(chunk['hadm_id'], errors='coerce')


        chunk = chunk.sort_values('charttime').groupby(['subject_id','hadm_id','itemid'], as_index=False).last()
        frames.append(chunk)

    if not frames:
        return pd.DataFrame(columns=['subject_id','hadm_id','itemid','charttime','valuenum'])

    df = pd.concat(frames, ignore_index=True)

    df = df.sort_values('charttime').groupby(['subject_id','hadm_id','itemid'], as_index=False).last()


    df['subject_id'] = pd.to_numeric(df['subject_id'], errors='coerce').astype('Int64')
    df['hadm_id'] = pd.to_numeric(df['hadm_id'], errors='coerce').astype('Int64')
    return df[['subject_id','hadm_id','itemid','charttime','valuenum']]


print("\nExtracting CHARTEVENTS (HR + SBP)...")
chart_df = extract_latest(CHARTEVENTS_CSV, HR_ITEMS + SBP_ITEMS, source="chartevents")
print("chart_df rows:", chart_df.shape[0])

print("\nExtracting LABEVENTS (Creatinine)...")
lab_df = extract_latest(LABEVENTS_CSV, CR_ITEMS, source="labevents")
print("lab_df rows:", lab_df.shape[0])

# Collapse multiple ITEMIDs into single measurement per admission
def collapse_itemids_to_measure(df, itemids, out_col):
    if df.empty:
        return pd.DataFrame(columns=['subject_id','hadm_id', out_col])
    df2 = df[df['itemid'].isin(itemids)].copy()
    if df2.empty:
        return pd.DataFrame(columns=['subject_id','hadm_id', out_col])

    df2['subject_id'] = df2['subject_id'].astype('Int64')
    df2['hadm_id'] = df2['hadm_id'].astype('Int64')

    pivot = df2.pivot(index=['subject_id','hadm_id'], columns='itemid', values='valuenum').reset_index()

    value_cols = [c for c in pivot.columns if c not in ('subject_id','hadm_id')]
    if value_cols:
        pivot[out_col] = pivot[value_cols].median(axis=1, skipna=True)
    else:
        pivot[out_col] = pd.NA
    return pivot[['subject_id','hadm_id', out_col]]

hr_pivot  = collapse_itemids_to_measure(chart_df, HR_ITEMS,  'heart_rate')
sbp_pivot = collapse_itemids_to_measure(chart_df, SBP_ITEMS, 'sys_bp')
cr_pivot  = collapse_itemids_to_measure(lab_df,  CR_ITEMS,  'creatinine')

# combining hr+sbp
if not hr_pivot.empty or not sbp_pivot.empty:
    ch_pivot = hr_pivot.merge(sbp_pivot, on=['subject_id','hadm_id'], how='outer')
else:
    ch_pivot = pd.DataFrame(columns=['subject_id','hadm_id','heart_rate','sys_bp'])


for df in (ch_pivot, cr_pivot):
    if not df.empty:
        df['subject_id'] = pd.to_numeric(df['subject_id'], errors='coerce').astype('Int64')
        df['hadm_id'] = pd.to_numeric(df['hadm_id'], errors='coerce').astype('Int64')

# admission-level merge first
merged = ch_pivot.merge(cr_pivot, on=['subject_id','hadm_id'], how='left')

# patient-level latest creatinine (fallback)
if not cr_pivot.empty:
    patient_latest = cr_pivot.sort_values(['subject_id','hadm_id']).groupby('subject_id', as_index=False)['creatinine'].last()
else:
    patient_latest = pd.DataFrame(columns=['subject_id','creatinine'])


merged = merged.merge(patient_latest, on='subject_id', how='left', suffixes=('','_patient'))
merged['creatinine'] = merged['creatinine'].fillna(merged['creatinine_patient'])
merged = merged.drop(columns=['creatinine_patient'], errors='ignore')

# optional fill with median
if merged['creatinine'].notna().any():
    median_cr = merged['creatinine'].median()
    merged['creatinine'] = merged['creatinine'].fillna(median_cr)

# Diagnostics
print("ch_pivot rows (admissions with HR/SBP):", len(ch_pivot))
print("cr_pivot rows (admissions with creatinine):", len(cr_pivot))
print("patient_latest rows (patients with at least one creatinine):", len(patient_latest))
print("Final merged rows:", merged.shape[0])
print("Admissions with creatinine (non-null):", merged['creatinine'].notna().sum())

display(merged.head(40))


hf_data = merged.copy()


Detected HR ITEMIDs: [211, 3494, 220045, 220046, 220047]
Detected SBP ITEMIDs: [6, 51, 442, 455, 480, 482, 484, 492, 666, 3313, 3315, 3317, 3319, 3321, 3323, 3325, 7643, 6701, 228152, 224167, 227243, 226850, 226852, 220050, 220059, 220179, 225309]
Detected Creatinine ITEMIDs: [50841, 50912, 51021, 51032, 51052, 51067, 51070, 51073, 51080, 51081, 51082, 51099, 51106]

Extracting CHARTEVENTS (HR + SBP)...
chart_df rows: 466

Extracting LABEVENTS (Creatinine)...
lab_df rows: 13865
ch_pivot rows (admissions with HR/SBP): 125
cr_pivot rows (admissions with creatinine): 10393
patient_latest rows (patients with at least one creatinine): 7654
Final merged rows: 125
Admissions with creatinine (non-null): 125


Unnamed: 0,subject_id,hadm_id,heart_rate,sys_bp,creatinine
0,10006,142345,66.0,150.0,3.7
1,10011,105331,72.0,41.0,0.9
2,10013,165520,0.0,45.0,25.85
3,10017,199207,76.0,121.0,0.4
4,10019,177759,89.0,47.5,7.5
5,10026,103770,66.0,145.0,0.5
6,10027,199395,111.0,34.0,1.4
7,10029,132349,84.0,152.0,3.0
8,10032,140372,108.0,92.0,59.85
9,10033,157235,86.0,136.0,0.8


In [None]:
TIME_WINDOW_HOURS = 48   # max allowed distance (hours) to consider a lab "near" the chart time; set to None to ignore window

# 1) Computing representative chart_time per admission
# (chart_df came from extract_latest and contains charttime per (subject,hadm,itemid))
if 'chart_df' in globals() and not chart_df.empty:
    ch_time = (chart_df
               .groupby(['subject_id','hadm_id'], as_index=False)['charttime']
               .max()
              )

    ch_time['subject_id'] = pd.to_numeric(ch_time['subject_id'], errors='coerce').astype('Int64')
    ch_time['hadm_id']    = pd.to_numeric(ch_time['hadm_id'], errors='coerce').astype('Int64')

    ch_pivot = ch_pivot.merge(ch_time, on=['subject_id','hadm_id'], how='left')
else:
    # if chart_df missing, creating a dummy charttime column (NaT)
    ch_pivot['charttime'] = pd.NaT

# 2) Preparing lab_df times and types
if not lab_df.empty:
    lab_df = lab_df.copy()
    lab_df['subject_id'] = pd.to_numeric(lab_df['subject_id'], errors='coerce').astype('Int64')
    lab_df['hadm_id']    = pd.to_numeric(lab_df['hadm_id'], errors='coerce').astype('Int64')
    lab_df['charttime']  = pd.to_datetime(lab_df['charttime'], errors='coerce')

    if 'valuenum' not in lab_df.columns and 'value' in lab_df.columns:
        lab_df['valuenum'] = pd.to_numeric(lab_df['value'], errors='coerce')

# 3) Helper function to find closest lab per subject for a single row
def find_closest_creatinine(subject_id, ref_time):
    """Return (valuenum, time_diff_hours) of closest creatinine for subject_id to ref_time.
       Returns (np.nan, np.nan) if none found within TIME_WINDOW_HOURS (if set) or no labs exist."""
    if pd.isna(subject_id) or lab_df.empty:
        return (np.nan, np.nan)
    sub_labs = lab_df[lab_df['subject_id'] == subject_id]
    if sub_labs.empty:
        return (np.nan, np.nan)

    if pd.isna(ref_time):

        latest = sub_labs.sort_values('charttime').iloc[-1]
        return (latest['valuenum'], np.nan)
    sub_labs = sub_labs.dropna(subset=['charttime'])
    if sub_labs.empty:
        return (np.nan, np.nan)
    diffs = (sub_labs['charttime'] - ref_time).abs().dt.total_seconds() / 3600.0  # hours

    if TIME_WINDOW_HOURS is not None:
        valid = diffs <= float(TIME_WINDOW_HOURS)
        if not valid.any():
            return (np.nan, np.nan)
        idx = diffs[valid].idxmin()
    else:
        idx = diffs.idxmin()
    row = sub_labs.loc[idx]
    return (row['valuenum'], diffs.loc[idx])

# 4) Filling admission-level missing creatinine with closest-in-time labs
target_df = merged


mask_missing = target_df['creatinine'].isna()
if mask_missing.any():

    fill_vals = []
    for i, row in target_df.loc[mask_missing, ['subject_id','charttime']].iterrows():
        subj = row['subject_id']
        ref_time = row['charttime'] if 'charttime' in row.index else pd.NaT
        val, diff_hrs = find_closest_creatinine(subj, ref_time)
        fill_vals.append((i, val, diff_hrs))

    for i, val, diff in fill_vals:
        if not (pd.isna(val)):
            target_df.at[i, 'creatinine'] = val

            target_df.at[i, 'creatinine_time_diff_hrs'] = diff

# 5) Final fallback: patient_latest then median
if 'patient_latest' not in globals() or patient_latest is None:

    if not cr_pivot.empty:
        patient_latest = cr_pivot.sort_values(['subject_id','hadm_id']).groupby('subject_id', as_index=False)['creatinine'].last()
    else:
        patient_latest = pd.DataFrame(columns=['subject_id','creatinine'])

# Filling remaining NaNs with patient_latest then median
still_missing = target_df['creatinine'].isna()
if still_missing.any() and not patient_latest.empty:
    target_df = target_df.merge(patient_latest, on='subject_id', how='left', suffixes=('','_patient'))
    target_df['creatinine'] = target_df['creatinine'].fillna(target_df['creatinine_patient'])
    target_df = target_df.drop(columns=['creatinine_patient'], errors='ignore')

# optional median fill
if target_df['creatinine'].notna().any():
    median_cr = target_df['creatinine'].median()
    target_df['creatinine'] = target_df['creatinine'].fillna(median_cr)


print("Admissions with creatinine (non-null) after time-based fill:", target_df['creatinine'].notna().sum())
display(target_df.head(30))


hf_data = target_df.copy()


Admissions with creatinine (non-null) after time-based fill: 125


Unnamed: 0,subject_id,hadm_id,heart_rate,sys_bp,creatinine
0,10006,142345,66.0,150.0,3.7
1,10011,105331,72.0,41.0,0.9
2,10013,165520,0.0,45.0,25.85
3,10017,199207,76.0,121.0,0.4
4,10019,177759,89.0,47.5,7.5
5,10026,103770,66.0,145.0,0.5
6,10027,199395,111.0,34.0,1.4
7,10029,132349,84.0,152.0,3.0
8,10032,140372,108.0,92.0,59.85
9,10033,157235,86.0,136.0,0.8


In [None]:
import pandas as pd
import numpy as np

CHARTEVENTS_CSV = "/content/CHARTEVENTS.csv"
LABEVENTS_CSV   = "/content/LABEVENTS_sorted.csv"


HR_ITEMS  = globals().get('HR_ITEMS', [220045])
SBP_ITEMS = globals().get('SBP_ITEMS', [220179])
CR_ITEMS  = globals().get('CR_ITEMS', [50912])

CHUNK_SIZE = 200_000

def inspect_admission(subject_id, hadm_id,
                      charpath=CHARTEVENTS_CSV, labpath=LABEVENTS_CSV,
                      hr_items=HR_ITEMS, sbp_items=SBP_ITEMS, cr_items=CR_ITEMS,
                      chunksize=CHUNK_SIZE):
    sid = str(subject_id)
    hid = str(hadm_id)


    char_frames = []
    for chunk in pd.read_csv(charpath, chunksize=chunksize, low_memory=False):
        c = chunk.copy()
        c.columns = c.columns.str.lower()

        if not {'subject_id','hadm_id','itemid'}.issubset(c.columns):
            continue

        c['subject_id'] = c['subject_id'].astype(str)
        c['hadm_id']   = c['hadm_id'].astype(str)

        if 'valuenum' in c.columns:
            c['valuenum'] = pd.to_numeric(c['valuenum'], errors='coerce')

        mask = (c['subject_id']==sid) & (c['hadm_id']==hid) & (c['itemid'].astype(str).isin([str(x) for x in (hr_items+sbp_items)]))
        if mask.any():

            cols_keep = [col for col in ['subject_id','hadm_id','itemid','charttime','value','valuenum','valueuom','flag'] if col in c.columns]
            char_frames.append(c.loc[mask, cols_keep])
    char_rows = pd.concat(char_frames, ignore_index=True) if char_frames else pd.DataFrame()


    lab_frames = []
    for chunk in pd.read_csv(labpath, chunksize=chunksize, low_memory=False):
        l = chunk.copy()
        l.columns = l.columns.str.lower()
        if not {'subject_id','hadm_id','itemid'}.issubset(l.columns):
            continue
        l['subject_id'] = l['subject_id'].astype(str)
        l['hadm_id']   = l['hadm_id'].astype(str)
        if 'valuenum' in l.columns:
            l['valuenum'] = pd.to_numeric(l['valuenum'], errors='coerce')
        mask = (l['subject_id']==sid) & (l['hadm_id']==hid) & (l['itemid'].astype(str).isin([str(x) for x in cr_items]))
        if mask.any():
            cols_keep = [col for col in ['subject_id','hadm_id','itemid','charttime','value','valuenum','valueuom','flag'] if col in l.columns]
            lab_frames.append(l.loc[mask, cols_keep])
    lab_rows = pd.concat(lab_frames, ignore_index=True) if lab_frames else pd.DataFrame()


    print(f"\nCHARTEVENTS rows found: {len(char_rows)}")
    if not char_rows.empty:

        display_cols = [c for c in ['charttime','itemid','valuenum','value','valueuom','flag'] if c in char_rows.columns]
        print("\n--- Last CHARTEVENTS entries (vitals) ---")
        display(char_rows.sort_values('charttime').tail(30)[display_cols])
        # count zeros
        zcount = int((char_rows['valuenum']==0).sum()) if 'valuenum' in char_rows.columns else 0
        print(f"\nCount of valuenum == 0 (vitals) for this admission: {zcount}")
    else:
        print("No CHARTEVENTS rows for this admission (HR/SBP itemids)")

    print(f"\nLABEVENTS rows found: {len(lab_rows)}")
    if not lab_rows.empty:
        cols_keep = [c for c in ['charttime','itemid','valuenum','value','valueuom','flag'] if c in lab_rows.columns]
        print("\n--- LABEVENTS entries (creatinine) ---")
        display(lab_rows.sort_values('charttime').tail(30)[cols_keep])


        unit_df = (lab_rows
                   .assign(valueuom_clean = lab_rows['valueuom'].fillna('UNKNOWN').astype(str).str.strip().str.lower())
                   .groupby(['itemid','valueuom_clean'], dropna=False)
                   .agg(count=('valuenum','count'), median=('valuenum','median'), max=('valuenum','max'))
                   .reset_index()
                   .sort_values('count', ascending=False))
        print("\n--- Creatinine unit/itemid summary for this admission ---")
        display(unit_df)
    else:
        print("No LABEVENTS rows for this admission (creatinine itemids)")

    return char_rows, lab_rows, unit_df if 'unit_df' in locals() else pd.DataFrame()


char_rows, lab_rows, unit_summary = inspect_admission(10013, 165520)



CHARTEVENTS rows found: 236

--- Last CHARTEVENTS entries (vitals) ---


Unnamed: 0,charttime,itemid,valuenum,value,valueuom
221,2125-10-06 21:00:00,51,122.0,122,mmHg
67,2125-10-06 21:00:00,211,91.0,91,BPM
68,2125-10-06 21:30:00,211,89.0,89,BPM
222,2125-10-06 21:30:00,51,97.0,97,mmHg
223,2125-10-06 22:00:00,51,95.0,95,mmHg
69,2125-10-06 22:00:00,211,88.0,88,BPM
224,2125-10-06 23:00:00,51,87.0,87,mmHg
70,2125-10-06 23:00:00,211,86.0,86,BPM
225,2125-10-07 00:00:00,51,106.0,106,mmHg
71,2125-10-07 00:00:00,211,87.0,87,BPM



Count of valuenum == 0 (vitals) for this admission: 3

LABEVENTS rows found: 0
No LABEVENTS rows for this admission (creatinine itemids)


In [None]:
import numpy as np
import pandas as pd


subject_id = 10013
hadm_id    = 165520


if 'lab_df' in globals() and not lab_df.empty:
    cr_pivot = (lab_df[lab_df['itemid'].astype(str).isin([str(x) for x in CR_ITEMS])]
                .groupby(['subject_id','hadm_id'], as_index=False)
                .agg(creatinine=('valuenum', 'last')))
else:
    cr_pivot = pd.DataFrame(columns=['subject_id','hadm_id','creatinine'])

patient_latest = (cr_pivot.sort_values(['subject_id','hadm_id'])
                  .groupby('subject_id', as_index=False)['creatinine']
                  .last())


if 'ch_pivot' not in globals() or ch_pivot.empty:
    print("ch_pivot missing or empty â€” please run the chart extraction step first.")
else:
    merged_check = ch_pivot.merge(cr_pivot, on=['subject_id','hadm_id'], how='left')
    merged_check = merged_check.merge(patient_latest, on='subject_id', how='left', suffixes=('','_patient'))


    median_cr = merged_check['creatinine'].median() if merged_check['creatinine'].notna().any() else np.nan


    row_mask = (merged_check['subject_id'].astype(int) == int(subject_id)) & (merged_check['hadm_id'].astype(int) == int(hadm_id))
    if not row_mask.any():
        print("Admission not present in ch_pivot. Check subject/hadm ids or types.")
    else:
        row = merged_check.loc[row_mask].iloc[0]
        source = None
        value  = None
        if not (pd.isna(row.get('creatinine'))):
            source = 'admission_lab'
            value  = row.get('creatinine')
        elif not (pd.isna(row.get('creatinine_patient'))):
            source = 'patient_latest'
            value  = row.get('creatinine_patient')
        else:
            source = 'median_fallback'
            value  = median_cr
        print(f"Admission ({subject_id},{hadm_id}) -> creatinine value used = {value} , source = {source}")


        if 'lab_df' in globals() and not lab_df.empty:
            print("\nAll creatinine lab rows for this SUBJECT (any hadm):")
            display(lab_df[lab_df['subject_id'].astype(int)==int(subject_id)].sort_values('charttime').tail(30))
        else:
            print("\nNo lab_df loaded in memory to inspect raw lab rows.")


Admission (10013,165520) -> creatinine value used = 50.0 , source = admission_lab

All creatinine lab rows for this SUBJECT (any hadm):


Unnamed: 0,subject_id,hadm_id,itemid,charttime,valuenum
13022,10013,165520,51082,2125-10-05 03:07:00,50.0
13021,10013,165520,50912,2125-10-05 05:30:00,1.7


In [None]:

D_LABITEMS_CSV = "/content/D_LABITEMS.csv"

import pandas as pd
d_labitems = pd.read_csv(D_LABITEMS_CSV, low_memory=False)
d_labitems.columns = d_labitems.columns.str.lower()


itemids = ['50912', '51082']
print("D_LABITEMS rows for these itemids:")
display(d_labitems[d_labitems['itemid'].astype(str).isin(itemids)][['itemid','label','category','loinc_code']].drop_duplicates())


D_LABITEMS rows for these itemids:


Unnamed: 0,itemid,label,category,loinc_code
239,50912,Creatinine,Chemistry,2160-0
408,51082,"Creatinine, Urine",Chemistry,2161-8


In [None]:

HR_MIN, HR_MAX = 30, 220
SBP_MIN, SBP_MAX = 30, 300
TIME_WINDOW_HOURS = None
MEDIAN_FALLBACK = True
PRINT_N = 8



if 'lab_df' not in globals():
    raise AssertionError("lab_df not found: load LABEVENTS into `lab_df` first.")
if 'd_labitems' not in globals():
    raise AssertionError("d_labitems not found: load D_LABITEMS into `d_labitems` first.")

have_ch_pivot = 'ch_pivot' in globals() and isinstance(globals()['ch_pivot'], pd.DataFrame)
have_chart_df = 'chart_df' in globals() and isinstance(globals()['chart_df'], pd.DataFrame)
if not (have_ch_pivot or have_chart_df):
    raise AssertionError("Provide either `ch_pivot` or `chart_df` (chartevents) in the session.")


HR_ITEMS = globals().get('HR_ITEMS', [220045, 211, 220046, 220047])
SBP_ITEMS = globals().get('SBP_ITEMS', [220179, 6, 51, 455, 3313])


def build_ch_pivot_from_chart_df(chart_df):
    df = chart_df.copy()
    df.columns = df.columns.str.lower()

    for c in ['subject_id','hadm_id','itemid','charttime','valuenum','value']:
        if c not in df.columns:
            df[c] = np.nan
    df['valuenum'] = pd.to_numeric(df['valuenum'], errors='coerce')
    df['charttime'] = pd.to_datetime(df['charttime'], errors='coerce')
    df['itemid'] = df['itemid'].astype(str)

    hr_mask = df['itemid'].isin([str(x) for x in HR_ITEMS])
    sbp_mask = df['itemid'].isin([str(x) for x in SBP_ITEMS])

    df.loc[hr_mask & ((df['valuenum'] == 0) | (df['valuenum'] < HR_MIN) | (df['valuenum'] > HR_MAX)), 'valuenum'] = np.nan
    df.loc[sbp_mask & ((df['valuenum'] == 0) | (df['valuenum'] < SBP_MIN) | (df['valuenum'] > SBP_MAX)), 'valuenum'] = np.nan

    if 'charttime' in df.columns:
        df = df.sort_values('charttime')
    pivot = (df.dropna(subset=['valuenum'])
               .pivot_table(index=['subject_id','hadm_id'], columns='itemid', values='valuenum', aggfunc='last')
               .reset_index())

    hr_cols = [c for c in pivot.columns if str(c) in [str(x) for x in HR_ITEMS]]
    sbp_cols = [c for c in pivot.columns if str(c) in [str(x) for x in SBP_ITEMS]]
    if hr_cols:
        pivot['heart_rate'] = pivot[hr_cols].median(axis=1, skipna=True)
    else:
        pivot['heart_rate'] = np.nan
    if sbp_cols:
        pivot['sys_bp'] = pivot[sbp_cols].median(axis=1, skipna=True)
    else:
        pivot['sys_bp'] = np.nan

    if 'charttime' in df.columns:
        ch_time = (df.groupby(['subject_id','hadm_id'], as_index=False)['charttime'].max())
        pivot = pivot.merge(ch_time, on=['subject_id','hadm_id'], how='left')
    pivot['subject_id'] = pivot['subject_id'].astype(str)
    pivot['hadm_id'] = pivot['hadm_id'].astype(str)
    return pivot

if have_ch_pivot:
    ch = ch_pivot.copy()
    ch.columns = ch.columns.str.lower()

    if 'heart_rate' not in ch.columns or 'sys_bp' not in ch.columns:
        if have_chart_df:
            ch = build_ch_pivot_from_chart_df(chart_df)
        else:

            ch['heart_rate'] = ch.get('heart_rate', np.nan)
            ch['sys_bp'] = ch.get('sys_bp', np.nan)
else:
    ch = build_ch_pivot_from_chart_df(chart_df)


lab = lab_df.copy()
lab.columns = lab.columns.str.lower()
lab['itemid'] = lab['itemid'].astype(str)
lab['subject_id'] = lab['subject_id'].astype(str)
lab['hadm_id'] = lab['hadm_id'].astype(str)
lab['charttime'] = pd.to_datetime(lab.get('charttime', pd.NaT), errors='coerce')


dl = d_labitems.copy()
dl.columns = dl.columns.str.lower()
dl['label'] = dl.get('label', dl.get('name', pd.Series(['']*len(dl)))).astype(str).str.lower()


serum_items = dl[dl['label'].str.contains('creatinine', na=False) & ~dl['label'].str.contains('urine', na=False)]
SERUM_CR_ITEMIDS = serum_items['itemid'].astype(str).unique().tolist()

urine_items = dl[dl['label'].str.contains('creatinine', na=False) & dl['label'].str.contains('urine', na=False)]
URINE_CR_ITEMIDS = urine_items['itemid'].astype(str).unique().tolist()


serum_lab = lab[lab['itemid'].isin(SERUM_CR_ITEMIDS)].copy()
print("Serum creatinine rows (pre-clean):", len(serum_lab))


serum_lab['valuenum'] = pd.to_numeric(serum_lab.get('valuenum', np.nan), errors='coerce')

serum_lab['valueuom'] = serum_lab.get('valueuom', pd.Series(['']*len(serum_lab))).fillna('').astype(str).str.strip().str.lower()


vals = serum_lab['valuenum'].to_numpy(dtype=float)
uoms = serum_lab['valueuom'].to_numpy(dtype=object)

is_umol = np.array([('umol' in str(u)) or ('Âµmol' in str(u)) for u in uoms])
is_mgdl = np.array([('mg/dl' in str(u)) or ('mg per dl' in str(u)) for u in uoms])
unknown_mask = (~is_umol) & (~is_mgdl) & ~np.isnan(vals)

creat = np.full_like(vals, np.nan, dtype=float)

creat[is_umol & ~np.isnan(vals)] = vals[is_umol & ~np.isnan(vals)] / 88.4
creat[is_mgdl & ~np.isnan(vals)] = vals[is_mgdl & ~np.isnan(vals)]

creat[unknown_mask & (vals > 20)] = vals[unknown_mask & (vals > 20)] / 88.4
creat[unknown_mask & (vals <= 20)] = vals[unknown_mask & (vals <= 20)]

serum_lab['creat_mgdl'] = creat


serum_lab_sorted = serum_lab.sort_values(['subject_id','hadm_id','charttime'])
cr_pivot = (serum_lab_sorted.dropna(subset=['creat_mgdl'])
            .groupby(['subject_id','hadm_id'], as_index=False)
            .last()[['subject_id','hadm_id','creat_mgdl']]).rename(columns={'creat_mgdl':'creatinine'})

print("Admission-level serum creatinine rows (after convert):", len(cr_pivot))


patient_latest = (serum_lab_sorted.dropna(subset=['creat_mgdl'])
                  .groupby('subject_id', as_index=False)
                  .last()[['subject_id','creat_mgdl']].rename(columns={'creat_mgdl':'creatinine_patient'}))


if TIME_WINDOW_HOURS is not None and 'charttime' in ch.columns:

    ch_time = ch[['subject_id','hadm_id','charttime']].copy()
    ch_time['charttime'] = pd.to_datetime(ch_time['charttime'], errors='coerce')

    last_lab_time = (serum_lab_sorted.dropna(subset=['creat_mgdl'])
                     .groupby(['subject_id','hadm_id'], as_index=False)['charttime'].last().rename(columns={'charttime':'lab_charttime'}))
    cr_pivot = cr_pivot.merge(last_lab_time, on=['subject_id','hadm_id'], how='left')
    cr_pivot = cr_pivot.merge(ch_time, on=['subject_id','hadm_id'], how='left')

    cr_pivot['time_diff_hrs'] = (cr_pivot['lab_charttime'] - cr_pivot['charttime']).abs().dt.total_seconds() / 3600.0

    cr_pivot = cr_pivot[cr_pivot['time_diff_hrs'].le(float(TIME_WINDOW_HOURS)) | cr_pivot['time_diff_hrs'].isna()]

    cr_pivot = cr_pivot[['subject_id','hadm_id','creatinine']]


cr_pivot['subject_id'] = cr_pivot['subject_id'].astype(str)
cr_pivot['hadm_id']    = cr_pivot['hadm_id'].astype(str)
patient_latest['subject_id'] = patient_latest['subject_id'].astype(str)


ch = ch.copy()
ch['subject_id'] = ch['subject_id'].astype(str)
ch['hadm_id'] = ch['hadm_id'].astype(str)

merged = ch.merge(cr_pivot, on=['subject_id','hadm_id'], how='left')
merged = merged.merge(patient_latest, on='subject_id', how='left')


merged['creatinine_final'] = merged['creatinine']
merged['creatinine_source'] = pd.Series([None]*len(merged), index=merged.index, dtype=object)
merged.loc[merged['creatinine_final'].notna(), 'creatinine_source'] = 'admission_serum'

mask_patient = merged['creatinine_final'].isna() & merged['creatinine_patient'].notna()
merged.loc[mask_patient, 'creatinine_final'] = merged.loc[mask_patient, 'creatinine_patient']
merged.loc[mask_patient, 'creatinine_source'] = 'patient_latest_serum'


if MEDIAN_FALLBACK:
    median_cr = merged['creatinine_final'].median() if merged['creatinine_final'].notna().any() else np.nan
    still_missing = merged['creatinine_final'].isna()
    merged.loc[still_missing, 'creatinine_final'] = median_cr
    merged.loc[still_missing & merged['creatinine_source'].isna(), 'creatinine_source'] = 'median_fallback'


merged = merged.rename(columns={'creatinine_final':'creatinine'})


merged['heart_rate'] = pd.to_numeric(merged.get('heart_rate', np.nan), errors='coerce')
merged['sys_bp'] = pd.to_numeric(merged.get('sys_bp', np.nan), errors='coerce')
merged['hr_missing'] = merged['heart_rate'].isna()
merged['sbp_missing'] = merged['sys_bp'].isna()
merged['both_vitals_missing'] = merged['hr_missing'] & merged['sbp_missing']


merged['creatinine_flag'] = merged['creatinine_source']
merged.loc[merged['both_vitals_missing'], 'creatinine_flag'] = merged.loc[merged['both_vitals_missing'], 'creatinine_flag'].fillna('vitals_missing')


merged_clean = merged.copy()


print(" Diagnostics ")
print("Serum ITEMIDs (whitelist):", SERUM_CR_ITEMIDS[:20])
print("Urine ITEMIDs (excluded):", URINE_CR_ITEMIDS[:20])
print("Serum creatinine admissions (cr_pivot rows):", len(cr_pivot))
print("Final merged rows:", merged_clean.shape)
print("\nCreatinine source counts (top):")
print(merged_clean['creatinine_source'].value_counts(dropna=False).to_string())

print("\nSample rows (first {}):".format(PRINT_N))
display(merged_clean.sort_values(['subject_id','hadm_id']).head(PRINT_N))


if 'serum_lab' in globals() and not serum_lab.empty:
    suspicious = serum_lab[(serum_lab['valuenum'] > 20) & (serum_lab['creat_mgdl'] < 10)].copy().sort_values('charttime').head(10)
    if not suspicious.empty:
        print("\nSample suspicious conversions (likely Âµmol/L -> mg/dL applied):")
        display(suspicious[['subject_id','hadm_id','itemid','charttime','valuenum','valueuom','creat_mgdl']].head(10))




Serum creatinine rows (pre-clean): 11094
Admission-level serum creatinine rows (after convert): 10392
 Diagnostics 
Serum ITEMIDs (whitelist): ['50841', '50912', '51021', '51032', '51052', '51067', '51080', '51081', '51099']
Urine ITEMIDs (excluded): ['51070', '51073', '51082', '51106']
Serum creatinine admissions (cr_pivot rows): 10392
Final merged rows: (125, 13)

Creatinine source counts (top):
creatinine_source
median_fallback    70
admission_serum    55

Sample rows (first 8):


Unnamed: 0,subject_id,hadm_id,heart_rate,sys_bp,charttime,creatinine,creatinine_patient,creatinine.1,creatinine_source,hr_missing,sbp_missing,both_vitals_missing,creatinine_flag
0,10006,142345,66.0,150.0,2164-10-25 08:30:00,3.7,3.7,3.7,admission_serum,False,False,False,admission_serum
1,10011,105331,72.0,41.0,2126-08-28 16:10:00,0.9,0.9,0.9,admission_serum,False,False,False,admission_serum
2,10013,165520,0.0,45.0,2125-10-07 12:12:00,1.7,1.7,1.7,admission_serum,False,False,False,admission_serum
3,10017,199207,76.0,121.0,2149-05-31 20:00:00,0.4,0.4,0.4,admission_serum,False,False,False,admission_serum
4,10019,177759,89.0,47.5,2163-05-15 21:00:00,4.0,4.0,4.0,admission_serum,False,False,False,admission_serum
5,10026,103770,66.0,145.0,2195-05-19 16:00:00,0.5,0.5,0.5,admission_serum,False,False,False,admission_serum
6,10027,199395,111.0,34.0,2190-07-20 15:00:00,1.4,1.4,1.4,admission_serum,False,False,False,admission_serum
7,10029,132349,84.0,152.0,2139-09-25 16:00:00,3.0,3.0,3.0,admission_serum,False,False,False,admission_serum



Sample suspicious conversions (likely Âµmol/L -> mg/dL applied):


Unnamed: 0,subject_id,hadm_id,itemid,charttime,valuenum,valueuom,creat_mgdl
12583,9714,164688,51080,2101-05-27 14:15:00,58.0,,0.656109
6963,5394,107229,51067,2105-03-27 22:30:00,836.0,,9.457014
1542,1113,128609,51067,2107-02-20 07:19:00,616.0,,6.968326
7460,5760,196904,51080,2108-12-07 13:00:00,23.0,,0.260181
6824,5282,132174,51032,2109-08-10 18:30:00,33.0,,0.373303
6251,4871,137235,51032,2109-08-29 06:01:00,82.5,,0.933258
6249,4871,137235,50841,2109-09-03 08:33:00,41.7,,0.471719
8831,6824,105359,51067,2114-07-29 12:52:00,803.0,,9.08371
3022,2318,168835,51067,2115-08-30 21:19:00,407.0,,4.604072
5543,4292,196125,51080,2116-07-28 22:29:00,41.0,,0.463801


In [None]:

from pathlib import Path

OUT_DIR = Path("/content")
OUT_DIR.mkdir(parents=True, exist_ok=True)


assert 'merged_clean' in globals(), "merged_clean not found â€” run the cleaning pipeline first."


df = merged_clean.copy()
df.columns = df.columns.str.lower()


recommended = [
    'subject_id','hadm_id',
    'age','gender',
    'heart_rate','sys_bp','heart_rate_imp','sys_bp_imp',
    'hr_missing','sbp_missing','hr_was_missing','sbp_was_missing','both_vitals_missing',
    'creatinine','creatinine_source','creatinine_flag',
    'charttime','hospital_expire_flag'
]


keep_cols = [c for c in recommended if c in df.columns]


for extra in ['subject_id','hadm_id','charttime']:
    if extra in df.columns and extra not in keep_cols:
        keep_cols.insert(0, extra)


tidy = df[keep_cols].copy()


tidy_path = OUT_DIR / "merged_tidy.csv"
tidy.to_csv(tidy_path, index=False)
print(f"Saved tidy snapshot -> {tidy_path}")


if 'heart_rate_imp' not in tidy.columns:
    tidy['heart_rate_imp'] = tidy['heart_rate'].fillna(tidy['heart_rate'].median())
if 'sys_bp_imp' not in tidy.columns:
    tidy['sys_bp_imp'] = tidy['sys_bp'].fillna(tidy['sys_bp'].median())

df_keep = tidy.copy()
df_keep_path = OUT_DIR / "df_keep.csv"
df_keep.to_csv(df_keep_path, index=False)
print(f"Saved df_keep (keep+impute) -> {df_keep_path}")


mask_both_missing = df_keep.get('both_vitals_missing', pd.Series(False, index=df_keep.index)).fillna(False).astype(bool)
df_drop = df_keep.loc[~mask_both_missing].reset_index(drop=True)
df_drop_path = OUT_DIR / "df_drop_no_both_vitals_missing.csv"
df_drop.to_csv(df_drop_path, index=False)
print(f"Saved df_drop (dropped both_vitals_missing) -> {df_drop_path}")


dropped_ids = df_keep.loc[mask_both_missing, ['subject_id','hadm_id']].drop_duplicates()
print(f"\nRows dropped by Option B (count = {len(dropped_ids)}). Example IDs:")
display(dropped_ids.head(25))


df_serum_adm = df_keep.loc[df_keep.get('creatinine_source','').astype(str) == 'admission_serum'].reset_index(drop=True)
df_serum_path = OUT_DIR / "df_serum_admission_only.csv"
df_serum_adm.to_csv(df_serum_path, index=False)
print(f"Saved df_serum_adm (admission_serum only) -> {df_serum_path}")


print("\n Quick summary ")
print("Total rows (merged_clean):", len(df))
print("Tidy columns kept:", tidy.columns.tolist())
print("\nCreatinine source counts (in tidy):")
print(tidy['creatinine_source'].value_counts(dropna=False).to_string())

print(f"\nDataset sizes: df_keep={len(df_keep)}, df_drop={len(df_drop)}, df_serum_adm={len(df_serum_adm)}")


print("\nSample tidy rows (first 8):")
display(tidy.sort_values(['subject_id','hadm_id']).head(8))




Saved tidy snapshot -> /content/merged_tidy.csv
Saved df_keep (keep+impute) -> /content/df_keep.csv
Saved df_drop (dropped both_vitals_missing) -> /content/df_drop_no_both_vitals_missing.csv

Rows dropped by Option B (count = 0). Example IDs:


Unnamed: 0,subject_id,hadm_id


Saved df_serum_adm (admission_serum only) -> /content/df_serum_admission_only.csv

 Quick summary 
Total rows (merged_clean): 125
Tidy columns kept: ['subject_id', 'hadm_id', 'heart_rate', 'sys_bp', 'hr_missing', 'sbp_missing', 'both_vitals_missing', 'creatinine', 'creatinine', 'creatinine_source', 'creatinine_flag', 'charttime', 'heart_rate_imp', 'sys_bp_imp']

Creatinine source counts (in tidy):
creatinine_source
median_fallback    70
admission_serum    55

Dataset sizes: df_keep=125, df_drop=125, df_serum_adm=55

Sample tidy rows (first 8):


Unnamed: 0,subject_id,hadm_id,heart_rate,sys_bp,hr_missing,sbp_missing,both_vitals_missing,creatinine,creatinine.1,creatinine_source,creatinine_flag,charttime,heart_rate_imp,sys_bp_imp
0,10006,142345,66.0,150.0,False,False,False,3.7,3.7,admission_serum,admission_serum,2164-10-25 08:30:00,66.0,150.0
1,10011,105331,72.0,41.0,False,False,False,0.9,0.9,admission_serum,admission_serum,2126-08-28 16:10:00,72.0,41.0
2,10013,165520,0.0,45.0,False,False,False,1.7,1.7,admission_serum,admission_serum,2125-10-07 12:12:00,0.0,45.0
3,10017,199207,76.0,121.0,False,False,False,0.4,0.4,admission_serum,admission_serum,2149-05-31 20:00:00,76.0,121.0
4,10019,177759,89.0,47.5,False,False,False,4.0,4.0,admission_serum,admission_serum,2163-05-15 21:00:00,89.0,47.5
5,10026,103770,66.0,145.0,False,False,False,0.5,0.5,admission_serum,admission_serum,2195-05-19 16:00:00,66.0,145.0
6,10027,199395,111.0,34.0,False,False,False,1.4,1.4,admission_serum,admission_serum,2190-07-20 15:00:00,111.0,34.0
7,10029,132349,84.0,152.0,False,False,False,3.0,3.0,admission_serum,admission_serum,2139-09-25 16:00:00,84.0,152.0


In [None]:

from pathlib import Path

IN_PATH  = Path("/content/merged_tidy.csv")
OUT_PATH = Path("/content/merged_tidy_pruned.csv")

# load
tidy = pd.read_csv(IN_PATH)

print("Columns in loaded tidy:", tidy.columns.tolist())


if 'heart_rate_imp' not in tidy.columns:
    if 'heart_rate' in tidy.columns:
        med_hr = pd.to_numeric(tidy['heart_rate'], errors='coerce').median(skipna=True)
        tidy['heart_rate_imp'] = pd.to_numeric(tidy['heart_rate'], errors='coerce').fillna(med_hr)
        print(f"Created heart_rate_imp via median imputation (median={med_hr})")
    else:
        tidy['heart_rate_imp'] = np.nan
        print("No heart_rate column found; created heart_rate_imp as NaN")

if 'sys_bp_imp' not in tidy.columns:
    if 'sys_bp' in tidy.columns:
        med_sbp = pd.to_numeric(tidy['sys_bp'], errors='coerce').median(skipna=True)
        tidy['sys_bp_imp'] = pd.to_numeric(tidy['sys_bp'], errors='coerce').fillna(med_sbp)
        print(f"Created sys_bp_imp via median imputation (median={med_sbp})")
    else:
        tidy['sys_bp_imp'] = np.nan
        print("No sys_bp column found; created sys_bp_imp as NaN")


cre_cols = [c for c in tidy.columns if 'creatinine' in c.lower()]
print("Creatinine-like columns found:", cre_cols)

if not cre_cols:

    tidy['creatinine'] = np.nan
    print("No creatinine-like column found; created empty 'creatinine' column.")
else:

    if 'creatinine' in tidy.columns:
        chosen = 'creatinine'
    else:
        chosen = cre_cols[-1]

        tidy['creatinine'] = tidy[chosen]
        print(f"Standardized creatinine by copying column '{chosen}' -> 'creatinine'")


if 'creatinine_source' not in tidy.columns:
    tidy['creatinine_source'] = 'unknown'
    print("Added missing 'creatinine_source' column with value 'unknown'")


keep_cols = [
    "subject_id", "hadm_id",
    "heart_rate", "sys_bp",
    "creatinine", "creatinine_source",
    "heart_rate_imp", "sys_bp_imp"
]


keep_actual = [c for c in keep_cols if c in tidy.columns]
print("Columns that will be kept:", keep_actual)


tidy_pruned = tidy[keep_actual].copy()


tidy_pruned.to_csv(OUT_PATH, index=False)
print(f"Pruned columns saved -> {OUT_PATH}")
print("Rows:", len(tidy_pruned))


Columns in loaded tidy: ['subject_id', 'hadm_id', 'heart_rate', 'sys_bp', 'hr_missing', 'sbp_missing', 'both_vitals_missing', 'creatinine', 'creatinine.1', 'creatinine_source', 'creatinine_flag', 'charttime']
Created heart_rate_imp via median imputation (median=79.0)
Created sys_bp_imp via median imputation (median=113.0)
Creatinine-like columns found: ['creatinine', 'creatinine.1', 'creatinine_source', 'creatinine_flag']
Columns that will be kept: ['subject_id', 'hadm_id', 'heart_rate', 'sys_bp', 'creatinine', 'creatinine_source', 'heart_rate_imp', 'sys_bp_imp']
Pruned columns saved -> /content/merged_tidy_pruned.csv
Rows: 125


In [None]:
import pandas as pd
pd.read_csv("/content/merged_tidy_pruned.csv").head(12)


Unnamed: 0,subject_id,hadm_id,heart_rate,sys_bp,creatinine,creatinine_source,heart_rate_imp,sys_bp_imp
0,10006,142345,66.0,150.0,3.7,admission_serum,66.0,150.0
1,10011,105331,72.0,41.0,0.9,admission_serum,72.0,41.0
2,10013,165520,0.0,45.0,1.7,admission_serum,0.0,45.0
3,10017,199207,76.0,121.0,0.4,admission_serum,76.0,121.0
4,10019,177759,89.0,47.5,4.0,admission_serum,89.0,47.5
5,10026,103770,66.0,145.0,0.5,admission_serum,66.0,145.0
6,10027,199395,111.0,34.0,1.4,admission_serum,111.0,34.0
7,10029,132349,84.0,152.0,3.0,admission_serum,84.0,152.0
8,10032,140372,108.0,92.0,0.7,admission_serum,108.0,92.0
9,10033,157235,86.0,136.0,0.8,admission_serum,86.0,136.0


In [None]:
import os
import pandas as pd
import numpy as np


files = ["PRESCRIPTIONS_sorted.csv","INPUTEVENTS_MV_sorted.csv","INPUTEVENTS_CV_sorted.csv","PROCEDUREEVENTS_MV_sorted.csv","ADMISSIONS_sorted.csv"]
for f in files:
    print(f, "exists?" , os.path.exists(f))


df_adm = pd.DataFrame({
    'subject_id':[10006,10011,10013,10017,10019,10026,10027,10029,10032,10033,10035,10036],
    'hadm_id':[142345,105331,165520,199207,177759,103770,199395,132349,140372,157235,110244,189483],
    'heart_rate':[66.0,72.0,0.0,76.0,89.0,66.0,111.0,84.0,108.0,86.0,87.0,0.0],
    'sys_bp':[150.0,41.0,45.0,121.0,47.5,145.0,34.0,152.0,92.0,136.0,133.5,91.0],
    'creatinine':[3.7,0.9,1.7,0.4,4.0,0.5,1.4,3.0,0.7,0.8,1.2,0.6],
    'creatinine_source':['admission_serum']*12,
    'heart_rate_imp':[66.0,72.0,0.0,76.0,89.0,66.0,111.0,84.0,108.0,86.0,87.0,0.0],
    'sys_bp_imp':[150.0,41.0,45.0,121.0,47.5,145.0,34.0,152.0,92.0,136.0,133.5,91.0]
})



vital_cols = ['heart_rate','heart_rate_imp','sys_bp','sys_bp_imp']
for c in vital_cols:
    # flag zeros
    n_zero = (df_adm[c]==0).sum()
    if n_zero>0:
        print(f"Column {c} has {n_zero} zero values -> converting to NaN (treat as missing).")
        df_adm.loc[df_adm[c]==0, c] = np.nan


for c in ['heart_rate','heart_rate_imp','sys_bp','sys_bp_imp']:
    if df_adm[c].isna().any():
        med = df_adm[c].median(skipna=True)
        print(f"Imputing {c} NaNs with cohort median = {med}")
        df_adm[c] = df_adm[c].fillna(med)


TIMESTEP_HOURS = 6
NUM_STEPS = 48 // TIMESTEP_HOURS
rows = []
for _, r in df_adm.iterrows():
    for t in range(NUM_STEPS):
        rows.append({
            'subject_id': r['subject_id'],
            'hadm_id': r['hadm_id'],
            'timestep': t,
            'time_since_admit_hours': t*TIMESTEP_HOURS,
            'heart_rate': r['heart_rate_imp'],
            'sys_bp': r['sys_bp_imp'],
            'creatinine': r['creatinine'],
            'action': 'no_action'
        })
traj_df = pd.DataFrame(rows)
print("traj_df sample:")
print(traj_df.groupby('hadm_id').size().head())
print(traj_df.head(6))


PRESCRIPTIONS_sorted.csv exists? True
INPUTEVENTS_MV_sorted.csv exists? True
INPUTEVENTS_CV_sorted.csv exists? True
PROCEDUREEVENTS_MV_sorted.csv exists? True
ADMISSIONS_sorted.csv exists? True
Column heart_rate has 2 zero values -> converting to NaN (treat as missing).
Column heart_rate_imp has 2 zero values -> converting to NaN (treat as missing).
Imputing heart_rate NaNs with cohort median = 85.0
Imputing heart_rate_imp NaNs with cohort median = 85.0
traj_df sample:
hadm_id
103770    8
105331    8
110244    8
132349    8
140372    8
dtype: int64
   subject_id  hadm_id  timestep  time_since_admit_hours  heart_rate  sys_bp  \
0       10006   142345         0                       0        66.0   150.0   
1       10006   142345         1                       6        66.0   150.0   
2       10006   142345         2                      12        66.0   150.0   
3       10006   142345         3                      18        66.0   150.0   
4       10006   142345         4             

In [None]:
import pandas as pd
import numpy as np
from datetime import timedelta

TIMESTEP_HOURS = 6
example_hadm = 142345
example_timestep = 0

admissions = pd.read_csv("ADMISSIONS_sorted.csv", parse_dates=['ADMITTIME'], low_memory=False)
admissions = admissions.rename(columns={'ADMITTIME':'admit_time','HADM_ID':'hadm_id'})
ad_row = admissions.loc[admissions['hadm_id']==example_hadm]

if ad_row.empty:
    print(f"Admission {example_hadm} not found in ADMISSIONS_sorted.csv")
else:
    admit_time = ad_row['admit_time'].iloc[0]
    print(ad_row[['hadm_id','SUBJECT_ID','admit_time','DISCHTIME','DEATHTIME']].head(1))

    window_start = admit_time + pd.Timedelta(hours=example_timestep * TIMESTEP_HOURS)
    window_end   = admit_time + pd.Timedelta(hours=(example_timestep+1) * TIMESTEP_HOURS)
    print(window_start, "->", window_end)
    print("-"*60)

    try:
        presc = pd.read_csv("PRESCRIPTIONS_sorted.csv", low_memory=False, parse_dates=['STARTDATE','ENDDATE'])
        presc = presc.rename(columns={c:c.upper() for c in presc.columns})
        psub = presc[(presc['HADM_ID']==example_hadm) & (presc['STARTDATE'] >= window_start) & (presc['STARTDATE'] < window_end)]
        print(f"PRESCRIPTIONS rows in window: {len(psub)}")
        if not psub.empty:
            display_cols = ['ROW_ID','SUBJECT_ID','HADM_ID','STARTDATE','ENDDATE','DRUG','DRUG_NAME_POE','DRUG_NAME_GENERIC','DOSE_VAL_RX','DOSE_UNIT_RX','ROUTE']
            print(psub[display_cols].to_string(index=False))
        else:
            print("  (none)")
    except FileNotFoundError:
        print("PRESCRIPTIONS_sorted.csv not found")

    print("-"*60)

    try:
        iev_mv = pd.read_csv("INPUTEVENTS_MV_sorted.csv", low_memory=False)
        iev_mv.columns = [c.upper() for c in iev_mv.columns]
        iev_mv['STARTTIME'] = pd.to_datetime(iev_mv['STARTTIME'], errors='coerce')
        iev_mv['ENDTIME'] = pd.to_datetime(iev_mv['ENDTIME'], errors='coerce')
        cond_mv = (iev_mv['HADM_ID']==example_hadm) & ((iev_mv['STARTTIME'] < window_end) & ((iev_mv['ENDTIME'].notna() & (iev_mv['ENDTIME'] >= window_start)) | (iev_mv['ENDTIME'].isna() & (iev_mv['STARTTIME'] >= window_start))))
        mv_sub = iev_mv.loc[cond_mv]
        print(f"INPUTEVENTS_MV rows overlapping window: {len(mv_sub)}")
        if not mv_sub.empty:
            display_cols = ['ROW_ID','SUBJECT_ID','HADM_ID','STARTTIME','ENDTIME','ITEMID','AMOUNT','AMOUNTUOM','RATE','RATEUOM','ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME']
            print(mv_sub[display_cols].to_string(index=False))
        else:
            print("  (none)")
    except FileNotFoundError:
        print("INPUTEVENTS_MV_sorted.csv not found")

    print("-"*60)

    try:
        iev_cv = pd.read_csv("INPUTEVENTS_CV_sorted.csv", low_memory=False)
        iev_cv.columns = [c.upper() for c in iev_cv.columns]
        iev_cv['CHARTTIME'] = pd.to_datetime(iev_cv['CHARTTIME'], errors='coerce')
        iev_cv = iev_cv.dropna(subset=['CHARTTIME'])
        cv_sub = iev_cv[(iev_cv['HADM_ID']==example_hadm) & (iev_cv['CHARTTIME'] >= window_start) & (iev_cv['CHARTTIME'] < window_end)]
        print(f"INPUTEVENTS_CV rows in window: {len(cv_sub)}")
        if not cv_sub.empty:
            display_cols = ['ROW_ID','SUBJECT_ID','HADM_ID','CHARTTIME','ITEMID','AMOUNT','AMOUNTUOM','RATE','RATEUOM','ORIGINALROUTE','ORIGINALRATEUOM','ORDERID']
            print(cv_sub[display_cols].to_string(index=False))
        else:
            print("  (none)")
    except FileNotFoundError:
        print("INPUTEVENTS_CV_sorted.csv not found")

    print("-"*60)

    try:
        proc = pd.read_csv("PROCEDUREEVENTS_MV_sorted.csv", low_memory=False)
        proc.columns = [c.upper() for c in proc.columns]
        proc['STARTTIME'] = pd.to_datetime(proc['STARTTIME'], errors='coerce')
        proc['ENDTIME'] = pd.to_datetime(proc['ENDTIME'], errors='coerce')
        proc_sub = proc[(proc['HADM_ID']==example_hadm) & (proc['STARTTIME'] >= window_start) & (proc['STARTTIME'] < window_end)]
        print(f"PROCEDUREEVENTS_MV rows in window: {len(proc_sub)}")
        if not proc_sub.empty:
            display_cols = ['ROW_ID','SUBJECT_ID','HADM_ID','STARTTIME','ENDTIME','ITEMID','VALUE','VALUEUOM','ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME']
            print(proc_sub[display_cols].to_string(index=False))
        else:
            print("  (none)")
    except FileNotFoundError:
        print("PROCEDUREEVENTS_MV_sorted.csv not found")

    print("-"*60)

    import os
    if os.path.exists("traj_with_mapped_actions.csv"):
        traj = pd.read_csv("traj_with_mapped_actions.csv")
        row = traj[(traj['hadm_id']==example_hadm) & (traj['timestep']==example_timestep)]
        if not row.empty:
            print(row[['hadm_id','timestep','mapped_action','action_code']].to_string(index=False))
        else:
            print("No row found in traj_with_mapped_actions.csv for that hadm/timestep.")
    else:
        print("traj_with_mapped_actions.csv not found")


       hadm_id  SUBJECT_ID          admit_time            DISCHTIME DEATHTIME
12257   142345       10006 2164-10-23 21:09:00  2164-11-01 17:15:00       NaN
2164-10-23 21:09:00 -> 2164-10-24 03:09:00
------------------------------------------------------------
PRESCRIPTIONS rows in window: 4
 ROW_ID  SUBJECT_ID  HADM_ID  STARTDATE    ENDDATE              DRUG     DRUG_NAME_POE DRUG_NAME_GENERIC DOSE_VAL_RX DOSE_UNIT_RX ROUTE
 299233       10006   142345 2164-10-24 2164-11-01             Senna             Senna             Senna           1          TAB    PO
 299232       10006   142345 2164-10-24 2164-11-01   Docusate Sodium   Docusate Sodium   Docusate Sodium         100           mg    PO
 299234       10006   142345 2164-10-24 2164-10-25 Magnesium Sulfate Magnesium Sulfate Magnesium Sulfate           2           gm    IV
 299231       10006   142345 2164-10-24 2164-10-25   Magnesium Oxide   Magnesium Oxide   Magnesium Oxide         140           mg    PO
----------------------------

In [None]:
# === Updated mapping pipeline (drop-in replacement) ===
import os, re
import numpy as np
import pandas as pd
from typing import List, Tuple, Optional

# ---------- Config ----------
TIMESTEP_HOURS = globals().get('TIMESTEP_HOURS', 6)
TOTAL_HOURS = globals().get('TOTAL_HOURS', 48)
NUM_STEPS = globals().get('NUM_STEPS', TOTAL_HOURS // TIMESTEP_HOURS)
CHUNKSIZE = globals().get('CHUNKSIZE', 50000)
ACTION_LABELS = globals().get('ACTION_LABELS', ['no_action','vasopressor','fluid_bolus','diuretic','antibiotic','insulin','other'])
ACTION_PRIORITY = globals().get('ACTION_PRIORITY', {'no_action':0,'other':1,'antibiotic':2,'diuretic':3,'fluid_bolus':4,'insulin':5,'vasopressor':6})

PRESC_MAP_SAFE = {
    'antibiotic': [r'\bvancomycin\b', r'\bampicillin\b', r'\bclindamycin\b', r'\bpiperacillin[- ]?tazobactam\b',
                   r'\bmeropenem\b', r'\bciprofloxacin\b', r'\blevofloxacin\b', r'\bazithro\w*', r'\bgentamicin\b',
                   r'\bceftriaxone\b', r'\bcefazolin\b', r'\bmetronidazole\b'],
    'diuretic': [r'\bfurosemide\b', r'\bbumetanide\b', r'\btorsemide\b', r'\bhydrochlorothiazide\b', r'\bspironolactone\b'],
    'vasopressor': [r'\bnorepinephrine\b', r'\bphenylephrine\b', r'\bepinephrine\b', r'\bvasopressin\b', r'\bdopamine\b'],
    'insulin': [r'\binsulin glargine\b', r'\binsulin lispro\b', r'\binsulin aspart\b', r'\binsulin\b']
}
PRESC_REGEX = {k: re.compile("|".join(v), flags=re.I) for k, v in PRESC_MAP_SAFE.items()}
FLUID_RE = re.compile(r'intravenous|intravenous push|intravenous drip|fluid|bolus|drip|crystalloid|colloid', flags=re.I)

# ---------- Helpers ----------
def hours_to_timestep(hours_series: pd.Series) -> pd.Series:
    """Convert hours -> nullable Int64 timestep index. Return pd.NA for out-of-window."""
    idx_float = np.floor(hours_series / TIMESTEP_HOURS)
    idx = pd.Series(idx_float, index=hours_series.index).astype('Int64')
    mask_outside = (idx < 0) | (idx >= NUM_STEPS)
    idx.loc[mask_outside] = pd.NA
    return idx

def explode_ranges_vectorized(hadm_arr: np.ndarray, start_arr: np.ndarray, end_arr: np.ndarray, label_arr: np.ndarray) -> Optional[pd.DataFrame]:
    """Explode ranges into (hadm_id,timestep,action_label). Arguments must be numpy ints (no NA)."""
    if len(start_arr) == 0:
        return None
    valid_mask = (start_arr <= end_arr)
    if not valid_mask.any():
        return None
    starts = start_arr[valid_mask]; ends = end_arr[valid_mask]
    hadms = hadm_arr[valid_mask]; labels = label_arr[valid_mask]
    ranges = [np.arange(s, e+1, dtype='int64') for s, e in zip(starts, ends)]
    lengths = np.array([len(r) for r in ranges], dtype='int64')
    if lengths.sum() == 0:
        return None
    all_timesteps = np.concatenate(ranges)
    all_hadms = np.repeat(hadms, lengths)
    all_labels = np.repeat(labels, lengths)
    df = pd.DataFrame({'hadm_id': all_hadms, 'timestep': all_timesteps, 'action_label': all_labels})
    df = df[(df['timestep'] >= 0) & (df['timestep'] < NUM_STEPS)].copy()
    return df if not df.empty else None

def prepare_for_explode(df: pd.DataFrame, hadm_col: str, start_col: str, end_col: str, label_col: str,
                        treat_end_as_start: bool = True, cap_end_to_last: bool = True) -> Optional[Tuple[np.ndarray,np.ndarray,np.ndarray,np.ndarray]]:
    tmp = df[[hadm_col, start_col, end_col, label_col]].copy()
    if treat_end_as_start:
        tmp[end_col] = tmp[end_col].where(tmp[end_col].notna(), tmp[start_col])
    if cap_end_to_last:
        mask_non_na = tmp[end_col].notna()
        tmp.loc[mask_non_na, end_col] = tmp.loc[mask_non_na, end_col].clip(upper=NUM_STEPS-1)
    valid = tmp[hadm_col].notna() & tmp[start_col].notna() & tmp[end_col].notna()
    if not valid.any():
        return None
    hadm_arr  = tmp.loc[valid, hadm_col].astype('int64').to_numpy()
    start_arr = tmp.loc[valid, start_col].astype('int64').to_numpy()
    end_arr   = tmp.loc[valid, end_col].astype('int64').to_numpy()
    label_arr = tmp.loc[valid, label_col].to_numpy()
    return hadm_arr, start_arr, end_arr, label_arr

# ---------- Ensure inputs ----------
if 'traj_df' not in globals():
    raise RuntimeError("traj_df not present. Create your skeleton 'traj_df' before running this mapping.")
if 'admissions' not in globals():
    if os.path.exists("ADMISSIONS_sorted.csv"):
        admissions = pd.read_csv("ADMISSIONS_sorted.csv", parse_dates=['ADMITTIME'], low_memory=False).rename(columns={'ADMITTIME':'admit_time','HADM_ID':'hadm_id'})
    else:
        raise RuntimeError("admissions not found in namespace nor ADMISSIONS_sorted.csv on disk.")
# normalize admissions
if 'admit_time' not in admissions.columns and 'ADMITTIME' in admissions.columns:
    admissions = admissions.rename(columns={'ADMITTIME':'admit_time'})
if 'hadm_id' not in admissions.columns and 'HADM_ID' in admissions.columns:
    admissions = admissions.rename(columns={'HADM_ID':'hadm_id'})
admissions['hadm_id'] = pd.to_numeric(admissions['hadm_id'], errors='coerce').astype('Int64')
admissions['admit_time'] = pd.to_datetime(admissions['admit_time'], errors='coerce')
admissions_small = admissions[['hadm_id','admit_time']].copy()

hadm_keep = set(traj_df['hadm_id'].unique())
print(f"Processing only {len(hadm_keep)} admissions from traj_df.")

# ---------- Collect action_frames ----------
action_frames: List[pd.DataFrame] = []

# PRESCRIPTIONS
if os.path.exists("PRESCRIPTIONS_sorted.csv"):
    for chunk in pd.read_csv("PRESCRIPTIONS_sorted.csv", chunksize=CHUNKSIZE, dtype=str, low_memory=False):
        chunk.columns = [c.upper() for c in chunk.columns]
        if 'HADM_ID' not in chunk.columns:
            continue
        chunk['HADM_ID'] = pd.to_numeric(chunk['HADM_ID'], errors='coerce').astype('Int64')
        chunk = chunk[chunk['HADM_ID'].isin(hadm_keep)]
        if chunk.empty:
            continue
        chunk['STARTDATE_PARSED'] = pd.to_datetime(chunk.get('STARTDATE','').astype(str).str.strip(), errors='coerce')
        chunk['ENDDATE_PARSED']   = pd.to_datetime(chunk.get('ENDDATE','').astype(str).str.strip(), errors='coerce').fillna(chunk['STARTDATE_PARSED'])
        chunk = chunk.merge(admissions_small, left_on='HADM_ID', right_on='hadm_id', how='left')
        dt_start = (chunk['STARTDATE_PARSED'] - chunk['admit_time']).dt.total_seconds() / 3600.0
        dt_end   = (chunk['ENDDATE_PARSED'] - chunk['admit_time']).dt.total_seconds() / 3600.0
        chunk['start_idx'] = hours_to_timestep(dt_start).astype('Int64')
        chunk['end_idx']   = hours_to_timestep(dt_end).astype('Int64')
        chunk = chunk[chunk['start_idx'].notna()]
        text_cols = [c for c in ['DRUG','DRUG_NAME_POE','DRUG_NAME_GENERIC'] if c in chunk.columns]
        chunk['combined_text'] = chunk[text_cols].fillna('').agg(' '.join, axis=1).str.lower() if text_cols else ''
        chunk['action_label'] = 'other'
        for lbl, rx in PRESC_REGEX.items():
            mask = chunk['combined_text'].str.contains(rx, na=False)
            chunk.loc[mask, 'action_label'] = lbl
        prepared = prepare_for_explode(chunk, hadm_col='HADM_ID', start_col='start_idx', end_col='end_idx', label_col='action_label', treat_end_as_start=True)
        if prepared is not None:
            hadm_arr, start_arr, end_arr, label_arr = prepared
            df_expl = explode_ranges_vectorized(hadm_arr, start_arr, end_arr, label_arr)
            if df_expl is not None:
                df_expl['source'] = 'prescriptions'
                action_frames.append(df_expl)
    print("PRESCRIPTIONS processed.")
else:
    print("PRESCRIPTIONS_sorted.csv not found -> skipped.")

# INPUTEVENTS_MV (main, chunked)
MV_PATH = "INPUTEVENTS_MV_sorted.csv"
if os.path.exists(MV_PATH):
    usecols = ['ROW_ID','SUBJECT_ID','HADM_ID','STARTTIME','ENDTIME','AMOUNTUOM','RATEUOM','ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME','ORDERCATEGORYDESCRIPTION','AMOUNT','RATE','ITEMID']
    for chunk in pd.read_csv(MV_PATH, usecols=lambda c: c.upper() in [u.upper() for u in usecols], chunksize=CHUNKSIZE, dtype=str, low_memory=False):
        chunk.columns = [c.upper() for c in chunk.columns]
        chunk['STARTTIME_parsed'] = pd.to_datetime(chunk.get('STARTTIME','').astype(str).str.strip(), errors='coerce')
        chunk['ENDTIME_parsed']   = pd.to_datetime(chunk.get('ENDTIME','').astype(str).str.strip(), errors='coerce').fillna(chunk['STARTTIME_parsed'])
        # fast path: HADM_ID present, restrict to hadm_keep
        if 'HADM_ID' in chunk.columns:
            chunk['HADM_ID'] = pd.to_numeric(chunk['HADM_ID'], errors='coerce').astype('Int64')
            hv = chunk[chunk['HADM_ID'].isin(hadm_keep)].copy()
            if not hv.empty:
                hv = hv.merge(admissions_small, left_on='HADM_ID', right_on='hadm_id', how='left')
                dt_start = (hv['STARTTIME_parsed'] - hv['admit_time']).dt.total_seconds() / 3600.0
                dt_end   = (hv['ENDTIME_parsed']   - hv['admit_time']).dt.total_seconds() / 3600.0
                hv['start_idx'] = hours_to_timestep(dt_start).astype('Int64')
                hv['end_idx']   = hours_to_timestep(dt_end).astype('Int64')
                hv = hv[hv['start_idx'].notna()]
                if not hv.empty:
                    # tighter heuristics for fluid vs infusion
                    meta_cols = [c for c in ['ORIGINALROUTE','ORIGINALRATEUOM','ORDERID','AMOUNTUOM','RATEUOM','ORDERCATEGORYNAME','ITEMID'] if c in chunk.columns]
                    meta_text = chunk[meta_cols].fillna('').agg(' '.join, axis=1).str.lower() if meta_cols else pd.Series('', index=chunk.index)
                    amt = pd.to_numeric(chunk.get('AMOUNT','').astype(str).str.replace(',',''), errors='coerce')
                    amt_uom = chunk.get('AMOUNTUOM', pd.Series('', index=chunk.index)).fillna('').str.lower()
                    rate_uom = chunk.get('RATEUOM', pd.Series('', index=chunk.index)).fillna('').str.lower()
                    is_bolus_text = meta_text.str.contains(r'\bbolus\b|\bpush\b|\bstat\b', na=False)
                    is_infusion_text = meta_text.str.contains(r'\bdrip\b|\binfusion\b|\bcontinuous\b|\bmaintenance\b', na=False)
                    is_single_ml_bolus = (amt.notna() & amt.between(1,2000)) & amt_uom.isin(['ml','l']) & (~rate_uom.str.contains(r'/hr|/h|per hour', na=False))
                    cond_fluid = (is_bolus_text | is_single_ml_bolus) & (~is_infusion_text)
                    cond_press = meta_text.str.contains(r'press|pressor|inotrope|norepinephrine|vasopressin|epinephrine|dopamine', na=False) | rate_uom.str.contains(r'mcgkgmin|mcg/kg/min', na=False)
                    cond_ab = ord_text.str.contains(r'antibiotic|antibiotics', na=False)
                    hv['action_label'] = 'other'
                    hv.loc[cond_ab, 'action_label'] = 'antibiotic'
                    hv.loc[cond_press, 'action_label'] = 'vasopressor'
                    hv.loc[cond_fluid, 'action_label'] = 'fluid_bolus'
                    hv.loc[ord_text.str.contains('insulin', na=False), 'action_label'] = 'insulin'
                    # reduce duplicates: group identical hadm/start/end/label combos
                    hv['start_idx_i'] = hv['start_idx'].astype('Int64')
                    hv['end_idx_i']   = hv['end_idx'].astype('Int64')
                    grouped = hv.groupby(['HADM_ID','start_idx_i','end_idx_i','action_label'], dropna=False, as_index=False).agg({'AMOUNT':'first','RATE':'first'})
                    grouped = grouped.rename(columns={'start_idx_i':'start_idx','end_idx_i':'end_idx'})
                    prepared = prepare_for_explode(grouped, hadm_col='HADM_ID', start_col='start_idx', end_col='end_idx', label_col='action_label', treat_end_as_start=True)
                    if prepared is not None:
                        hadm_arr, start_arr, end_arr, label_arr = prepared
                        df_expl = explode_ranges_vectorized(hadm_arr, start_arr, end_arr, label_arr)
                        if df_expl is not None:
                            df_expl['source'] = 'inputevents_mv'
                            action_frames.append(df_expl)
        # subject-based matching path (if present)
        if 'SUBJECT_ID' in chunk.columns:
            # build subject->admissions map once (cheap relative to full file)
            subject_to_admissions = {sid: subdf.drop(columns=['subject_id']).reset_index(drop=True) for sid, subdf in admissions.groupby('subject_id')} if 'subject_id' in admissions.columns else {}
            chunk['SUBJECT_ID'] = pd.to_numeric(chunk['SUBJECT_ID'], errors='coerce').astype('Int64')
            present_subjects = set(chunk['SUBJECT_ID'].dropna().astype(int).unique()) & set(subject_to_admissions.keys())
            if present_subjects:
                sub_chunk = chunk[chunk['SUBJECT_ID'].isin(present_subjects)].copy()
                for subj, subgrp in sub_chunk.groupby('SUBJECT_ID'):
                    adm_rows = subject_to_admissions.get(subj)
                    if adm_rows is None or adm_rows.empty:
                        continue
                    subgrp = subgrp.reset_index(drop=True); adm = adm_rows.copy().reset_index(drop=True)
                    subgrp['_tmp']=1; adm['_tmp']=1
                    merge = subgrp.merge(adm, on='_tmp', suffixes=('','_adm')).drop(columns=['_tmp'])
                    cond_admit_ge = merge['STARTTIME_parsed'] >= merge['admit_time']
                    if 'disch_time' in merge.columns:
                        cond_before_disch = (merge['STARTTIME_parsed'] <= merge['disch_time']) | merge['disch_time'].isna()
                        mask = cond_admit_ge & cond_before_disch
                    else:
                        mask = cond_admit_ge
                    merge = merge.loc[mask].copy()
                    if merge.empty:
                        continue
                    dt_start = (merge['STARTTIME_parsed'] - merge['admit_time']).dt.total_seconds() / 3600.0
                    dt_end   = (merge['ENDTIME_parsed'] - merge['admit_time']).dt.total_seconds() / 3600.0
                    merge['start_idx'] = hours_to_timestep(dt_start).astype('Int64')
                    merge['end_idx']   = hours_to_timestep(dt_end).astype('Int64')
                    merge = merge[merge['start_idx'].notna()]
                    if merge.empty:
                        continue
                    ord_cols = [c for c in ['ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME','ORDERCATEGORYDESCRIPTION'] if c in merge.columns]
                    ord_text = merge[ord_cols].fillna('').agg(' '.join, axis=1).str.lower() if ord_cols else pd.Series('', index=merge.index)
                    amt_num = pd.to_numeric(merge.get('AMOUNT','').astype(str).str.replace(',',''), errors='coerce')
                    amt_uom = merge.get('AMOUNTUOM', pd.Series('', index=merge.index)).fillna('').str.lower()
                    rate_uom = merge.get('RATEUOM', pd.Series('', index=merge.index)).fillna('').str.lower()
                    is_bolus_text = ord_text.str.contains(r'\bbolus\b|\bpush\b|\bstat\b', na=False)
                    is_infusion_text = ord_text.str.contains(r'\bdrip\b|\binfusion\b|\bcontinuous\b', na=False)
                    is_single_ml_bolus = (amt_num.notna() & amt_num.between(1, 2000)) & amt_uom.isin(['ml','l']) & (~rate_uom.str.contains(r'/hr|/h|per hour', na=False))
                    cond_fluid = (is_bolus_text | is_single_ml_bolus) & (~is_infusion_text)
                    cond_press = ord_text.str.contains(r'press|pressor|inotrope|norepinephrine|epinephrine|vasopressin|dopamine', na=False) | rate_uom.str.contains(r'mcgkgmin|mcg/kg/min', na=False)
                    cond_ab = ord_text.str.contains(r'antibiotic|antibiotics', na=False)
                    merge['action_label'] = 'other'
                    merge.loc[cond_ab, 'action_label'] = 'antibiotic'
                    merge.loc[cond_press, 'action_label'] = 'vasopressor'
                    merge.loc[cond_fluid, 'action_label'] = 'fluid_bolus'
                    merge.loc[ord_text.str.contains('insulin', na=False), 'action_label'] = 'insulin'
                    grouped = merge.groupby(['hadm_id','start_idx','end_idx','action_label'], dropna=False, as_index=False).agg({'AMOUNT':'first','RATE':'first'})
                    prepared = prepare_for_explode(grouped, hadm_col='hadm_id', start_col='start_idx', end_col='end_idx', label_col='action_label', treat_end_as_start=True)
                    if prepared is not None:
                        hadm_arr, start_arr, end_arr, label_arr = prepared
                        df_expl = explode_ranges_vectorized(hadm_arr, start_arr, end_arr, label_arr)
                        if df_expl is not None:
                            df_expl['source'] = 'inputevents_mv_subject'
                            action_frames.append(df_expl)
    print("INPUTEVENTS_MV processed.")
else:
    print("INPUTEVENTS_MV_sorted.csv not found -> skipped.")

# INPUTEVENTS_CV (charted events)
if os.path.exists("INPUTEVENTS_CV_sorted.csv"):
    for chunk in pd.read_csv("INPUTEVENTS_CV_sorted.csv", chunksize=CHUNKSIZE, dtype=str, low_memory=False):
        chunk.columns = [c.upper() for c in chunk.columns]
        if 'HADM_ID' not in chunk.columns:
            continue
        chunk['HADM_ID'] = pd.to_numeric(chunk['HADM_ID'], errors='coerce').astype('Int64')
        chunk['CHARTTIME_PARSED'] = pd.to_datetime(chunk.get('CHARTTIME','').astype(str).str.strip(), errors='coerce')
        chunk = chunk.merge(admissions_small, left_on='HADM_ID', right_on='hadm_id', how='left')
        dt_chart = (chunk['CHARTTIME_PARSED'] - chunk['admit_time']).dt.total_seconds() / 3600.0
        chunk['timestep'] = hours_to_timestep(dt_chart).astype('Int64')
        chunk = chunk[chunk['timestep'].notna()]
        if chunk.empty:
            continue
        meta_cols = [c for c in ['ORIGINALROUTE','ORIGINALRATEUOM','ORDERID','AMOUNTUOM','RATEUOM'] if c in chunk.columns]
        meta_text = chunk[meta_cols].fillna('').agg(' '.join, axis=1).str.lower() if meta_cols else pd.Series('', index=chunk.index)
        amt_uom = chunk.get('AMOUNTUOM', pd.Series('', index=chunk.index)).fillna('').str.lower()
        rate_uom = chunk.get('RATEUOM', pd.Series('', index=chunk.index)).fillna('').str.lower()
        # apply same stricter checks (single-timepoint)
        is_bolus = meta_text.str.contains(r'\bbolus\b|\bpush\b|\bstat\b', na=False)
        is_infusion = meta_text.str.contains(r'\bdrip\b|\binfusion\b|\bcontinuous\b', na=False)
        cond_fluid = is_bolus & (~is_infusion)
        cond_press = rate_uom.str.contains(r'mcgkgmin|mcg/kg/min', na=False) | meta_text.str.contains(r'press|pressor|inotrope', na=False)
        cond_ab = meta_text.str.contains(r'antibiotic', na=False)
        chunk['action_label'] = 'other'
        chunk.loc[cond_ab, 'action_label'] = 'antibiotic'
        chunk.loc[cond_press, 'action_label'] = 'vasopressor'
        chunk.loc[cond_fluid, 'action_label'] = 'fluid_bolus'
        # keep and append
        valid_mask = chunk['HADM_ID'].notna() & chunk['timestep'].notna()
        if valid_mask.any():
            df_cv = chunk.loc[valid_mask, ['HADM_ID','timestep','action_label']].copy()
            df_cv.columns = ['hadm_id','timestep','action_label']
            df_cv['source'] = 'inputevents_cv'
            df_cv['hadm_id'] = df_cv['hadm_id'].astype('int64')
            df_cv['timestep'] = df_cv['timestep'].astype('int64')
            action_frames.append(df_cv)
    print("INPUTEVENTS_CV processed.")
else:
    print("INPUTEVENTS_CV_sorted.csv not found -> skipped.")

# PROCEDUREEVENTS_MV (optional)
if os.path.exists("PROCEDUREEVENTS_MV_sorted.csv"):
    for chunk in pd.read_csv("PROCEDUREEVENTS_MV_sorted.csv", chunksize=CHUNKSIZE, dtype=str, low_memory=False):
        chunk.columns = [c.upper() for c in chunk.columns]
        if 'HADM_ID' not in chunk.columns:
            continue
        chunk['HADM_ID'] = pd.to_numeric(chunk['HADM_ID'], errors='coerce').astype('Int64')
        chunk = chunk[chunk['HADM_ID'].isin(hadm_keep)]
        if chunk.empty:
            continue
        chunk['STARTTIME_PARSED'] = pd.to_datetime(chunk.get('STARTTIME','').astype(str).str.strip(), errors='coerce')
        chunk['ENDTIME_PARSED'] = pd.to_datetime(chunk.get('ENDTIME','').astype(str).str.strip(), errors='coerce').fillna(chunk['STARTTIME_PARSED'])
        chunk = chunk.merge(admissions_small, left_on='HADM_ID', right_on='hadm_id', how='left')
        dt_start = (chunk['STARTTIME_PARSED'] - chunk['admit_time']).dt.total_seconds() / 3600.0
        dt_end   = (chunk['ENDTIME_PARSED'] - chunk['admit_time']).dt.total_seconds() / 3600.0
        chunk['start_idx'] = hours_to_timestep(dt_start).astype('Int64')
        chunk['end_idx']   = hours_to_timestep(dt_end).astype('Int64')
        chunk = chunk[chunk['start_idx'].notna()]
        if chunk.empty:
            continue
        chunk['action_label'] = 'other'
        prepared = prepare_for_explode(chunk, hadm_col='HADM_ID', start_col='start_idx', end_col='end_idx', label_col='action_label', treat_end_as_start=True)
        if prepared is not None:
            hadm_arr, start_arr, end_arr, label_arr = prepared
            df_expl = explode_ranges_vectorized(hadm_arr, start_arr, end_arr, label_arr)
            if df_expl is not None:
                df_expl['source'] = 'procedureevents_mv'
                action_frames.append(df_expl)
    print("PROCEDUREEVENTS_MV processed.")
else:
    print("PROCEDUREEVENTS_MV_sorted.csv not found -> skipped.")

# ---------- Post-process actions frames ----------
if not action_frames:
    raise RuntimeError("No actions collected (action_frames empty). Nothing to merge.")
actions_df = pd.concat(action_frames, ignore_index=True)

# Save raw for audit
actions_raw_path = "actions_raw.csv"
actions_df.to_csv(actions_raw_path, index=False)
print(f"Wrote raw concatenated actions to {actions_raw_path} (rows: {len(actions_df)})")

# Diagnostics
print("\nPre-dedupe counts by source/action:")
print(actions_df.groupby(['source','action_label']).size().unstack(fill_value=0))
conflict_counts = actions_df.groupby(['hadm_id','timestep']).size()
print("\nConflicts summary (how many hadm,timestep combos had >1 raw row):", (conflict_counts>1).sum())
print("Top conflicts (sample):")
print(conflict_counts.sort_values(ascending=False).head(20))

# Clean & dedupe (priority)
actions_df = actions_df[actions_df['hadm_id'].notna() & actions_df['timestep'].notna()].copy()
actions_df['hadm_id'] = actions_df['hadm_id'].astype('int64')
actions_df['timestep'] = actions_df['timestep'].astype('int64')
actions_df['action_label'] = actions_df['action_label'].astype(str).str.strip().str.lower()
actions_df['priority'] = actions_df['action_label'].map(lambda x: ACTION_PRIORITY.get(x, 1))
actions_sorted = actions_df.sort_values(['hadm_id','timestep','priority'], ascending=[True, True, False])
actions_top = actions_sorted.drop_duplicates(subset=['hadm_id','timestep'], keep='first').reset_index(drop=True)

# Save top actions
actions_top_path = "actions_top.csv"
actions_top.to_csv(actions_top_path, index=False)
print(f"Saved deduplicated actions to {actions_top_path} (unique hadm,timestep: {len(actions_top)})")
print("\nPost-dedupe action counts:")
print(actions_top['action_label'].value_counts())

# ---------- Merge into traj_df skeleton (final) ----------
traj_out = traj_df.merge(actions_top[['hadm_id','timestep','action_label','source']], on=['hadm_id','timestep'], how='left')
traj_out['mapped_action'] = traj_out['action_label'].fillna('no_action')
action_to_code = {lbl:i for i,lbl in enumerate(ACTION_LABELS)}
traj_out['action_code'] = traj_out['mapped_action'].map(lambda x: action_to_code.get(x, action_to_code.get('other',1))).astype(int)
traj_out.to_csv("traj_with_mapped_actions.csv", index=False)
print(f"Wrote final traj_with_mapped_actions.csv ({len(traj_out)} rows).")
print(traj_out['mapped_action'].value_counts())


Processing only 12 admissions from traj_df.
PRESCRIPTIONS processed.
INPUTEVENTS_MV processed.
INPUTEVENTS_CV processed.
PROCEDUREEVENTS_MV processed.
Wrote raw concatenated actions to actions_raw.csv (rows: 1263228)

Pre-dedupe counts by source/action:
action_label    antibiotic  diuretic  fluid_bolus  insulin   other  \
source                                                               
inputevents_cv           0         0       723326        0  332970   
prescriptions           25        17            0       20     347   

action_label    vasopressor  
source                       
inputevents_cv       206503  
prescriptions            20  

Conflicts summary (how many hadm,timestep combos had >1 raw row): 52960
Top conflicts (sample):
hadm_id  timestep
124271   4           329
154448   3           303
134650   3           297
199727   1           294
193335   0           280
105709   5           265
110691   5           262
134650   4           261
145134   3           258
11069

In [None]:
import pandas as pd
import numpy as np
from collections import Counter

# paths (change if needed)
CV_PATH = "INPUTEVENTS_CV_sorted.csv"
ACTIONS_TOP = "actions_top.csv"   # produced by your pipeline
TRAJ = "traj_with_mapped_actions.csv"

# load small references
traj = pd.read_csv(TRAJ, usecols=['hadm_id']).drop_duplicates()
hadm_set = set(traj['hadm_id'].astype(int).unique())
print("Hadm IDs in traj:", len(hadm_set))

# load deduped actions (fast)
actions_top = pd.read_csv(ACTIONS_TOP)
# restrict to CV-labeled fluid/pressor timesteps
focus = actions_top[actions_top['action_label'].isin(['fluid_bolus','vasopressor'])]
focus = focus[focus['hadm_id'].isin(hadm_set)]
print("Focused action rows (post-dedupe) to inspect:", len(focus))

if len(focus)==0:
    print("No fluid/vasopressor rows in actions_top for these hadms â€” nothing to inspect.")
else:
    # read CV source but only rows for hadm in hadm_set
    usecols = None  # read all (or set a smaller list if you want)
    # if the CV file is huge, add nrows=200000 to the read_csv call for quick run
    cv = pd.read_csv(CV_PATH, low_memory=False)
    cv.columns = [c.upper() for c in cv.columns]

    # ensure HADM_ID numeric and CHARTTIME parsed
    cv['HADM_ID'] = pd.to_numeric(cv.get('HADM_ID'), errors='coerce').astype('Int64')
    cv = cv[cv['HADM_ID'].isin(hadm_set)].copy()
    if cv.empty:
        print("No CV rows for these hadms found in the file.")
    else:
        cv['CHARTTIME'] = pd.to_datetime(cv.get('CHARTTIME'), errors='coerce')
        # need admit times to compute timesteps
        adm = pd.read_csv("ADMISSIONS_sorted.csv", parse_dates=['ADMITTIME']).rename(columns={'HADM_ID':'hadm_id','ADMITTIME':'admit_time'})
        cv = cv.merge(adm[['hadm_id','admit_time']], left_on='HADM_ID', right_on='hadm_id', how='left')
        cv['dt_hours'] = (cv['CHARTTIME'] - cv['admit_time']).dt.total_seconds() / 3600.0
        TIMESTEP_HOURS = 6
        cv['timestep'] = (cv['dt_hours'] // TIMESTEP_HOURS).astype('Int64')
        cv = cv[cv['timestep'].notna()]

        # join to focus rows
        merged = cv.merge(focus, left_on=['HADM_ID','timestep'], right_on=['hadm_id','timestep'], how='inner', suffixes=('','_act'))
        print("CV rows overlapping focus actions (merged):", len(merged))

        if merged.empty:
            print("No overlapping CV rows found for the focused (hadm,timestep).")
        else:
            # examine top ITEMIDs and ORDERTEXT fields driving mappings
            cols_of_interest = []
            for c in ['ITEMID','ORIGINALROUTE','ORIGINALRATEUOM','AMOUNTUOM','RATEUOM','ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME','ORDERCATEGORYDESCRIPTION']:
                if c in merged.columns:
                    cols_of_interest.append(c)

            print("Cols we will inspect:", cols_of_interest)

            def top_counts(col, n=30):
                s = merged[col].fillna('').astype(str)
                c = s.value_counts().head(n)
                return c

            for col in cols_of_interest:
                print("\n--- Top values for", col, "(sample) ---")
                print(top_counts(col, n=20))

            # summary by action_label
            print("\nCounts by action_label in merged CV rows:")
            print(merged['action_label'].value_counts())

            # if you want to save a small sample for manual inspection:
            merged[['HADM_ID','CHARTTIME','timestep','action_label'] + cols_of_interest].head(200).to_csv("cv_merged_sample.csv", index=False)
            print("Saved cv_merged_sample.csv (first 200 rows) for manual inspection.")


Hadm IDs in traj: 12
Focused action rows (post-dedupe) to inspect: 55
CV rows overlapping focus actions (merged): 2191
Cols we will inspect: ['ITEMID', 'ORIGINALROUTE', 'ORIGINALRATEUOM', 'AMOUNTUOM', 'RATEUOM']

--- Top values for ITEMID (sample) ---
ITEMID
30013    476
30018    241
30120    161
30045    153
30025    119
30043    107
30128     99
30051     90
30125     75
30015     68
30126     65
30124     60
30118     60
30133     57
30114     51
30131     42
30050     40
30121     35
30178     27
30026     20
Name: count, dtype: int64

--- Top values for ORIGINALROUTE (sample) ---
ORIGINALROUTE
Intravenous Push        1262
IV Drip                  786
Intravenous Infusion     119
Oral                      15
Gastric/Feeding Tube       8
Nasogastric                1
Name: count, dtype: int64

--- Top values for ORIGINALRATEUOM (sample) ---
ORIGINALRATEUOM
         1697
ml/hr     494
Name: count, dtype: int64

--- Top values for AMOUNTUOM (sample) ---
AMOUNTUOM
ml     844
       812


In [None]:

import os, numpy as np, pandas as pd
from tqdm import tqdm
import re

MV_PATH = "INPUTEVENTS_MV_sorted.csv"
CHUNKSIZE = 150_000
TIMESTEP_HOURS = globals().get('TIMESTEP_HOURS', 6)
NUM_STEPS = globals().get('NUM_STEPS', 48 // TIMESTEP_HOURS)
FLUID_RE = re.compile(r'intravenous|intravenous push|intravenous drip|fluid|bolus|drip|crystalloid|colloid', flags=re.I)


if 'admissions' not in globals():
    raise RuntimeError("Load ADMISSIONS_sorted.csv into 'admissions' DataFrame first.")
ad = admissions.copy()
if 'ADMITTIME' in ad.columns and 'admit_time' not in ad.columns:
    ad = ad.rename(columns={'ADMITTIME':'admit_time'})
ad['admit_time'] = pd.to_datetime(ad['admit_time'], errors='coerce')
if 'DISCHTIME' in ad.columns:
    ad['disch_time'] = pd.to_datetime(ad['DISCHTIME'], errors='coerce')

ad['hadm_id'] = pd.to_numeric(ad.get('HADM_ID', ad.get('hadm_id')), errors='coerce').astype('Int64')
ad['subject_id'] = pd.to_numeric(ad.get('SUBJECT_ID', ad.get('subject_id')), errors='coerce').astype('Int64')
ad = ad[ad['admit_time'].notna()].copy()


hadm_to_admit = ad.set_index('hadm_id')['admit_time'].to_dict()
hadm_to_disch = ad.set_index('hadm_id').get('disch_time', pd.Series()).to_dict() if 'disch_time' in ad.columns else {}

subject_to_admissions = {sid: subdf.drop(columns=['subject_id']).reset_index(drop=True)
                         for sid, subdf in ad.groupby('subject_id')}
admissions_subjects = set(subject_to_admissions.keys())


if 'action_frames' not in globals():
    action_frames = []
mapped_count = 0
mapped_hadms = set()
processed_chunks = 0


usecols = [
    'ROW_ID','SUBJECT_ID','HADM_ID','STARTTIME','ENDTIME',
    'AMOUNTUOM','RATEUOM','ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME','ORDERCATEGORYDESCRIPTION',
    'AMOUNT','RATE'
]

if not os.path.exists(MV_PATH):
    print(f"{MV_PATH} not found â€” skipping MV mapping.")
else:

    it = pd.read_csv(MV_PATH, usecols=lambda c: c.upper() in [u.upper() for u in usecols],
                     chunksize=CHUNKSIZE, dtype=str, low_memory=False)
    for chunk in tqdm(it, desc="Processing INPUTEVENTS_MV chunks"):
        processed_chunks += 1
        chunk.columns = [c.upper() for c in chunk.columns]  # normalize


        chunk['STARTTIME_parsed'] = pd.to_datetime(chunk.get('STARTTIME','').astype(str).str.strip(), errors='coerce')
        chunk['ENDTIME_parsed'] = pd.to_datetime(chunk.get('ENDTIME','').astype(str).str.strip(), errors='coerce').fillna(chunk['STARTTIME_parsed'])


        if 'SUBJECT_ID' in chunk.columns:
            chunk['SUBJECT_ID'] = pd.to_numeric(chunk['SUBJECT_ID'], errors='coerce').astype('Int64')

            present_subjects = list(set(chunk['SUBJECT_ID'].dropna().astype(int).unique()) & admissions_subjects)
            if present_subjects:

                sub_chunk = chunk[chunk['SUBJECT_ID'].isin(present_subjects)].copy()

                for subj, subgrp in sub_chunk.groupby('SUBJECT_ID'):
                    adm_rows = subject_to_admissions.get(subj)
                    if adm_rows is None or adm_rows.empty:
                        continue

                    subgrp = subgrp.reset_index(drop=True)
                    adm = adm_rows.copy().reset_index(drop=True)
                    subgrp['_tmp'] = 1
                    adm['_tmp'] = 1
                    merge = subgrp.merge(adm, on='_tmp', suffixes=('','_adm')).drop(columns=['_tmp'])

                    cond_admit_ge = merge['STARTTIME_parsed'] >= merge['admit_time']
                    if 'disch_time' in merge.columns:
                        cond_before_disch = (merge['STARTTIME_parsed'] <= merge['disch_time']) | merge['disch_time'].isna()
                        mask = cond_admit_ge & cond_before_disch
                    else:
                        mask = cond_admit_ge
                    merge = merge.loc[mask].copy()
                    if merge.empty:
                        continue

                    dt_start = (merge['STARTTIME_parsed'] - merge['admit_time']).dt.total_seconds() / 3600.0
                    dt_end   = (merge['ENDTIME_parsed'] - merge['admit_time']).dt.total_seconds() / 3600.0
                    merge['start_idx'] = (np.floor(dt_start / TIMESTEP_HOURS)).astype('Int64')
                    merge['end_idx']   = (np.floor(dt_end / TIMESTEP_HOURS)).astype('Int64')
                    merge = merge[merge['start_idx'].notna()]
                    if merge.empty:
                        continue

                    ord_cols = [c for c in ['ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME','ORDERCATEGORYDESCRIPTION'] if c in merge.columns]
                    ord_text = merge[ord_cols].fillna('').agg(' '.join, axis=1).str.lower() if ord_cols else pd.Series('', index=merge.index)
                    amt_uom = merge.get('AMOUNTUOM', pd.Series('', index=merge.index)).fillna('').str.lower()
                    rate_uom = merge.get('RATEUOM', pd.Series('', index=merge.index)).fillna('').str.lower()
                    cond_antibiotic = ord_text.str.contains('antibiotic|antibiotics', na=False)
                    cond_press = ord_text.str.contains('press|pressor|inotrope', na=False) | rate_uom.str.contains('mcgkgmin|mcg/kg/min', na=False)
                    cond_fluid = ord_text.str.contains(FLUID_RE, na=False) | amt_uom.isin(['ml','l','ml/hr','ml/hour'])
                    merge['action_label'] = 'other'
                    merge.loc[cond_antibiotic, 'action_label'] = 'antibiotic'
                    merge.loc[cond_press, 'action_label'] = 'vasopressor'
                    merge.loc[cond_fluid, 'action_label'] = 'fluid_bolus'
                    merge.loc[ord_text.str.contains('insulin', na=False), 'action_label'] = 'insulin'

                    hadm_arr = merge['hadm_id'].fillna(-1).astype('int64').to_numpy()
                    start_arr = merge['start_idx'].astype('int64').to_numpy()
                    end_arr = merge['end_idx'].astype('int64').to_numpy()
                    label_arr = merge['action_label'].to_numpy()
                    df_expl = explode_ranges_vectorized(hadm_arr, start_arr, end_arr, label_arr)
                    if df_expl is not None:
                        df_expl['hadm_id'] = df_expl['hadm_id'].replace({-1: pd.NA}).astype('Int64')
                        action_frames.append(df_expl)
                        mapped_count += len(df_expl)
                        mapped_hadms.update(df_expl['hadm_id'].dropna().astype(int).unique())


        if 'HADM_ID' in chunk.columns:

            chunk['HADM_ID'] = pd.to_numeric(chunk['HADM_ID'], errors='coerce').astype('Int64')
            hv = chunk[chunk['STARTTIME_parsed'].notna() & chunk['HADM_ID'].notna()].copy()
            if not hv.empty:

                hv = hv[hv['HADM_ID'].isin(hadm_to_admit.keys())]
                if not hv.empty:

                    hv['admit_time_lookup'] = hv['HADM_ID'].map(hadm_to_admit)
                    hv['disch_time_lookup'] = hv['HADM_ID'].map(hadm_to_disch) if hadm_to_disch else pd.NaT

                    cond_admit_ge = hv['STARTTIME_parsed'] >= hv['admit_time_lookup']
                    if hadm_to_disch:
                        cond_before_disch = (hv['STARTTIME_parsed'] <= hv['disch_time_lookup']) | hv['disch_time_lookup'].isna()
                        hv = hv[cond_admit_ge & cond_before_disch]
                    else:
                        hv = hv[cond_admit_ge]
                    if not hv.empty:
                        dt_start = (hv['STARTTIME_parsed'] - hv['admit_time_lookup']).dt.total_seconds() / 3600.0
                        dt_end   = (hv['ENDTIME_parsed'] - hv['admit_time_lookup']).dt.total_seconds() / 3600.0
                        hv['start_idx'] = (np.floor(dt_start / TIMESTEP_HOURS)).astype('Int64')
                        hv['end_idx'] = (np.floor(dt_end / TIMESTEP_HOURS)).astype('Int64')
                        hv = hv[hv['start_idx'].notna()]
                        if not hv.empty:

                            ord_cols = [c for c in ['ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME','ORDERCATEGORYDESCRIPTION'] if c in hv.columns]
                            ord_text = hv[ord_cols].fillna('').agg(' '.join, axis=1).str.lower() if ord_cols else pd.Series('', index=hv.index)
                            amt_uom = hv.get('AMOUNTUOM', pd.Series('', index=hv.index)).fillna('').str.lower()
                            rate_uom = hv.get('RATEUOM', pd.Series('', index=hv.index)).fillna('').str.lower()
                            cond_antibiotic = ord_text.str.contains('antibiotic|antibiotics', na=False)
                            cond_press = ord_text.str.contains('press|pressor|inotrope', na=False) | rate_uom.str.contains('mcgkgmin|mcg/kg/min', na=False)
                            cond_fluid = ord_text.str.contains(FLUID_RE, na=False) | amt_uom.isin(['ml','l','ml/hr','ml/hour'])
                            hv['action_label'] = 'other'
                            hv.loc[cond_antibiotic, 'action_label'] = 'antibiotic'
                            hv.loc[cond_press, 'action_label'] = 'vasopressor'
                            hv.loc[cond_fluid, 'action_label'] = 'fluid_bolus'
                            hv.loc[ord_text.str.contains('insulin', na=False), 'action_label'] = 'insulin'

                            hadm_arr = hv['HADM_ID'].astype('int64').to_numpy()
                            start_arr = hv['start_idx'].astype('int64').to_numpy()
                            end_arr = hv['end_idx'].astype('int64').to_numpy()
                            label_arr = hv['action_label'].to_numpy()
                            df_expl = explode_ranges_vectorized(hadm_arr, start_arr, end_arr, label_arr)
                            if df_expl is not None:
                                action_frames.append(df_expl)
                                mapped_count += len(df_expl)
                                mapped_hadms.update(df_expl['hadm_id'].dropna().astype(int).unique())

    # summary
    print(f"\nFinished processing MV file in {processed_chunks} chunks.")
    print(f"Total MV-derived timestep rows produced (approx): {mapped_count:,}")
    print(f"Unique hadm_ids discovered in MV mapping (count): {len([h for h in mapped_hadms if pd.notna(h)])}")
    if len(mapped_hadms) > 0:
        print("Example hadm_ids (first 20):", list([h for h in mapped_hadms if pd.notna(h)])[:20])


Processing INPUTEVENTS_MV chunks: 14it [00:33,  2.42s/it]


Finished processing MV file in 14 chunks.
Total MV-derived timestep rows produced (approx): 86,466
Unique hadm_ids discovered in MV mapping (count): 663
Example hadm_ids (first 20): [np.int64(114690), np.int64(112643), np.int64(147462), np.int64(167945), np.int64(147469), np.int64(106510), np.int64(174095), np.int64(184338), np.int64(114707), np.int64(122900), np.int64(135186), np.int64(184345), np.int64(145440), np.int64(137250), np.int64(153637), np.int64(114726), np.int64(170024), np.int64(178216), np.int64(149546), np.int64(192557)]





In [None]:

import pandas as pd
import numpy as np
import os


TIMESTEP_HOURS = globals().get('TIMESTEP_HOURS', 6)
NUM_STEPS = globals().get('NUM_STEPS', 48 // TIMESTEP_HOURS)
PRIORITY = globals().get('PRIORITY', {'no_action':0,'other':1,'antibiotic':2,'diuretic':3,'fluid_bolus':4,'insulin':5,'vasopressor':6})
ACTION_LABELS = globals().get('ACTION_LABELS', ['no_action','vasopressor','fluid_bolus','diuretic','antibiotic','insulin','other'])


if 'mapped_hadms' not in globals():
    if 'action_frames' in globals() and len(action_frames) > 0:
        af_all = pd.concat(action_frames, ignore_index=True)
        mapped_hadms_set = set(af_all['hadm_id'].dropna().astype(int).unique())
    else:
        raise RuntimeError("mapped_hadms not available and action_frames is empty or missing.")
else:
    mapped_hadms_set = set([int(x) for x in mapped_hadms if pd.notna(x)])

print(f"Mapped hadm count to build trajectories for: {len(mapped_hadms_set)}")


if 'admissions' not in globals():
    raise RuntimeError("admissions DataFrame not found â€” load ADMISSIONS_sorted.csv into variable 'admissions' first.")

ad = admissions.copy()

if 'HADM_ID' in ad.columns and 'hadm_id' not in ad.columns:
    ad = ad.rename(columns={'HADM_ID':'hadm_id'})
ad['hadm_id'] = pd.to_numeric(ad['hadm_id'], errors='coerce').astype('Int64')
if 'ADMITTIME' in ad.columns and 'admit_time' not in ad.columns:
    ad = ad.rename(columns={'ADMITTIME':'admit_time'})
ad['admit_time'] = pd.to_datetime(ad['admit_time'], errors='coerce')

ad_sub = ad[ad['hadm_id'].isin(mapped_hadms_set)].copy()
print(f"Admissions available for selected hadms: {len(ad_sub)}")


rows = []
for _, r in ad_sub.iterrows():
    hadm = int(r['hadm_id'])
    subj = int(r['SUBJECT_ID']) if 'SUBJECT_ID' in r and pd.notna(r['SUBJECT_ID']) else np.nan
    admit_time = r['admit_time']

    hr = r.get('heart_rate', np.nan)
    sbp = r.get('sys_bp', np.nan)
    creat = r.get('creatinine', np.nan)
    for t in range(NUM_STEPS):
        rows.append({
            'subject_id': subj,
            'hadm_id': hadm,
            'timestep': t,
            'time_since_admit_hours': t * TIMESTEP_HOURS,
            'admit_time': admit_time,
            'heart_rate': hr,
            'sys_bp': sbp,
            'creatinine': creat,
        })

traj_mv_df = pd.DataFrame(rows)
print("Built base traj_mv_df rows:", len(traj_mv_df))


if 'action_frames' in globals() and len(action_frames) > 0:
    actions_df = pd.concat(action_frames, ignore_index=True)

    actions_df = actions_df[actions_df['hadm_id'].notna() & actions_df['timestep'].notna()].copy()

    actions_df['hadm_id'] = actions_df['hadm_id'].astype(int)
    actions_df['timestep'] = actions_df['timestep'].astype(int)

    if 'action_label' not in actions_df.columns and 'action' in actions_df.columns:
        actions_df = actions_df.rename(columns={'action':'action_label'})
else:
    actions_df = pd.DataFrame(columns=['hadm_id','timestep','action_label'])

print("Total action rows available to merge:", len(actions_df))


if not actions_df.empty:

    actions_df['priority'] = actions_df['action_label'].map(lambda x: PRIORITY.get(x, 1))

    actions_df_sorted = actions_df.sort_values(['hadm_id','timestep','priority'], ascending=[True, True, False])
    actions_top = actions_df_sorted.drop_duplicates(subset=['hadm_id','timestep'], keep='first').copy()
    actions_top = actions_top[['hadm_id','timestep','action_label']].reset_index(drop=True)
else:
    actions_top = pd.DataFrame(columns=['hadm_id','timestep','action_label'])

print("Unique (hadm,timestep) with actions after priority reduction:", len(actions_top))


traj_mv_df = traj_mv_df.merge(actions_top, on=['hadm_id','timestep'], how='left')
traj_mv_df['mapped_action'] = traj_mv_df['action_label'].fillna('no_action')


action_to_code = {lbl:i for i,lbl in enumerate(ACTION_LABELS)}
traj_mv_df['action_code'] = traj_mv_df['mapped_action'].map(lambda x: action_to_code.get(x, action_to_code.get('other', 1))).astype(int)


out_path = "traj_mv_with_actions.csv"
traj_mv_df.to_csv(out_path, index=False)
print("Saved", out_path, "rows:", len(traj_mv_df))


counts = traj_mv_df['mapped_action'].value_counts().to_dict()
print("Action distribution (sample):", counts)


Mapped hadm count to build trajectories for: 663
Admissions available for selected hadms: 663
Built base traj_mv_df rows: 5304
Total action rows available to merge: 1349694
Unique (hadm,timestep) with actions after priority reduction: 59650
Saved traj_mv_with_actions.csv rows: 5304
Action distribution (sample): {'fluid_bolus': 3203, 'no_action': 1357, 'insulin': 720, 'other': 21, 'vasopressor': 2, 'antibiotic': 1}


In [None]:
"""
Publication-quality mapping pipeline (v2): map prescriptions / inputevents / procedures -> timestep actions


Key safety & quality features including:
 - Robust time parsing and timezone-awareness option
 - Consistent ACTION_LABELS + explicit ACTION_PRIORITY
 - Chunked CSV processing (memory-friendly)
 - Safe handling of missing values using pandas nullable integers
 - Central helper to safely filter and explode ranges (avoids NA->int errors)
 - Deterministic priority-based deduplication
 - Diagnostics and a small synthetic unit test

"""

from __future__ import annotations
import argparse
import logging
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import re

# -------------------------- Config --------------------------
ACTION_LABELS = ['no_action', 'vasopressor', 'fluid_bolus', 'diuretic', 'antibiotic', 'insulin', 'other']
ACTION_PRIORITY = {'no_action': 0, 'other': 1, 'antibiotic': 2, 'diuretic': 3, 'fluid_bolus': 4, 'insulin': 5, 'vasopressor': 6}
DEFAULT_USECOLS_MV = [
    'ROW_ID','SUBJECT_ID','HADM_ID','STARTTIME','ENDTIME','AMOUNTUOM','RATEUOM',
    'ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME','ORDERCATEGORYDESCRIPTION','AMOUNT','RATE','ITEMID'
]

PRESC_MAP_SAFE = {
    'antibiotic': [r'\bvancomycin\b', r'\bampicillin\b', r'\bclindamycin\b', r'\bpiperacillin[- ]?tazobactam\b',
                   r'\bmeropenem\b', r'\bciprofloxacin\b', r'\blevofloxacin\b', r'\bazithro\w*', r'\bgentamicin\b',
                   r'\bceftriaxone\b', r'\bcefazolin\b', r'\bmetronidazole\b'],
    'diuretic': [r'\bfurosemide\b', r'\bbumetanide\b', r'\btorsemide\b', r'\bhydrochlorothiazide\b', r'\bspironolactone\b'],
    'vasopressor': [r'\bnorepinephrine\b', r'\bphenylephrine\b', r'\bepinephrine\b', r'\bvasopressin\b', r'\bdopamine\b'],
    'insulin': [r'\binsulin glargine\b', r'\binsulin lispro\b', r'\binsulin aspart\b', r'\binsulin\b']
}
PRESC_REGEX = {k: re.compile('|'.join(v), flags=re.I) for k, v in PRESC_MAP_SAFE.items()}
FLUID_RE = re.compile(r'intravenous|intravenous push|intravenous drip|fluid|bolus|drip|crystalloid|colloid', flags=re.I)

# -------------------------- Helpers --------------------------

def setup_logging(level=logging.INFO):
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=level)


def hours_to_timestep(hours_series: pd.Series, timestep_hours: int, num_steps: int) -> pd.Series:
    idx_float = np.floor(hours_series / float(timestep_hours))
    idx = pd.Series(idx_float, index=hours_series.index).astype('Int64')
    mask_outside = (idx < 0) | (idx >= num_steps)
    idx.loc[mask_outside] = pd.NA
    return idx


def explode_ranges_vectorized(hadm_arr: np.ndarray, start_arr: np.ndarray, end_arr: np.ndarray, label_arr: np.ndarray, num_steps: int) -> Optional[pd.DataFrame]:
    if len(start_arr) == 0:
        return None
    valid_mask = (start_arr <= end_arr)
    if not valid_mask.any():
        return None
    starts = start_arr[valid_mask]; ends = end_arr[valid_mask]
    hadms = hadm_arr[valid_mask]; labels = label_arr[valid_mask]
    ranges = [np.arange(s, e + 1, dtype='int64') for s, e in zip(starts, ends)]
    lengths = np.array([r.size for r in ranges], dtype='int64')
    if lengths.sum() == 0:
        return None
    all_timesteps = np.concatenate(ranges)
    all_hadms = np.repeat(hadms, lengths)
    all_labels = np.repeat(labels, lengths)
    df = pd.DataFrame({'hadm_id': all_hadms, 'timestep': all_timesteps, 'action_label': all_labels})
    df = df[(df['timestep'] >= 0) & (df['timestep'] < num_steps)].copy()
    return df if not df.empty else None


def safe_filter_and_explode(df_in: pd.DataFrame, hadm_col: str, start_col: str, end_col: str, label_col: str, num_steps: int) -> Optional[pd.DataFrame]:
    for c in (hadm_col, start_col, end_col, label_col):
        if c not in df_in.columns:
            return None
    mask_good = df_in[hadm_col].notna() & df_in[start_col].notna() & df_in[end_col].notna()
    df_good = df_in.loc[mask_good].copy()
    if df_good.empty:
        return None
    hadm_arr = df_good[hadm_col].astype('int64').to_numpy()
    start_arr = df_good[start_col].astype('int64').to_numpy()
    end_arr = df_good[end_col].astype('int64').to_numpy()
    label_arr = df_good[label_col].to_numpy()
    return explode_ranges_vectorized(hadm_arr, start_arr, end_arr, label_arr, num_steps)

# -------------------------- Source processors --------------------------

def process_prescriptions(path: Path, hadm_keep: set, admissions_small: pd.DataFrame, timestep_hours: int, num_steps: int, chunksize: int = 50000) -> List[pd.DataFrame]:
    frames = []
    if not path.exists():
        logging.info('PRESCRIPTIONS missing â€” skipping')
        return frames
    for chunk in pd.read_csv(path, chunksize=chunksize, dtype=str, low_memory=False):
        chunk.columns = [c.upper() for c in chunk.columns]
        if 'HADM_ID' not in chunk.columns:
            continue
        chunk['HADM_ID'] = pd.to_numeric(chunk['HADM_ID'], errors='coerce').astype('Int64')
        chunk = chunk[chunk['HADM_ID'].isin(hadm_keep)]
        if chunk.empty:
            continue
        chunk['STARTDATE_PARSED'] = pd.to_datetime(chunk.get('STARTDATE','').astype(str).str.strip(), errors='coerce')
        chunk['ENDDATE_PARSED'] = pd.to_datetime(chunk.get('ENDDATE','').astype(str).str.strip(), errors='coerce').fillna(chunk['STARTDATE_PARSED'])
        chunk = chunk.merge(admissions_small, left_on='HADM_ID', right_on='hadm_id', how='left')
        dt_start = (chunk['STARTDATE_PARSED'] - chunk['admit_time']).dt.total_seconds() / 3600.0
        dt_end = (chunk['ENDDATE_PARSED'] - chunk['admit_time']).dt.total_seconds() / 3600.0
        chunk['start_idx'] = hours_to_timestep(dt_start, timestep_hours, num_steps).astype('Int64')
        chunk['end_idx'] = hours_to_timestep(dt_end, timestep_hours, num_steps).astype('Int64')
        chunk = chunk[chunk['start_idx'].notna()]
        text_cols = [c for c in ['DRUG','DRUG_NAME_POE','DRUG_NAME_GENERIC'] if c in chunk.columns]
        chunk['combined_text'] = chunk[text_cols].fillna('').agg(' '.join, axis=1).str.lower() if text_cols else ''
        chunk['action_label'] = 'other'
        for lbl, rx in PRESC_REGEX.items():
            chunk.loc[chunk['combined_text'].str.contains(rx, na=False), 'action_label'] = lbl
        df_expl = safe_filter_and_explode(chunk, hadm_col='HADM_ID', start_col='start_idx', end_col='end_idx', label_col='action_label', num_steps=num_steps)
        if df_expl is not None:
            df_expl['source'] = 'prescriptions'
            frames.append(df_expl)
    return frames


def process_inputevents_mv(path: Path, hadm_keep: set, admissions: pd.DataFrame, timestep_hours: int, num_steps: int, chunksize: int = 150000) -> List[pd.DataFrame]:
    frames = []
    if not path.exists():
        logging.info('INPUTEVENTS_MV missing â€” skipping')
        return frames
    hadm_to_admit = admissions.set_index('hadm_id')['admit_time'].to_dict()
    hadm_to_disch = admissions.set_index('hadm_id').get('disch_time', pd.Series()).to_dict() if 'disch_time' in admissions.columns else {}
    subject_to_admissions = {sid: subdf.drop(columns=['subject_id']).reset_index(drop=True) for sid, subdf in admissions.groupby('subject_id')} if 'subject_id' in admissions.columns else {}
    usecols_upper = [c.upper() for c in DEFAULT_USECOLS_MV]
    for chunk in pd.read_csv(path, usecols=lambda c: c.upper() in usecols_upper, chunksize=chunksize, dtype=str, low_memory=False):
        chunk.columns = [c.upper() for c in chunk.columns]
        chunk['STARTTIME_parsed'] = pd.to_datetime(chunk.get('STARTTIME','').astype(str).str.strip(), errors='coerce')
        chunk['ENDTIME_parsed'] = pd.to_datetime(chunk.get('ENDTIME','').astype(str).str.strip(), errors='coerce').fillna(chunk['STARTTIME_parsed'])
        if 'HADM_ID' in chunk.columns:
            chunk['HADM_ID'] = pd.to_numeric(chunk['HADM_ID'], errors='coerce').astype('Int64')
            hv = chunk[chunk['STARTTIME_parsed'].notna() & chunk['HADM_ID'].notna()].copy()
            if not hv.empty:
                hv = hv[hv['HADM_ID'].isin(hadm_to_admit.keys())]
                if not hv.empty:
                    hv['admit_time_lookup'] = hv['HADM_ID'].map(hadm_to_admit)
                    hv['disch_time_lookup'] = hv['HADM_ID'].map(hadm_to_disch) if hadm_to_disch else pd.NaT
                    cond_admit_ge = hv['STARTTIME_parsed'] >= hv['admit_time_lookup']
                    if hadm_to_disch:
                        hv = hv[cond_admit_ge & ((hv['STARTTIME_parsed'] <= hv['disch_time_lookup']) | hv['disch_time_lookup'].isna())]
                    else:
                        hv = hv[cond_admit_ge]
                    if not hv.empty:
                        dt_start = (hv['STARTTIME_parsed'] - hv['admit_time_lookup']).dt.total_seconds() / 3600.0
                        dt_end = (hv['ENDTIME_parsed'] - hv['admit_time_lookup']).dt.total_seconds() / 3600.0
                        hv['start_idx'] = hours_to_timestep(dt_start, timestep_hours, num_steps).astype('Int64')
                        hv['end_idx'] = hours_to_timestep(dt_end, timestep_hours, num_steps).astype('Int64')
                        hv = hv[hv['start_idx'].notna()].copy()
                        if not hv.empty:
                            ord_cols = [c for c in ['ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME','ORDERCATEGORYDESCRIPTION'] if c in hv.columns]
                            ord_text = hv[ord_cols].fillna('').agg(' '.join, axis=1).str.lower() if ord_cols else pd.Series('', index=hv.index)
                            amt_uom = hv.get('AMOUNTUOM', pd.Series('', index=hv.index)).fillna('').str.lower()
                            rate_uom = hv.get('RATEUOM', pd.Series('', index=hv.index)).fillna('').str.lower()
                            cond_antibiotic = ord_text.str.contains(r'antibiotic|antibiotics', na=False)
                            cond_press = ord_text.str.contains(r'press|pressor|inotrope', na=False) | rate_uom.str.contains(r'mcgkgmin|mcg/kg/min', na=False)
                            cond_fluid = ord_text.str.contains(FLUID_RE, na=False) | amt_uom.isin(['ml','l','ml/hr','ml/hour'])
                            hv['action_label'] = 'other'
                            hv.loc[cond_antibiotic, 'action_label'] = 'antibiotic'
                            hv.loc[cond_press, 'action_label'] = 'vasopressor'
                            hv.loc[cond_fluid, 'action_label'] = 'fluid_bolus'
                            hv.loc[ord_text.str.contains('insulin', na=False), 'action_label'] = 'insulin'
                            df_expl = safe_filter_and_explode(hv, hadm_col='HADM_ID', start_col='start_idx', end_col='end_idx', label_col='action_label', num_steps=num_steps)
                            if df_expl is not None:
                                df_expl['source'] = 'inputevents_mv'
                                frames.append(df_expl)
        if 'SUBJECT_ID' in chunk.columns and subject_to_admissions:
            chunk['SUBJECT_ID'] = pd.to_numeric(chunk['SUBJECT_ID'], errors='coerce').astype('Int64')
            present_subjects = set(chunk['SUBJECT_ID'].dropna().astype(int).unique()) & set(subject_to_admissions.keys())
            if present_subjects:
                sub_chunk = chunk[chunk['SUBJECT_ID'].isin(present_subjects)].copy()
                for subj, subgrp in sub_chunk.groupby('SUBJECT_ID'):
                    adm_rows = subject_to_admissions.get(subj)
                    if adm_rows is None or adm_rows.empty:
                        continue
                    subgrp = subgrp.reset_index(drop=True)
                    adm = adm_rows.copy().reset_index(drop=True)
                    subgrp['_tmp'] = 1
                    adm['_tmp'] = 1
                    merge = subgrp.merge(adm, on='_tmp', suffixes=('','_adm')).drop(columns=['_tmp'])
                    cond_admit_ge = merge['STARTTIME_parsed'] >= merge['admit_time']
                    if 'disch_time' in merge.columns:
                        cond_before_disch = (merge['STARTTIME_parsed'] <= merge['disch_time']) | merge['disch_time'].isna()
                        mask = cond_admit_ge & cond_before_disch
                    else:
                        mask = cond_admit_ge
                    merge = merge.loc[mask].copy()
                    if merge.empty:
                        continue
                    dt_start = (merge['STARTTIME_parsed'] - merge['admit_time']).dt.total_seconds() / 3600.0
                    dt_end = (merge['ENDTIME_parsed'] - merge['admit_time']).dt.total_seconds() / 3600.0
                    merge['start_idx'] = hours_to_timestep(dt_start, timestep_hours, num_steps).astype('Int64')
                    merge['end_idx'] = hours_to_timestep(dt_end, timestep_hours, num_steps).astype('Int64')
                    merge = merge[merge['start_idx'].notna()]
                    if merge.empty:
                        continue
                    ord_cols = [c for c in ['ORDERCATEGORYNAME','SECONDARYORDERCATEGORYNAME','ORDERCATEGORYDESCRIPTION'] if c in merge.columns]
                    ord_text = merge[ord_cols].fillna('').agg(' '.join, axis=1).str.lower() if ord_cols else pd.Series('', index=merge.index)
                    amt_uom = merge.get('AMOUNTUOM', pd.Series('', index=merge.index)).fillna('').str.lower()
                    rate_uom = merge.get('RATEUOM', pd.Series('', index=merge.index)).fillna('').str.lower()
                    cond_antibiotic = ord_text.str.contains(r'antibiotic|antibiotics', na=False)
                    cond_press = ord_text.str.contains(r'press|pressor|inotrope', na=False) | rate_uom.str.contains(r'mcgkgmin|mcg/kg/min', na=False)
                    cond_fluid = ord_text.str.contains(FLUID_RE, na=False) | amt_uom.isin(['ml','l','ml/hr','ml/hour'])
                    merge['action_label'] = 'other'
                    merge.loc[cond_antibiotic, 'action_label'] = 'antibiotic'
                    merge.loc[cond_press, 'action_label'] = 'vasopressor'
                    merge.loc[cond_fluid, 'action_label'] = 'fluid_bolus'
                    merge.loc[ord_text.str.contains('insulin', na=False), 'action_label'] = 'insulin'
                    grouped = merge.groupby(['hadm_id','start_idx','end_idx','action_label'], dropna=False, as_index=False).agg({'AMOUNT':'first','RATE':'first'})
                    df_expl = safe_filter_and_explode(grouped, hadm_col='hadm_id', start_col='start_idx', end_col='end_idx', label_col='action_label', num_steps=num_steps)
                    if df_expl is not None:
                        df_expl['source'] = 'inputevents_mv_subject'
                        frames.append(df_expl)
    return frames


def process_inputevents_cv(path: Path, hadm_keep: set, admissions_small: pd.DataFrame, timestep_hours: int, num_steps: int, chunksize: int = 150000) -> List[pd.DataFrame]:
    frames = []
    if not path.exists():
        logging.info('INPUTEVENTS_CV missing â€” skipping')
        return frames
    for chunk in pd.read_csv(path, chunksize=chunksize, dtype=str, low_memory=False):
        chunk.columns = [c.upper() for c in chunk.columns]
        if 'HADM_ID' not in chunk.columns:
            continue
        chunk['HADM_ID'] = pd.to_numeric(chunk['HADM_ID'], errors='coerce').astype('Int64')
        chunk['CHARTTIME_PARSED'] = pd.to_datetime(chunk.get('CHARTTIME','').astype(str).str.strip(), errors='coerce')
        chunk = chunk.merge(admissions_small, left_on='HADM_ID', right_on='hadm_id', how='left')
        dt_chart = (chunk['CHARTTIME_PARSED'] - chunk['admit_time']).dt.total_seconds() / 3600.0
        chunk['timestep'] = hours_to_timestep(dt_chart, timestep_hours, num_steps).astype('Int64')
        chunk = chunk[chunk['timestep'].notna()]
        if chunk.empty:
            continue
        meta_cols = [c for c in ['ORIGINALROUTE','ORIGINALRATEUOM','ORDERID','AMOUNTUOM','RATEUOM'] if c in chunk.columns]
        meta_text = chunk[meta_cols].fillna('').agg(' '.join, axis=1).str.lower() if meta_cols else pd.Series('', index=chunk.index)
        amt_uom = chunk.get('AMOUNTUOM', pd.Series('', index=chunk.index)).fillna('').str.lower()
        rate_uom = chunk.get('RATEUOM', pd.Series('', index=chunk.index)).fillna('').str.lower()
        is_bolus = meta_text.str.contains(r'\bbolus\b|\bpush\b|\bstat\b', na=False)
        is_infusion = meta_text.str.contains(r'\bdrip\b|\binfusion\b|\bcontinuous\b', na=False)
        cond_fluid = is_bolus & (~is_infusion)
        cond_press = rate_uom.str.contains(r'mcgkgmin|mcg/kg/min', na=False) | meta_text.str.contains(r'press|pressor|inotrope', na=False)
        cond_ab = meta_text.str.contains(r'antibiotic', na=False)
        chunk['action_label'] = 'other'
        chunk.loc[cond_ab, 'action_label'] = 'antibiotic'
        chunk.loc[cond_press, 'action_label'] = 'vasopressor'
        chunk.loc[cond_fluid, 'action_label'] = 'fluid_bolus'
        valid_mask = chunk['HADM_ID'].notna() & chunk['timestep'].notna()
        if valid_mask.any():
            df_cv = chunk.loc[valid_mask, ['HADM_ID','timestep','action_label']].copy()
            df_cv.columns = ['hadm_id','timestep','action_label']
            df_cv['source'] = 'inputevents_cv'
            df_cv['hadm_id'] = df_cv['hadm_id'].astype('int64')
            df_cv['timestep'] = df_cv['timestep'].astype('int64')
            frames.append(df_cv)
    return frames


def process_procedureevents_mv(path: Path, hadm_keep: set, admissions_small: pd.DataFrame, timestep_hours: int, num_steps: int, chunksize: int = 50000) -> List[pd.DataFrame]:
    frames = []
    if not path.exists():
        logging.info('PROCEDUREEVENTS_MV missing â€” skipping')
        return frames
    for chunk in pd.read_csv(path, chunksize=chunksize, dtype=str, low_memory=False):
        chunk.columns = [c.upper() for c in chunk.columns]
        if 'HADM_ID' not in chunk.columns:
            continue
        chunk['HADM_ID'] = pd.to_numeric(chunk['HADM_ID'], errors='coerce').astype('Int64')
        chunk = chunk[chunk['HADM_ID'].isin(hadm_keep)]
        if chunk.empty:
            continue
        chunk['STARTTIME_PARSED'] = pd.to_datetime(chunk.get('STARTTIME','').astype(str).str.strip(), errors='coerce')
        chunk['ENDTIME_PARSED'] = pd.to_datetime(chunk.get('ENDTIME','').astype(str).str.strip(), errors='coerce').fillna(chunk['STARTTIME_PARSED'])
        chunk = chunk.merge(admissions_small, left_on='HADM_ID', right_on='hadm_id', how='left')
        dt_start = (chunk['STARTTIME_PARSED'] - chunk['admit_time']).dt.total_seconds() / 3600.0
        dt_end = (chunk['ENDTIME_PARSED'] - chunk['admit_time']).dt.total_seconds() / 3600.0
        chunk['start_idx'] = hours_to_timestep(dt_start, timestep_hours, num_steps).astype('Int64')
        chunk['end_idx'] = hours_to_timestep(dt_end, timestep_hours, num_steps).astype('Int64')
        chunk = chunk[chunk['start_idx'].notna()]
        if chunk.empty:
            continue
        chunk['action_label'] = 'other'
        df_expl = safe_filter_and_explode(chunk, hadm_col='HADM_ID', start_col='start_idx', end_col='end_idx', label_col='action_label', num_steps=num_steps)
        if df_expl is not None:
            df_expl['source'] = 'procedureevents_mv'
            frames.append(df_expl)
    return frames


def dedupe_actions(action_frames: List[pd.DataFrame]) -> pd.DataFrame:
    if not action_frames:
        return pd.DataFrame(columns=['hadm_id','timestep','action_label','source'])
    actions_df = pd.concat(action_frames, ignore_index=True)
    actions_df['hadm_id'] = pd.to_numeric(actions_df.get('hadm_id'), errors='coerce').astype('Int64')
    actions_df['timestep'] = pd.to_numeric(actions_df.get('timestep'), errors='coerce').astype('Int64')
    actions_df = actions_df.dropna(subset=['hadm_id','timestep']).copy()
    actions_df['hadm_id'] = actions_df['hadm_id'].astype('int64')
    actions_df['timestep'] = actions_df['timestep'].astype('int64')
    actions_df['action_label'] = actions_df['action_label'].astype(str).str.strip().str.lower()
    actions_df['priority'] = actions_df['action_label'].map(lambda x: ACTION_PRIORITY.get(x, 1))
    actions_sorted = actions_df.sort_values(['hadm_id','timestep','priority'], ascending=[True, True, False])
    actions_top = actions_sorted.drop_duplicates(subset=['hadm_id','timestep'], keep='first').reset_index(drop=True)
    return actions_top[['hadm_id','timestep','action_label','source']]


def merge_into_traj(traj_df: pd.DataFrame, actions_top: pd.DataFrame, output_path: Path) -> pd.DataFrame:
    traj_out = traj_df.merge(actions_top[['hadm_id','timestep','action_label','source']], on=['hadm_id','timestep'], how='left')
    traj_out['mapped_action'] = traj_out['action_label'].fillna('no_action')
    action_to_code = {lbl: int(i) for i, lbl in enumerate(ACTION_LABELS)}
    traj_out['mapped_action'] = traj_out['mapped_action'].fillna('no_action').astype(str)
    traj_out['action_code'] = traj_out['mapped_action'].map(lambda x: action_to_code.get(str(x), action_to_code.get('other', 1))).astype('Int64')
    traj_out.to_csv(output_path, index=False)
    return traj_out


def diagnostics_and_save(actions_df: pd.DataFrame, output_dir: Path) -> None:
    out_raw = output_dir / 'actions_raw.csv'
    out_top = output_dir / 'actions_top.csv'
    if not actions_df.empty:
        actions_df.to_csv(out_raw, index=False)
        actions_top = dedupe_actions([actions_df])
        actions_top.to_csv(out_top, index=False)
    summary = {
        'total_raw_rows': len(actions_df) if not actions_df.empty else 0,
        'unique_hadms': int(actions_df['hadm_id'].nunique()) if not actions_df.empty else 0,
        'unique_hadm_timestep_pairs': int(actions_df.groupby(['hadm_id','timestep']).ngroups) if not actions_df.empty else 0,
    }
    pd.DataFrame([summary]).to_csv(output_dir / 'traj_diag_summary.csv', index=False)


def build_base_traj_for_hadm_list(admissions: pd.DataFrame, mapped_hadms: set, timestep_hours: int, num_steps: int) -> pd.DataFrame:
    rows = []
    ad_sub = admissions[admissions['hadm_id'].isin(mapped_hadms)].copy()
    for _, r in ad_sub.iterrows():
        hadm = int(r['hadm_id'])
        subj = int(r['SUBJECT_ID']) if 'SUBJECT_ID' in r and pd.notna(r['SUBJECT_ID']) else np.nan
        admit_time = r['admit_time'] if 'admit_time' in r else pd.NaT
        for t in range(num_steps):
            rows.append({'subject_id': subj, 'hadm_id': hadm, 'timestep': t, 'time_since_admit_hours': t * timestep_hours, 'admit_time': admit_time})
    traj_mv_df = pd.DataFrame(rows)
    return traj_mv_df


def synthetic_unit_test(tmp_dir: Path) -> None:
    admissions = pd.DataFrame([
        {'hadm_id': 1, 'subject_id': 100, 'admit_time': pd.to_datetime('2020-01-01 00:00:00')}
    ])
    admissions_small = admissions[['hadm_id','admit_time']].copy()
    presc = pd.DataFrame([
        {'HADM_ID': 1, 'STARTDATE': '2020-01-01 02:00:00', 'ENDDATE': '2020-01-01 02:00:00', 'DRUG': 'vancomycin'}
    ])
    presc_path = tmp_dir / 'PRESCRIPTIONS_synth.csv'
    presc.to_csv(presc_path, index=False)
    frames = process_prescriptions(presc_path, {1}, admissions_small, timestep_hours=6, num_steps=8, chunksize=10)
    assert frames and len(frames) > 0, 'Prescription processing failed in synthetic test.'
    logging.info('Synthetic prescription test passed.')


def main(data_dir: str, timestep_hours: int, num_steps: int, chunksize: int, debug: bool = False) -> None:
    setup_logging(logging.DEBUG if debug else logging.INFO)
    data_dir = Path(data_dir)
    paths = {
        'ADMISSIONS': data_dir / 'ADMISSIONS_sorted.csv',
        'PRESCRIPTIONS': data_dir / 'PRESCRIPTIONS_sorted.csv',
        'INPUTEVENTS_MV': data_dir / 'INPUTEVENTS_MV_sorted.csv',
        'INPUTEVENTS_CV': data_dir / 'INPUTEVENTS_CV_sorted.csv',
        'PROCEDUREEVENTS_MV': data_dir / 'PROCEDUREEVENTS_MV_sorted.csv',
    }
    if not paths['ADMISSIONS'].exists():
        raise RuntimeError(f'ADMISSIONS_sorted.csv not found in {data_dir}')
    admissions = pd.read_csv(paths['ADMISSIONS'], parse_dates=['ADMITTIME'], low_memory=False)
    if 'ADMITTIME' in admissions.columns and 'admit_time' not in admissions.columns:
        admissions = admissions.rename(columns={'ADMITTIME': 'admit_time'})
    if 'HADM_ID' in admissions.columns and 'hadm_id' not in admissions.columns:
        admissions = admissions.rename(columns={'HADM_ID': 'hadm_id'})
    admissions['hadm_id'] = pd.to_numeric(admissions['hadm_id'], errors='coerce').astype('Int64')
    admissions['admit_time'] = pd.to_datetime(admissions['admit_time'], errors='coerce')
    if 'DISCHTIME' in admissions.columns:
        admissions['disch_time'] = pd.to_datetime(admissions['DISCHTIME'], errors='coerce')
    admissions_small = admissions[['hadm_id','admit_time']].copy()
    traj_skeleton_path = data_dir / 'traj_df_skeleton.csv'
    if traj_skeleton_path.exists():
        traj_df = pd.read_csv(traj_skeleton_path)
        hadm_keep = set(pd.to_numeric(traj_df['hadm_id'], errors='coerce').dropna().astype(int).unique())
    else:
        hadm_keep = set(admissions['hadm_id'].dropna().astype(int).unique())
    action_frames: List[pd.DataFrame] = []
    action_frames += process_prescriptions(paths['PRESCRIPTIONS'], hadm_keep, admissions_small, timestep_hours, num_steps, chunksize=chunksize)
    action_frames += process_inputevents_mv(paths['INPUTEVENTS_MV'], hadm_keep, admissions, timestep_hours, num_steps, chunksize=chunksize)
    action_frames += process_inputevents_cv(paths['INPUTEVENTS_CV'], hadm_keep, admissions_small, timestep_hours, num_steps, chunksize=chunksize)
    action_frames += process_procedureevents_mv(paths['PROCEDUREEVENTS_MV'], hadm_keep, admissions_small, timestep_hours, num_steps, chunksize=chunksize)
    actions_raw = pd.concat(action_frames, ignore_index=True) if action_frames else pd.DataFrame(columns=['hadm_id','timestep','action_label','source'])
    output_dir = Path(data_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    actions_raw.to_csv(output_dir / 'actions_raw.csv', index=False)
    actions_top = dedupe_actions([actions_raw])
    actions_top.to_csv(output_dir / 'actions_top.csv', index=False)
    mapped_hadms = set(actions_top['hadm_id'].dropna().astype(int).unique())
    if traj_skeleton_path.exists():
        traj_df = pd.read_csv(traj_skeleton_path)
    else:
        traj_df = build_base_traj_for_hadm_list(admissions, mapped_hadms, timestep_hours, num_steps)
    traj_out = merge_into_traj(traj_df, actions_top, output_dir / 'traj_with_mapped_actions.csv')
    diagnostics_and_save(actions_raw, output_dir)


if __name__ == '__main__':
    import sys
    parser = argparse.ArgumentParser()
    parser.add_argument('--data-dir', type=str, default='.')
    parser.add_argument('--timestep-hours', type=int, default=6)
    parser.add_argument('--num-steps', type=int, default=8)
    parser.add_argument('--chunksize', type=int, default=150000)
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--run-synthetic-test', action='store_true')
    args, unknown = parser.parse_known_args()
    if unknown:
        logging.debug(f'Ignoring unknown args: {unknown}')
    if args.run_synthetic_test:
        tmp = Path('./tmp_synth')
        tmp.mkdir(exist_ok=True)
        synthetic_unit_test(tmp)
    main(data_dir=args.data_dir, timestep_hours=args.timestep_hours, num_steps=args.num_steps, chunksize=args.chunksize, debug=args.debug)

In [None]:
import pandas as pd
traj = pd.read_csv("traj_with_mapped_actions.csv")
print("âœ… Trajectory shape:", traj.shape)
print(traj.head())

# check distinct hadm_id and timesteps
print("Unique hadm_ids:", traj['hadm_id'].nunique())
print("Unique timesteps:", traj['timestep'].nunique())


âœ… Trajectory shape: (89288, 9)
   subject_id  hadm_id  timestep  time_since_admit_hours           admit_time  \
0           3   145834         0                       0  2101-10-20 19:08:00   
1           3   145834         1                       6  2101-10-20 19:08:00   
2           3   145834         2                      12  2101-10-20 19:08:00   
3           3   145834         3                      18  2101-10-20 19:08:00   
4           3   145834         4                      24  2101-10-20 19:08:00   

  action_label          source mapped_action  action_code  
0  vasopressor  inputevents_cv   vasopressor            1  
1  vasopressor  inputevents_cv   vasopressor            1  
2  vasopressor  inputevents_cv   vasopressor            1  
3  vasopressor  inputevents_cv   vasopressor            1  
4  vasopressor  inputevents_cv   vasopressor            1  
Unique hadm_ids: 11161
Unique timesteps: 8
