In [1]:
# --- Connect to Google BigQuery ---
from google.cloud import bigquery
# If you have set up authentication (e.g., GOOGLE_APPLICATION_CREDENTIALS env variable), this will work.
# Otherwise, you may need to authenticate manually or use a service account key.
# Replace 'your-gcp-project-id' with your actual GCP project ID.
bq_client = bigquery.Client(project='ml-for-healthcare-2025')
print('Connected to BigQuery!')


Connected to BigQuery!


In [2]:
import os
import pandas as pd
from typing import List, Optional

# Load the initial cohort subject IDs
cohort_path = 'data/initial_cohort.csv'  # relative to the project folder
if not os.path.exists(cohort_path):
    # fallback for alternate execution contexts
    alt_path = '../project/data/initial_cohort.csv'
    if os.path.exists(alt_path):
        cohort_path = alt_path

cohort_df = pd.read_csv(cohort_path)
if 'subject_id' not in cohort_df.columns:
    raise ValueError("initial_cohort.csv must contain a 'subject_id' column")
subject_ids: List[int] = cohort_df['subject_id'].dropna().astype(int).tolist()
print(f"Loaded {len(subject_ids)} subject IDs for the cohort from {cohort_path}.")


Loaded 32513 subject IDs for the cohort from data/initial_cohort.csv.


In [3]:
# --- Tiny BigQuery sanity check (use a public dataset to avoid permissions issues) ---
# This verifies your client works without touching restricted datasets like MIMIC-III.
try:
    test_query = """
    SELECT year, COUNT(1) AS n
    FROM `bigquery-public-data.samples.natality`
    WHERE year BETWEEN 2000 AND 2002
    GROUP BY year
    ORDER BY year
    LIMIT 3
    """
    _ = bq_client.query(test_query).to_dataframe()
    print("BigQuery client sanity check passed on public dataset.")
except Exception as e:
    print("BigQuery sanity check failed:", repr(e))


BigQuery client sanity check passed on public dataset.


In [4]:
# --- Helpers for robust BigQuery querying and parameter passing ---
from google.api_core import exceptions as gexc

def safe_bq_to_df(sql: str, job_config: Optional[bigquery.QueryJobConfig] = None) -> pd.DataFrame:
    try:
        return bq_client.query(sql, job_config=job_config).to_dataframe()
    except (gexc.Forbidden, gexc.NotFound, gexc.BadRequest) as e:
        print(f"Query failed: {type(e).__name__}: {e}")
        return pd.DataFrame()
    except Exception as e:
        print(f"Unexpected error: {type(e).__name__}: {e}")
        return pd.DataFrame()


In [5]:
# --- Admissions: restrict to first hospital admission per subject ---
from google.cloud import bigquery as bq

def get_first_admissions(subject_ids: List[int]) -> pd.DataFrame:
    sql = """
    SELECT subject_id, hadm_id, admittime, dischtime, deathtime, admission_type,
           admission_location, discharge_location, diagnosis, insurance, language,
           marital_status, ethnicity
    FROM `physionet-data.mimiciii_clinical.admissions`
    WHERE subject_id IN UNNEST(@subject_ids)
    ORDER BY subject_id, admittime
    """
    cfg = bq.QueryJobConfig(
        query_parameters=[bq.ArrayQueryParameter("subject_ids", "INT64", subject_ids)]
    )
    df = safe_bq_to_df(sql, job_config=cfg)
    if df.empty:
        return df
    first = (
        df.sort_values(["subject_id", "admittime"])\
          .groupby("subject_id", as_index=False)\
          .first()
    )
    return first

first_admissions_df = get_first_admissions(subject_ids)
if first_admissions_df.empty:
    print("No admissions returned (check MIMIC access).")
else:
    print(f"First admissions loaded: {len(first_admissions_df)} rows")

# Collect first-admission hadm_ids
hadm_ids: List[int] = first_admissions_df.get('hadm_id', pd.Series([], dtype='int')).dropna().astype(int).tolist()


First admissions loaded: 32513 rows


In [6]:
# --- Demographics from patients table ---

def get_demographics(subject_ids: List[int]) -> pd.DataFrame:
    sql = """
    SELECT subject_id, gender, dob, dod, expire_flag
    FROM `physionet-data.mimiciii_clinical.patients`
    WHERE subject_id IN UNNEST(@subject_ids)
    """
    cfg = bq.QueryJobConfig(
        query_parameters=[bq.ArrayQueryParameter("subject_ids", "INT64", subject_ids)]
    )
    return safe_bq_to_df(sql, job_config=cfg)

demographics_df = get_demographics(subject_ids) if subject_ids else pd.DataFrame()
print(f"Demographics rows: {len(demographics_df)}")


Demographics rows: 32513


In [7]:
# --- Vitals in first 48h of first admission (Chartevents) ---
# We filter by common vital labels via d_items to avoid itemid hard-coding across systems.

def get_vitals_48h(hadm_ids: List[int]) -> pd.DataFrame:
    if not hadm_ids:
        return pd.DataFrame()
    sql = """
    WITH first_adm AS (
      SELECT hadm_id, admittime
      FROM `physionet-data.mimiciii_clinical.admissions`
      WHERE hadm_id IN UNNEST(@hadm_ids)
    ), vitals AS (
      SELECT ce.subject_id, ce.hadm_id, ce.icustay_id, ce.charttime,
             di.label AS item_label, ce.valuenum, ce.valueuom
      FROM `physionet-data.mimiciii_clinical.chartevents` ce
      JOIN first_adm fa USING (hadm_id)
      JOIN `physionet-data.mimiciii_clinical.d_items` di ON di.itemid = ce.itemid
      WHERE ce.hadm_id IN UNNEST(@hadm_ids)
        AND ce.valuenum IS NOT NULL
        AND TIMESTAMP_DIFF(ce.charttime, fa.admittime, HOUR) BETWEEN 0 AND 48
        AND (
          REGEXP_CONTAINS(LOWER(di.label), r"heart rate|hr") OR
          REGEXP_CONTAINS(LOWER(di.label), r"respiratory rate|rr") OR
          REGEXP_CONTAINS(LOWER(di.label), r"temperature") OR
          REGEXP_CONTAINS(LOWER(di.label), r"(non?invasive )?systolic|sysbp|sbp") OR
          REGEXP_CONTAINS(LOWER(di.label), r"(non?invasive )?diastolic|diasbp|dbp") OR
          REGEXP_CONTAINS(LOWER(di.label), r"mean arterial|map") OR
          REGEXP_CONTAINS(LOWER(di.label), r"spo2|o2 saturation|oxygen saturation")
        )
    )
    SELECT * FROM vitals
    ORDER BY subject_id, charttime
    """
    cfg = bq.QueryJobConfig(
        query_parameters=[bq.ArrayQueryParameter("hadm_ids", "INT64", hadm_ids)]
    )
    return safe_bq_to_df(sql, job_config=cfg)

vitals_df = get_vitals_48h(hadm_ids)
print(f"Vitals events (<=48h): {len(vitals_df)}")


Vitals events (<=48h): 7919202


In [8]:
# --- Labs in first 48h of first admission (Labevents only) ---
# Filter to a common set via d_labitems labels

def get_labs_48h(hadm_ids: List[int]) -> pd.DataFrame:
    if not hadm_ids:
        return pd.DataFrame()
    sql = """
    WITH first_adm AS (
      SELECT hadm_id, admittime
      FROM `physionet-data.mimiciii_clinical.admissions`
      WHERE hadm_id IN UNNEST(@hadm_ids)
    ), labs AS (
      SELECT le.subject_id, le.hadm_id, le.charttime,
             dl.label AS item_label, le.valuenum, le.value AS value_text,
             le.valueuom, le.flag
      FROM `physionet-data.mimiciii_clinical.labevents` le
      JOIN first_adm fa USING (hadm_id)
      JOIN `physionet-data.mimiciii_clinical.d_labitems` dl ON dl.itemid = le.itemid
      WHERE le.hadm_id IN UNNEST(@hadm_ids)
        AND le.charttime IS NOT NULL
        AND TIMESTAMP_DIFF(le.charttime, fa.admittime, HOUR) BETWEEN 0 AND 48
        AND (
          REGEXP_CONTAINS(LOWER(dl.label), r"wbc|white blood") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"hemoglobin|hgb") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"hematocrit|hct") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"platelet") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"sodium|na\\b") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"potassium|k\\b") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"chloride|cl\\b") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"bicarbonate|hco3") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"bun|urea") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"creatinine") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"glucose") OR
          REGEXP_CONTAINS(LOWER(dl.label), r"lactate")
        )
    )
    SELECT * FROM labs
    ORDER BY subject_id, charttime
    """
    cfg = bq.QueryJobConfig(
        query_parameters=[bq.ArrayQueryParameter("hadm_ids", "INT64", hadm_ids)]
    )
    return safe_bq_to_df(sql, job_config=cfg)

labs_df = get_labs_48h(hadm_ids)
print(f"Lab events (<=48h): {len(labs_df)}")


Lab events (<=48h): 1239058


In [9]:
# --- Additional modality #1: Prescriptions in first 48h ---

def get_prescriptions_48h(hadm_ids: List[int]) -> pd.DataFrame:
    if not hadm_ids:
        return pd.DataFrame()
    sql = """
    WITH first_adm AS (
      SELECT hadm_id, admittime
      FROM `physionet-data.mimiciii_clinical.admissions`
      WHERE hadm_id IN UNNEST(@hadm_ids)
    )
    SELECT pr.subject_id, pr.hadm_id, pr.startdate, pr.enddate,
           pr.drug, pr.drug_type, pr.formulary_drug_cd, pr.route
    FROM `physionet-data.mimiciii_clinical.prescriptions` pr
    JOIN first_adm fa USING (hadm_id)
    WHERE pr.hadm_id IN UNNEST(@hadm_ids)
      AND pr.startdate IS NOT NULL
      AND TIMESTAMP_DIFF(pr.startdate, fa.admittime, HOUR) BETWEEN 0 AND 48
    ORDER BY subject_id, startdate
    """
    cfg = bq.QueryJobConfig(
        query_parameters=[bq.ArrayQueryParameter("hadm_ids", "INT64", hadm_ids)]
    )
    return safe_bq_to_df(sql, job_config=cfg)

prescriptions_df = get_prescriptions_48h(hadm_ids)
print(f"Prescriptions (<=48h): {len(prescriptions_df)}")


Prescriptions (<=48h): 619232


In [10]:
# --- Additional modality #2: Procedures in first 48h ---
# Use procedureevents_mv which contains timestamped procedures (MV-era). Filter to start within 48h of admittime.

def get_procedures_48h(hadm_ids: List[int]) -> pd.DataFrame:
    if not hadm_ids:
        return pd.DataFrame()
    sql = """
    WITH first_adm AS (
      SELECT hadm_id, admittime
      FROM `physionet-data.mimiciii_clinical.admissions`
      WHERE hadm_id IN UNNEST(@hadm_ids)
    )
    SELECT pe.subject_id, pe.hadm_id, pe.icustay_id,
           pe.starttime, pe.endtime,
           pe.itemid, di.label AS item_label,
           pe.ordercategoryname, pe.ordercategorydescription, pe.location
    FROM `physionet-data.mimiciii_clinical.procedureevents_mv` pe
    JOIN first_adm fa USING (hadm_id)
    LEFT JOIN `physionet-data.mimiciii_clinical.d_items` di ON di.itemid = pe.itemid
    WHERE pe.hadm_id IN UNNEST(@hadm_ids)
      AND pe.starttime IS NOT NULL
      AND TIMESTAMP_DIFF(pe.starttime, fa.admittime, HOUR) BETWEEN 0 AND 48
    ORDER BY subject_id, starttime
    """
    cfg = bq.QueryJobConfig(
        query_parameters=[bq.ArrayQueryParameter("hadm_ids", "INT64", hadm_ids)]
    )
    return safe_bq_to_df(sql, job_config=cfg)

procedures_df = get_procedures_48h(hadm_ids)
print(f"Procedures (<=48h): {len(procedures_df)}")


Procedures (<=48h): 71675


In [11]:
# --- Optional: Cache extracted raw tables to disk for reproducibility ---
cache_dir = os.path.join(os.path.dirname(cohort_path), "extracted_cache")
os.makedirs(cache_dir, exist_ok=True)

if not first_admissions_df.empty:
    first_admissions_df.to_parquet(os.path.join(cache_dir, "first_admissions.parquet"), index=False)
if not demographics_df.empty:
    demographics_df.to_parquet(os.path.join(cache_dir, "demographics.parquet"), index=False)
if not vitals_df.empty:
    vitals_df.to_parquet(os.path.join(cache_dir, "vitals_48h.parquet"), index=False)
if not labs_df.empty:
    labs_df.to_parquet(os.path.join(cache_dir, "labs_48h.parquet"), index=False)
if not prescriptions_df.empty:
    prescriptions_df.to_parquet(os.path.join(cache_dir, "prescriptions_48h.parquet"), index=False)
if not procedures_df.empty:
    procedures_df.to_parquet(os.path.join(cache_dir, "procedures_48h.parquet"), index=False)

print(f"Cached available extracts under: {cache_dir}")


Cached available extracts under: data\extracted_cache


In [12]:
# --- Quick previews ---
for name, df in {
    'first_admissions': first_admissions_df,
    'demographics': demographics_df,
    'vitals_48h': vitals_df,
    'labs_48h': labs_df,
    'prescriptions_48h': prescriptions_df,
    'procedures_48h': procedures_df,
}.items():
    print(f"\n{name}: {len(df)} rows")
    if not df.empty:
        display(df.head(3))



first_admissions: 32513 rows


Unnamed: 0,subject_id,hadm_id,admittime,dischtime,deathtime,admission_type,admission_location,discharge_location,diagnosis,insurance,language,marital_status,ethnicity
0,2,163353,2138-07-17 19:04:00,2138-07-21 15:48:00,NaT,NEWBORN,PHYS REFERRAL/NORMAL DELI,HOME,NEWBORN,Private,,,ASIAN
1,3,145834,2101-10-20 19:08:00,2101-10-31 13:58:00,NaT,EMERGENCY,EMERGENCY ROOM ADMIT,SNF,HYPOTENSION,Medicare,,MARRIED,WHITE
2,4,185777,2191-03-16 00:28:00,2191-03-23 18:41:00,NaT,EMERGENCY,EMERGENCY ROOM ADMIT,HOME WITH HOME IV PROVIDR,"FEVER,DEHYDRATION,FAILURE TO THRIVE",Private,,SINGLE,WHITE



demographics: 32513 rows


Unnamed: 0,subject_id,gender,dob,dod,expire_flag
0,18848,F,2042-08-21,2128-01-08,1
1,61056,F,2067-04-11,2152-01-08,1
2,26889,F,2115-11-04,2164-01-08,1



vitals_48h: 7919202 rows


Unnamed: 0,subject_id,hadm_id,icustay_id,charttime,item_label,valuenum,valueuom
0,2,163353,243653,2138-07-17 20:20:00,Heart Rate,148.0,bpm
1,2,163353,243653,2138-07-17 20:30:00,BP Cuff [Diastolic],35.0,cc/min
2,2,163353,243653,2138-07-17 20:30:00,BP Cuff [Systolic],72.0,cc/min



labs_48h: 1239058 rows


Unnamed: 0,subject_id,hadm_id,charttime,item_label,valuenum,value_text,valueuom,flag
0,2,163353,2138-07-17 20:48:00,Hemoglobin,0.0,0,g/dL,abnormal
1,2,163353,2138-07-17 20:48:00,Hematocrit,0.0,0,%,abnormal
2,2,163353,2138-07-17 20:48:00,Platelet Count,5.0,5,K/uL,abnormal



prescriptions_48h: 619232 rows


Unnamed: 0,subject_id,hadm_id,startdate,enddate,drug,drug_type,formulary_drug_cd,route
0,2,163353,2138-07-18,2138-07-20,Syringe (Neonatal) *D5W*,BASE,NEOSYRD5W,IV
1,2,163353,2138-07-18,2138-07-21,Send 500mg Vial,BASE,AMPVL,IV
2,2,163353,2138-07-18,2138-07-21,Ampicillin Sodium,MAIN,AMP500I,IV



procedures_48h: 71675 rows


Unnamed: 0,subject_id,hadm_id,icustay_id,starttime,endtime,itemid,item_label,ordercategoryname,ordercategorydescription,location
0,266,186251,293876,2168-07-10 08:30:00,2168-07-10 18:41:00,225792,Invasive Ventilation,Ventilation,Task,
1,266,186251,293876,2168-07-10 08:30:00,2168-07-10 08:31:00,224385,Intubation,Intubation/Extubation,Electrolytes,
2,266,186251,293876,2168-07-10 09:27:00,2168-07-11 17:40:00,224275,20 Gauge,Peripheral Lines,Task,


In [13]:
# --- EDA: Load cached extracts (for reproducibility, independent of BigQuery access) ---
import numpy as np

cache_dir = os.path.join(os.path.dirname(cohort_path), "extracted_cache")

def load_cached(name: str) -> pd.DataFrame:
    path = os.path.join(cache_dir, f"{name}.parquet")
    if os.path.exists(path):
        try:
            return pd.read_parquet(path)
        except Exception as e:
            print(f"Failed to read {path}: {e}")
            return pd.DataFrame()
    else:
        return pd.DataFrame()

eda_first_adm = load_cached("first_admissions")
eda_demo = load_cached("demographics")
eda_vitals = load_cached("vitals_48h")
eda_labs = load_cached("labs_48h")
eda_rx = load_cached("prescriptions_48h")
eda_proc = load_cached("procedures_48h")
# Backward-compat: if procedures not present but microbiology exists, load it for EDA
eda_micro = pd.DataFrame()
if eda_proc.empty:
    eda_micro = load_cached("microbiology_48h")

print("Cached tables loaded for EDA:")
for n, df in {
    'first_admissions': eda_first_adm,
    'demographics': eda_demo,
    'vitals_48h': eda_vitals,
    'labs_48h': eda_labs,
    'prescriptions_48h': eda_rx,
    'procedures_48h': eda_proc,
    'microbiology_48h (fallback)': eda_micro,
}.items():
    print(f"- {n}: {len(df)} rows")


Cached tables loaded for EDA:
- first_admissions: 32513 rows
- demographics: 32513 rows
- vitals_48h: 7919202 rows
- labs_48h: 1239058 rows
- prescriptions_48h: 619232 rows
- procedures_48h: 71675 rows
- microbiology_48h (fallback): 0 rows


In [14]:
# --- EDA: Cohort overview, age, gender, admission stats ---
from datetime import datetime

def hours_between(a, b) -> float:
    try:
        return (pd.to_datetime(a) - pd.to_datetime(b)).total_seconds() / 3600.0
    except Exception:
        return np.nan

insights = []

n_subjects = len(set(eda_first_adm['subject_id'])) if not eda_first_adm.empty else len(set(cohort_df['subject_id']))
insights.append(f"Cohort subjects (initial): {n_subjects}")

# Age at first admission
age_df = pd.DataFrame()
if not eda_first_adm.empty and not eda_demo.empty:
    tmp = eda_first_adm[['subject_id', 'admittime']].merge(
        eda_demo[['subject_id', 'dob', 'gender']], on='subject_id', how='left'
    )
    tmp['age_years'] = (pd.to_datetime(tmp['admittime']) - pd.to_datetime(tmp['dob'])).dt.days / 365.25
    # Cap ages >= 89 to 90 per common MIMIC-III practice
    tmp['age_years'] = tmp['age_years'].clip(lower=0)
    tmp.loc[tmp['age_years'] >= 89, 'age_years'] = 90
    age_df = tmp
    if not tmp['age_years'].dropna().empty:
        insights.append(
            "Age at admission (years): median {:.1f} [IQR {:.1f}-{:.1f}], >=65: {:.1f}%".format(
                tmp['age_years'].median(),
                tmp['age_years'].quantile(0.25),
                tmp['age_years'].quantile(0.75),
                100.0 * (tmp['age_years'] >= 65).mean(),
            )
        )
    if 'gender' in tmp and not tmp['gender'].dropna().empty:
        g = tmp['gender'].value_counts(normalize=True).mul(100).round(1)
        insights.append("Gender distribution (%): " + ", ".join(f"{k}: {v}%" for k, v in g.to_dict().items()))

# Admission types and LOS
if not eda_first_adm.empty:
    if 'admission_type' in eda_first_adm:
        at = eda_first_adm['admission_type'].value_counts(normalize=True).mul(100).round(1).to_dict()
        insights.append("Admission types (%): " + ", ".join(f"{k}: {v}%" for k, v in at.items()))
    los_hours = (pd.to_datetime(eda_first_adm['dischtime']) - pd.to_datetime(eda_first_adm['admittime'])).dt.total_seconds() / 3600.0
    los_hours = los_hours.clip(lower=0)
    if not los_hours.dropna().empty:
        insights.append(
            "Hospital LOS (hours): median {:.1f} [IQR {:.1f}-{:.1f}], >=54h coverage: {:.1f}%".format(
                los_hours.median(),
                los_hours.quantile(0.25),
                los_hours.quantile(0.75),
                100.0 * (los_hours >= 54).mean(),
            )
        )

print("\nKey cohort insights:")
for s in insights:
    print("-", s)



Key cohort insights:
- Cohort subjects (initial): 32513
- Age at admission (years): median 60.5 [IQR 38.2-75.5], >=65: 42.5%
- Gender distribution (%): M: 56.4%, F: 43.6%
- Admission types (%): EMERGENCY: 67.2%, NEWBORN: 16.9%, ELECTIVE: 13.5%, URGENT: 2.5%
- Hospital LOS (hours): median 152.9 [IQR 89.0-280.4], >=54h coverage: 87.6%


In [15]:
# --- EDA: Coverage and density per modality in first 48h ---
from collections import defaultdict

coverage = []
if not eda_first_adm.empty:
    subj_set = set(eda_first_adm['subject_id'])
    hadm_set = set(eda_first_adm['hadm_id'])
else:
    subj_set = set(cohort_df['subject_id'])
    hadm_set = set()

def pct(x):
    return round(100.0 * x, 1)

# Helper to compute per-subject coverage
def cov_by_subject(name: str, df: pd.DataFrame, id_col: str = 'subject_id'):
    if df.empty:
        print(f"- {name}: 0.0% subjects with records")
        return
    with_rec = len(set(df[id_col]) & subj_set) if subj_set else df[id_col].nunique()
    base = len(subj_set) if subj_set else n_subjects
    print(f"- {name}: {pct(with_rec / base if base else 0)}% subjects with records, total rows={len(df):,}")

print("\nCoverage within first 48h:")
cov_by_subject("Vitals", eda_vitals)
cov_by_subject("Labs", eda_labs)
cov_by_subject("Prescriptions", eda_rx)
if not eda_proc.empty:
    cov_by_subject("Procedures", eda_proc)
elif not eda_micro.empty:
    cov_by_subject("Microbiology", eda_micro)

# Event density per subject (median events/subject)
def density(name: str, df: pd.DataFrame):
    if df.empty:
        return
    cnt = df.groupby('subject_id').size()
    print(f"  {name} events per subject: median {cnt.median():.0f} [IQR {cnt.quantile(0.25):.0f}-{cnt.quantile(0.75):.0f}]")

density("Vitals", eda_vitals)
density("Labs", eda_labs)
density("Prescriptions", eda_rx)
if not eda_proc.empty:
    density("Procedures", eda_proc)
elif not eda_micro.empty:
    density("Microbiology", eda_micro)



Coverage within first 48h:
- Vitals: 85.2% subjects with records, total rows=7,919,202
- Labs: 95.4% subjects with records, total rows=1,239,058
- Prescriptions: 79.2% subjects with records, total rows=619,232
- Procedures: 29.4% subjects with records, total rows=71,675
  Vitals events per subject: median 257 [IQR 154-375]
  Labs events per subject: median 36 [IQR 23-53]
  Prescriptions events per subject: median 20 [IQR 10-33]
  Procedures events per subject: median 6 [IQR 4-11]


In [16]:
# --- EDA: Vitals and Labs distributions (first 48h) ---
# Summarize by item label
if not eda_vitals.empty and 'item_label' in eda_vitals:
    vstats = (
        eda_vitals.dropna(subset=['valuenum'])
        .groupby('item_label')['valuenum']
        .agg(['count', 'median', 'mean', 'std'])
        .sort_values('count', ascending=False)
        .head(15)
        .round(2)
    )
    print("\nTop 15 vitals by count with summary stats:")
    display(vstats)
else:
    print("No vitals available for distribution summary.")

if not eda_labs.empty and 'item_label' in eda_labs:
    lstats = (
        eda_labs.dropna(subset=['valuenum'])
        .groupby('item_label')['valuenum']
        .agg(['count', 'median', 'mean', 'std'])
        .sort_values('count', ascending=False)
        .head(15)
        .round(2)
    )
    print("\nTop 15 labs by count with summary stats:")
    display(lstats)
else:
    print("No labs available for distribution summary.")



Top 15 vitals by count with summary stats:


Unnamed: 0_level_0,count,median,mean,std
item_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Heart Rate,1120336,87.0,91.6,25.0
Respiratory Rate,974145,18.0,18.8,5.9
SpO2,531751,98.0,97.29,3.91
O2 saturation pulseoxymetry,428941,97.0,99.27,1498.03
HR Alarm [High],403301,120.0,143.55,37.57
HR Alarm [Low],403227,60.0,65.23,18.33
Arterial BP [Systolic],298445,117.0,118.86,25.77
Arterial BP [Diastolic],298317,59.0,60.07,13.69
SpO2 Alarm [Low],283158,90.0,89.99,8.62
SpO2 Alarm [High],281712,100.0,99.15,7.74



Top 15 labs by count with summary stats:


Unnamed: 0_level_0,count,median,mean,std
item_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Glucose,125003,127.0,142.56,76.66
Hemoglobin,107141,10.8,11.06,2.41
Hematocrit,104090,31.5,32.51,6.77
Potassium,87845,4.1,4.16,0.76
Platelet Count,84996,193.0,206.6,106.37
Sodium,83555,139.0,138.97,5.16
Chloride,83525,106.0,105.54,6.25
Bicarbonate,81618,24.0,23.77,4.61
Creatinine,80634,0.9,1.38,1.47
Urea Nitrogen,80238,18.0,24.93,20.75


In [17]:
# --- EDA: Prescriptions (first 48h) ---
if not eda_rx.empty:
    top_drugs = (
        eda_rx.assign(drug_low=eda_rx['drug'].astype(str).str.lower())
        .groupby('drug_low')
        .size()
        .sort_values(ascending=False)
        .head(20)
    )
    print("\nTop 20 prescribed drugs (by mentions) in first 48h:")
    display(top_drugs)

    # Simple flags for antibiotics and vasopressors in prescriptions
    rx_lower = eda_rx['drug'].astype(str).str.lower()
    abx_patterns = ['cillin', 'cef', 'ceph', 'penem', 'floxacin', 'vancomycin', 'metronidazole', 'piperacillin', 'tazobactam']
    vaso_patterns = ['norepinephrine', 'levophed', 'epinephrine', 'dopamine', 'dobutamine', 'phenylephrine', 'vasopressin']
    has_abx = rx_lower.str.contains('|'.join(abx_patterns), na=False)
    has_vaso = rx_lower.str.contains('|'.join(vaso_patterns), na=False)
    abx_subj = eda_rx.loc[has_abx, 'subject_id'].nunique()
    vaso_subj = eda_rx.loc[has_vaso, 'subject_id'].nunique()
    base = len(subj_set) if subj_set else n_subjects
    print(f"Antibiotics exposure (any mention): {pct(abx_subj / base if base else 0)}% of subjects")
    print(f"Vasopressor exposure (any mention): {pct(vaso_subj / base if base else 0)}% of subjects")
else:
    print("No prescriptions available for summary.")



Top 20 prescribed drugs (by mentions) in first 48h:


drug_low
potassium chloride             28964
0.9% sodium chloride           23917
ns                             21936
d5w                            20332
insulin                        17355
magnesium sulfate              15592
furosemide                     14961
sodium chloride 0.9%  flush    14810
acetaminophen                  13396
sw                             12009
iso-osmotic dextrose           11860
metoprolol                     11173
calcium gluconate              10563
morphine sulfate               10220
5% dextrose                     9868
metoprolol tartrate             9383
lorazepam                       8826
heparin                         8586
lr                              7763
docusate sodium                 7684
dtype: int64

Antibiotics exposure (any mention): 36.3% of subjects
Vasopressor exposure (any mention): 12.3% of subjects


In [18]:
# --- EDA: Procedures or Microbiology (first 48h) ---
if not eda_proc.empty:
    top_proc_cat = eda_proc['ordercategoryname'].value_counts().head(15)
    print("\nTop 15 procedure categories:")
    display(top_proc_cat)
    # Heuristic ventilatory support flag
    label_col = 'item_label' if 'item_label' in eda_proc.columns else 'ordercategorydescription'
    lbl = eda_proc[label_col].astype(str).str.lower()
    vent_patterns = ['vent', 'intubat', 'endotracheal', 'peep', 'tidal volume']
    renal_patterns = ['dialysis', 'crrt', 'hemodialysis']
    central_line_patterns = ['central line', 'cvc', 'subclavian', 'internal jugular', 'femoral line']
    base = len(subj_set) if subj_set else n_subjects
    def frac(patterns):
        m = lbl.str.contains('|'.join(patterns), na=False)
        return pct(eda_proc.loc[m, 'subject_id'].nunique() / base) if base else 0.0
    print(f"Ventilation-related procedure (heuristic): {frac(vent_patterns)}% of subjects")
    print(f"Renal replacement (heuristic): {frac(renal_patterns)}% of subjects")
    print(f"Central line (heuristic): {frac(central_line_patterns)}% of subjects")
elif not eda_micro.empty:
    print("\nMicrobiology summary (fallback):")
    top_spec = eda_micro['spec_type_desc'].value_counts().head(10)
    display(top_spec)
    pos = eda_micro['interpretation'].astype(str).str.upper().str.contains('S|R|I')
    base = len(subj_set) if subj_set else n_subjects
    any_culture = eda_micro['subject_id'].nunique()
    print(f"Any culture taken: {pct(any_culture / base if base else 0)}% of subjects")
else:
    print("No procedures or microbiology available for summary.")



Top 15 procedure categories:


ordercategoryname
Peripheral Lines         22732
Procedures               15702
Invasive Lines           10780
Imaging                   9601
Ventilation               4124
Intubation/Extubation     3520
Significant Events        2942
Communication             1970
Continuous Procedures      175
Dialysis                   112
Peritoneal Dialysis         15
CRRT Filter Change           2
Name: count, dtype: int64

Ventilation-related procedure (heuristic): 12.2% of subjects
Renal replacement (heuristic): 0.8% of subjects
Central line (heuristic): 0.0% of subjects


In [19]:
# --- EDA: Consolidated key insights summary ---
summary_lines = []
summary_lines.extend(insights)

# Add coverage headlines
def cov_headline(name, df):
    if df.empty:
        return f"{name}: 0% coverage"
    base = len(subj_set) if subj_set else n_subjects
    frac = df['subject_id'].nunique() / base if base else 0
    return f"{name}: {pct(frac)}% subjects with data"

summary_lines.append(cov_headline("Vitals (<=48h)", eda_vitals))
summary_lines.append(cov_headline("Labs (<=48h)", eda_labs))
summary_lines.append(cov_headline("Prescriptions (<=48h)", eda_rx))
if not eda_proc.empty:
    summary_lines.append(cov_headline("Procedures (<=48h)", eda_proc))
elif not eda_micro.empty:
    summary_lines.append(cov_headline("Microbiology (<=48h)", eda_micro))

# Save a machine-readable report
report_path = os.path.join(cache_dir, "eda_summary.json")
try:
    import json
    with open(report_path, 'w', encoding='utf-8') as f:
        json.dump({"insights": summary_lines}, f, indent=2)
    print(f"\nEDA summary saved to {report_path}")
except Exception as e:
    print(f"Failed to save EDA summary: {e}")

print("\nKey EDA insights:")
for line in summary_lines:
    print("-", line)



EDA summary saved to data\extracted_cache\eda_summary.json

Key EDA insights:
- Cohort subjects (initial): 32513
- Age at admission (years): median 60.5 [IQR 38.2-75.5], >=65: 42.5%
- Gender distribution (%): M: 56.4%, F: 43.6%
- Admission types (%): EMERGENCY: 67.2%, NEWBORN: 16.9%, ELECTIVE: 13.5%, URGENT: 2.5%
- Hospital LOS (hours): median 152.9 [IQR 89.0-280.4], >=54h coverage: 87.6%
- Vitals (<=48h): 85.2% subjects with data
- Labs (<=48h): 95.4% subjects with data
- Prescriptions (<=48h): 79.2% subjects with data
- Procedures (<=48h): 29.4% subjects with data


In [20]:
# --- Labels: Mortality, Prolonged LOS (>7d), 30-day Readmission ---
# Build labels for the first admission per subject; filter to LOS >= 54h (timeline requirement).
from google.cloud import bigquery as bq
import numpy as np

# Helper: fetch all admissions for cohort to compute readmission
def get_all_admissions(subject_ids: List[int]) -> pd.DataFrame:
    sql = """
    SELECT subject_id, hadm_id, admittime, dischtime, deathtime
    FROM `physionet-data.mimiciii_clinical.admissions`
    WHERE subject_id IN UNNEST(@subject_ids)
    ORDER BY subject_id, admittime
    """
    cfg = bq.QueryJobConfig(
        query_parameters=[bq.ArrayQueryParameter("subject_ids", "INT64", subject_ids)]
    )
    return safe_bq_to_df(sql, job_config=cfg)

all_adm_df = get_all_admissions(subject_ids) if subject_ids else pd.DataFrame()

# Compute LOS and filter >=54h in first_admissions_df
if not first_admissions_df.empty:
    fadm = first_admissions_df.copy()
    fadm['admittime'] = pd.to_datetime(fadm['admittime'])
    fadm['dischtime'] = pd.to_datetime(fadm['dischtime'])
    fadm['los_hours'] = (fadm['dischtime'] - fadm['admittime']).dt.total_seconds() / 3600.0
    fadm = fadm[fadm['los_hours'] >= 54].copy()
else:
    fadm = pd.DataFrame(columns=['subject_id','hadm_id','admittime','dischtime','los_hours'])

# Recompute hadm_ids after LOS filter
hadm_ids_filtered: List[int] = fadm.get('hadm_id', pd.Series([], dtype='int')).dropna().astype(int).tolist()
print(f"First admissions with LOS>=54h: {len(hadm_ids_filtered)}")

# Subset 48h extracts to filtered hadm_ids
if not vitals_df.empty:
    vitals_df = vitals_df[vitals_df['hadm_id'].isin(hadm_ids_filtered)]
if not labs_df.empty:
    labs_df = labs_df[labs_df['hadm_id'].isin(hadm_ids_filtered)]
if not prescriptions_df.empty:
    prescriptions_df = prescriptions_df[prescriptions_df['hadm_id'].isin(hadm_ids_filtered)]
if 'procedures_df' in globals() and not procedures_df.empty:
    procedures_df = procedures_df[procedures_df['hadm_id'].isin(hadm_ids_filtered)]

# Mortality: in-hospital death OR death <= 30 days after discharge
mortality_df = pd.DataFrame()
if not fadm.empty:
    mort = fadm[['subject_id','hadm_id','admittime','dischtime']].copy()
    # join death data
    demo_for_death = demographics_df[['subject_id','dod']].copy() if not demographics_df.empty else pd.DataFrame(columns=['subject_id','dod'])
    mort = mort.merge(demo_for_death, on='subject_id', how='left')
    mort['dod'] = pd.to_datetime(mort['dod'])
    # in-hospital death via admissions.deathtime if available (from all_adm_df which has deathtime)
    if not all_adm_df.empty:
        tmp_death = all_adm_df[['hadm_id','deathtime']].drop_duplicates()
        tmp_death['deathtime'] = pd.to_datetime(tmp_death['deathtime'])
        mort = mort.merge(tmp_death, on='hadm_id', how='left')
    else:
        mort['deathtime'] = pd.NaT
    death_in_hosp = mort['deathtime'].notna()
    death_within_30d = (mort['dod'].notna()) & (mort['dod'] <= (mort['dischtime'] + pd.Timedelta(days=30))) & (mort['dod'] >= mort['dischtime'])
    mort['mortality_label'] = (death_in_hosp | death_within_30d).astype(int)
    mortality_df = mort[['subject_id','hadm_id','mortality_label']]

# Prolonged LOS: > 7 days
prolonged_los_df = pd.DataFrame()
if not fadm.empty:
    pl = fadm[['subject_id','hadm_id','los_hours']].copy()
    pl['prolonged_los_label'] = (pl['los_hours'] > 7*24).astype(int)
    prolonged_los_df = pl[['subject_id','hadm_id','prolonged_los_label']]

# 30-day readmission: second hospital admission within 30 days of first discharge
readmit_df = pd.DataFrame()
if not fadm.empty and not all_adm_df.empty:
    fa = fadm[['subject_id','hadm_id','dischtime']].rename(columns={'hadm_id':'first_hadm_id','dischtime':'first_dischtime'})
    nxt = all_adm_df.sort_values(['subject_id','admittime']).copy()
    # compute next admission per subject after the first hadm
    nxt['is_after_first'] = False
    # merge to tag the first discharge
    nxt = nxt.merge(fa[['subject_id','first_hadm_id','first_dischtime']], on='subject_id', how='left')
    nxt['is_after_first'] = (nxt['admittime'] > nxt['first_dischtime'])
    # next admission within 30 days exists?
    within_30 = nxt[nxt['is_after_first']].copy()
    within_30['within_30d'] = within_30['admittime'] <= (within_30['first_dischtime'] + pd.Timedelta(days=30))
    rn = within_30.groupby('subject_id')['within_30d'].any().reset_index().rename(columns={'within_30d':'readmit_30d'})
    readmit_df = fa[['subject_id','first_hadm_id']].merge(rn, on='subject_id', how='left').fillna({'readmit_30d': False})
    readmit_df['readmission_label'] = readmit_df['readmit_30d'].astype(int)
    readmit_df = readmit_df.rename(columns={'first_hadm_id':'hadm_id'})[['subject_id','hadm_id','readmission_label']]

# Combine labels
labels_df = fadm[['subject_id','hadm_id']].copy()
labels_df = labels_df.merge(mortality_df, on=['subject_id','hadm_id'], how='left')
labels_df = labels_df.merge(prolonged_los_df, on=['subject_id','hadm_id'], how='left')
labels_df = labels_df.merge(readmit_df, on=['subject_id','hadm_id'], how='left')
labels_df = labels_df.fillna({'mortality_label':0, 'prolonged_los_label':0, 'readmission_label':0}).astype({
    'mortality_label': int,
    'prolonged_los_label': int,
    'readmission_label': int,
})
print(labels_df.head())


First admissions with LOS>=54h: 28473
   subject_id  hadm_id  mortality_label  prolonged_los_label  \
0           2   163353                0                    0   
1           3   145834                0                    1   
2           4   185777                0                    1   
3           5   178980                0                    0   
4           7   118037                0                    0   

   readmission_label  
0                  0  
1                  0  
2                  0  
3                  0  
4                  0  


  readmit_df = fa[['subject_id','first_hadm_id']].merge(rn, on='subject_id', how='left').fillna({'readmit_30d': False})


In [21]:
# --- Feature engineering: aggregate 0-48h features and build a subject-level matrix ---
from collections import defaultdict

# Demographics features
feat_demo = pd.DataFrame()
if not fadm.empty and not demographics_df.empty:
    # Use ethnicity from first admissions; patients table doesn't have it
    tmp = fadm[['subject_id','admittime','ethnicity']].merge(
        demographics_df[['subject_id','gender','dob']], on='subject_id', how='left'
    )
    # Age (cap at 90)
    tmp['age'] = (pd.to_datetime(tmp['admittime']) - pd.to_datetime(tmp['dob'])).dt.days / 365.25
    tmp['age'] = tmp['age'].clip(lower=0)
    tmp.loc[tmp['age'] >= 89, 'age'] = 90
    # Gender one-hot
    tmp['gender_M'] = (tmp['gender'].astype(str).str.upper() == 'M').astype(int)
    tmp['gender_F'] = (tmp['gender'].astype(str).str.upper() == 'F').astype(int)
    # Ethnicity buckets and one-hot (keep small number of columns)
    def _eth_bucket(x: str) -> str:
        s = str(x).lower()
        if 'white' in s:
            return 'WHITE'
        if 'black' in s:
            return 'BLACK'
        if 'asian' in s:
            return 'ASIAN'
        if 'hisp' in s or 'latino' in s or 'latina' in s:
            return 'HISPANIC'
        return 'OTHER'
    tmp['eth_bucket'] = tmp['ethnicity'].apply(_eth_bucket)
    eth_dummies = pd.get_dummies(tmp['eth_bucket'], prefix='eth', dtype=int)
    feat_demo = pd.concat([tmp[['subject_id','age','gender_M','gender_F']], eth_dummies], axis=1)
# Helper to aggregate events

def aggregate_events(df: pd.DataFrame, value_col: str, time_col: str, label_col: str) -> pd.DataFrame:
    if df.empty:
        return pd.DataFrame(columns=['subject_id'])
    d = df.dropna(subset=[value_col]).copy()
    if d.empty:
        return pd.DataFrame(columns=['subject_id'])
    # last value per item by time
    d[time_col] = pd.to_datetime(d[time_col])
    last_vals = d.sort_values(["subject_id", label_col, time_col]).groupby(['subject_id', label_col]).tail(1)
    agg = d.groupby(['subject_id', label_col])[value_col].agg(['mean','min','max'])
    agg = agg.reset_index()
    last = last_vals[['subject_id', label_col, value_col]].rename(columns={value_col:'last'})
    wide = agg.merge(last, on=['subject_id', label_col], how='left')
    # pivot
    wide_cols = []
    for stat in ['mean','min','max','last']:
        pivot = wide.pivot_table(index='subject_id', columns=label_col, values=stat)
        pivot.columns = [f"{str(c)}__{stat}" for c in pivot.columns]
        wide_cols.append(pivot)
    out = pd.concat(wide_cols, axis=1)
    out = out.reset_index()
    return out

feat_vitals = aggregate_events(vitals_df, value_col='valuenum', time_col='charttime', label_col='item_label') if 'vitals_df' in globals() else pd.DataFrame()
feat_labs = aggregate_events(labs_df, value_col='valuenum', time_col='charttime', label_col='item_label') if 'labs_df' in globals() else pd.DataFrame()

# Prescriptions: simple pharmacotherapy flags
feat_rx = pd.DataFrame()
if 'prescriptions_df' in globals() and not prescriptions_df.empty:
    rx = prescriptions_df.copy()
    rx['drug_low'] = rx['drug'].astype(str).str.lower()
    def any_pattern(series, patterns):
        return series.str.contains('|'.join(patterns), na=False)
    abx_patterns = ['cillin','cef','ceph','penem','floxacin','vancomycin','metronidazole','piperacillin','tazobactam']
    insulin_patterns = ['insulin']
    diuretic_patterns = ['furosemide','lasix','bumetanide','torsemide','hydrochlorothiazide','hctz','spironolactone']
    steroid_patterns = ['predni','methylpred','hydrocortisone','dexamethasone']
    grp = rx.groupby('subject_id')
    feat_rx = pd.DataFrame({
        'subject_id': grp.size().index,
        'rx_total_mentions': grp.size().values,
        'rx_unique_drugs': grp['drug_low'].nunique().values,
    })
    # flags
    rx_flags = rx[['subject_id','drug_low']].copy()
    flags = rx_flags.groupby('subject_id').agg({
        'drug_low': lambda s: pd.Series({
            'rx_any_abx': any_pattern(s, abx_patterns).any(),
            'rx_any_insulin': any_pattern(s, insulin_patterns).any(),
            'rx_any_diuretic': any_pattern(s, diuretic_patterns).any(),
            'rx_any_steroid': any_pattern(s, steroid_patterns).any(),
        })
    })
    flags = pd.DataFrame(list(flags['drug_low'].values), index=flags.index).reset_index().rename(columns={'index':'subject_id'})
    for c in flags.columns:
        if c != 'subject_id':
            flags[c] = flags[c].astype(int)
    feat_rx = feat_rx.merge(flags, on='subject_id', how='left')

# Procedures: heuristic flags
feat_proc = pd.DataFrame()
if 'procedures_df' in globals() and not procedures_df.empty:
    p = procedures_df.copy()
    label_col = 'item_label' if 'item_label' in p.columns else 'ordercategorydescription'
    p['lbl_low'] = p[label_col].astype(str).str.lower()
    def flag_frac(df, patterns):
        return df.groupby('subject_id').apply(lambda g: g['lbl_low'].str.contains('|'.join(patterns), na=False).any())
    vent = flag_frac(p, ['vent','intubat','endotracheal','peep','tidal volume'])
    rrt = flag_frac(p, ['dialysis','crrt','hemodialysis'])
    cl = flag_frac(p, ['central line','cvc','subclavian','internal jugular','femoral line'])
    feat_proc = pd.DataFrame({
        'subject_id': vent.index,
        'proc_vent_any': vent.astype(int).values,
        'proc_rrt_any': rrt.astype(int).values,
        'proc_central_line_any': cl.astype(int).values,
    })

# Merge all features on subject_id
features = feat_demo.copy()
for df in [feat_vitals, feat_labs, feat_rx, feat_proc]:
    if not df.empty:
        features = features.merge(df, on='subject_id', how='left')

features = features.drop_duplicates('subject_id')
features = features.set_index('subject_id')
features = features.apply(pd.to_numeric, errors='ignore')
# Align to labeled cohort
features = features.loc[features.index.intersection(labels_df['subject_id'])]
# Ensure all column names are strings (sklearn requires homogeneous string feature names)
features.columns = features.columns.map(str)
print(f"Feature matrix shape: {features.shape}")


  return df.groupby('subject_id').apply(lambda g: g['lbl_low'].str.contains('|'.join(patterns), na=False).any())
  return df.groupby('subject_id').apply(lambda g: g['lbl_low'].str.contains('|'.join(patterns), na=False).any())
  return df.groupby('subject_id').apply(lambda g: g['lbl_low'].str.contains('|'.join(patterns), na=False).any())
  features = features.apply(pd.to_numeric, errors='ignore')


Feature matrix shape: (28473, 1097)


In [22]:
# --- Train 3 separate calibrated models and save artifacts ---
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss
from sklearn.dummy import DummyClassifier
import joblib

# Prepare labels indexed by subject_id
labels_idx = labels_df[['subject_id','mortality_label','prolonged_los_label','readmission_label']].drop_duplicates().set_index('subject_id')
# Keep rows with features
common_idx = features.index.intersection(labels_idx.index)
X = features.loc[common_idx]
Y = labels_idx.loc[common_idx]
# Also ensure X columns are strings explicitly
X.columns = X.columns.map(str)

# Train/val/test split (consistent across targets) with robust stratification fallback
from collections import Counter as _Counter
_y_mort = Y['mortality_label']
_stratify_ok = (_y_mort.nunique() >= 2) and min(_Counter(_y_mort.values).values()) >= 2
if _stratify_ok:
    train_ids, test_ids = train_test_split(common_idx, test_size=0.2, random_state=42, stratify=_y_mort)
    _y_train = Y.loc[train_ids, 'mortality_label']
    _stratify_ok2 = (_y_train.nunique() >= 2) and min(_Counter(_y_train.values).values()) >= 2
    if _stratify_ok2:
        train_ids, val_ids = train_test_split(train_ids, test_size=0.2, random_state=42, stratify=Y.loc[train_ids, 'mortality_label'])
    else:
        print("Warning: insufficient class balance for stratified val split; using unstratified split.")
        train_ids, val_ids = train_test_split(train_ids, test_size=0.2, random_state=42)
else:
    print("Warning: insufficient class balance for stratified splits; using unstratified splits.")
    train_ids, test_ids = train_test_split(common_idx, test_size=0.2, random_state=42)
    train_ids, val_ids = train_test_split(train_ids, test_size=0.2, random_state=42)

X_train, X_val, X_test = X.loc[train_ids], X.loc[val_ids], X.loc[test_ids]

# Prepare preprocessor
preprocessor = Pipeline([
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler(with_mean=True, with_std=True)),
])

X_train_t = preprocessor.fit_transform(X_train)
X_val_t = preprocessor.transform(X_val)
X_test_t = preprocessor.transform(X_test)

# Helper to train and calibrate a model for a single target

def train_calibrated(X_tr, y_tr, X_val, y_val):
    import numpy as _np
    # Fallback if training labels are single-class
    if len(_np.unique(y_tr)) < 2:
        print("Warning: single-class training labels; using DummyClassifier.")
        dummy = DummyClassifier(strategy='constant', constant=int(_np.unique(y_tr)[0]))
        dummy.fit(X_tr, y_tr)
        return dummy
    base = LogisticRegression(max_iter=500, n_jobs=None)
    base.fit(X_tr, y_tr)
    # If validation has a single class, skip calibration to avoid errors
    if len(_np.unique(y_val)) < 2:
        return base
    # scikit-learn >=1.4 uses 'estimator' instead of 'base_estimator'
    calib = CalibratedClassifierCV(estimator=base, method='sigmoid', cv='prefit')
    calib.fit(X_val, y_val)
    return calib

models = {}
metrics = {}
for target, col in [('mortality','mortality_label'), ('prolonged_los','prolonged_los_label'), ('readmission','readmission_label')]:
    y_train = Y.loc[train_ids, col].values
    y_val = Y.loc[val_ids, col].values
    y_test = Y.loc[test_ids, col].values
    model = train_calibrated(X_train_t, y_train, X_val_t, y_val)
    models[target] = model
    # eval
    proba = model.predict_proba(X_test_t)[:,1]
    metrics[target] = {
        'roc_auc': float(roc_auc_score(y_test, proba)) if len(np.unique(y_test))>1 else np.nan,
        'pr_auc': float(average_precision_score(y_test, proba)) if len(np.unique(y_test))>1 else np.nan,
        'brier': float(brier_score_loss(y_test, proba)),
        'positives_test': int(y_test.sum()),
        'n_test': int(len(y_test)),
    }

print("\nTest metrics:")
for k, v in metrics.items():
    print(k, v)

# Save artifacts
models_dir = os.path.join(os.path.dirname(__file__) if '__file__' in globals() else os.getcwd(), 'models')
os.makedirs(models_dir, exist_ok=True)
feature_cols = list(X.columns)
joblib.dump(preprocessor, os.path.join(models_dir, 'preprocessor.joblib'))
joblib.dump(models['mortality'], os.path.join(models_dir, 'model_mortality.joblib'))
joblib.dump(models['prolonged_los'], os.path.join(models_dir, 'model_prolonged_los.joblib'))
joblib.dump(models['readmission'], os.path.join(models_dir, 'model_readmission.joblib'))

# Save feature columns for alignment in unseen evaluation
import json
with open(os.path.join(models_dir, 'feature_columns.json'), 'w', encoding='utf-8') as f:
    json.dump(feature_cols, f)
print(f"Saved models and artifacts to: {models_dir}")


 'Mucomyst mg/hr__mean' 'NIMBEX MG/KG/HR__mean'
 'Pantoprazole   mg/hr__mean' 'Protonix       mg/hr__mean'
 'Solumedrol  mg/kg/hr__mean' 'TPA MG/HR__mean' 'approtinin cc/hr__mean'
 'left radial MAP__mean' 'nicardipine mg/hr__mean' 'APOTININ CC/HR__min'
 'Amicar cc/hr__min' 'FEM ART MAP__min' 'Mucomyst mg/hr__min'
 'NIMBEX MG/KG/HR__min' 'Pantoprazole   mg/hr__min'
 'Protonix       mg/hr__min' 'Solumedrol  mg/kg/hr__min' 'TPA MG/HR__min'
 'approtinin cc/hr__min' 'left radial MAP__min' 'nicardipine mg/hr__min'
 'APOTININ CC/HR__max' 'Amicar cc/hr__max' 'FEM ART MAP__max'
 'Mucomyst mg/hr__max' 'NIMBEX MG/KG/HR__max' 'Pantoprazole   mg/hr__max'
 'Protonix       mg/hr__max' 'Solumedrol  mg/kg/hr__max' 'TPA MG/HR__max'
 'approtinin cc/hr__max' 'left radial MAP__max' 'nicardipine mg/hr__max'
 'APOTININ CC/HR__last' 'Amicar cc/hr__last' 'FEM ART MAP__last'
 'Mucomyst mg/hr__last' 'NIMBEX MG/KG/HR__last'
 'Pantoprazole   mg/hr__last' 'Protonix       mg/hr__last'
 'Solumedrol  mg/kg/hr__last' '


Test metrics:
mortality {'roc_auc': 0.8215776353344071, 'pr_auc': 0.379138522657292, 'brier': 0.08379785195207613, 'positives_test': 622, 'n_test': 5695}
prolonged_los {'roc_auc': 0.7291081045953189, 'pr_auc': 0.7160317495940125, 'brier': 0.24947432700972497, 'positives_test': 2967, 'n_test': 5695}
readmission {'roc_auc': 0.5566588510947464, 'pr_auc': 0.04830008962497373, 'brier': 0.03836180003795895, 'positives_test': 226, 'n_test': 5695}
Saved models and artifacts to: C:\Users\Almog Luz\Documents\GitHub\mlhc-final-project\project\models


