# 03 - Length Of Stay (Regression)


In [1]:
import os, re, json, time, warnings, pickle
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")
np.set_printoptions(suppress=True)
pd.set_option("display.max_columns", 120)

In [2]:
# ======================
# Config / Paths
# ======================
SEED = 42
rng = np.random.default_rng(SEED)
DATA_DIR = r"D:/HealthAI Project/data"  # <- change if needed
ART = Path("./Models/los_artifacts"); ART.mkdir(parents=True, exist_ok=True)
MODELS_DIR = Path("./Models/LOS_Model"); MODELS_DIR.mkdir(parents=True, exist_ok=True)

# ======================
# Helpers
# ======================
def find_one_any(base_dir, basenames):
    want = {b.lower() for b in basenames}
    for root, _, files in os.walk(base_dir):
        for f in files:
            if f.lower() in want:
                return os.path.join(root, f)
    return None

def read_csv_auto(path):
    if path is None: return None
    comp = "gzip" if str(path).lower().endswith(".gz") else None
    for enc in ["utf-8","ISO-8859-1","cp1252","latin1"]:
        try:
            return pd.read_csv(path, compression=comp, low_memory=False, encoding=enc)
        except Exception:
            pass
    return pd.read_csv(path, compression=comp, low_memory=False)

def print_header(title):
    print("\n" + "="*len(title)); print(title); print("="*len(title))

def to_int_id(series):
    """Coerce an ID column to pandas nullable Int64 (handles strings/NaN safely)."""
    return pd.to_numeric(series, errors="coerce").astype("Int64")

In [None]:
# ======================
# Locate CSVs
# ======================
admissions_p = find_one_any(DATA_DIR, ["admissions.csv"])
patients_p   = find_one_any(DATA_DIR, ["patients.csv"])
diagn_p      = find_one_any(DATA_DIR, ["diagnoses_icd.csv", "diagnoses.csv"])
vitals_p     = find_one_any(DATA_DIR, ["vitalsign.csv", "vitalsigns.csv"])

print_header("Files detected")
print({
    "admissions": admissions_p,
    "patients": patients_p,
    "diagnoses_icd": diagn_p,
    "vitalsign": vitals_p
})

if admissions_p is None:
    raise FileNotFoundError("admissions.csv not found under DATA_DIR — required for LOS labels.")

# ======================
# Load & normalize
# ======================
adm = read_csv_auto(admissions_p)
pat = read_csv_auto(patients_p) if patients_p else None
diag = read_csv_auto(diagn_p) if diagn_p else None
vit = read_csv_auto(vitals_p) if vitals_p else None

for df in [adm, pat, diag, vit]:
    if df is not None:
        df.columns = [c.lower() for c in df.columns]
        if "hadm_id" in df.columns: df["hadm_id"] = to_int_id(df["hadm_id"])
        if "subject_id" in df.columns: df["subject_id"] = to_int_id(df["subject_id"])

# LOS label in days (+ unit guard)
for c in ("admittime","dischtime","deathtime"):
    if c in adm.columns:
        adm[c] = pd.to_datetime(adm[c], errors="coerce")

if "admittime" not in adm.columns or "dischtime" not in adm.columns:
    raise RuntimeError("Expected admittime / dischtime in admissions.csv")

los_raw_days = (adm["dischtime"] - adm["admittime"]).dt.total_seconds()/86400.0
los_raw_days = pd.to_numeric(los_raw_days, errors="coerce")

# Auto unit guard: if median looks like hours (>> 60), convert to days
med = np.nanmedian(los_raw_days)
if med is not None and np.isfinite(med) and med > 60:
    los_raw_days = los_raw_days / 24.0

adm["los_days"] = los_raw_days
adm = adm.dropna(subset=["los_days"])
adm = adm[adm["los_days"] >= 0.0]

# Robust cap (handles extreme outliers)
cap_q = 0.95
clip_max = float(adm["los_days"].quantile(cap_q))
adm["los_days"] = adm["los_days"].clip(0, clip_max)

print_header("LOS label stats (days, after cap)")
print({
    "n": int(adm["los_days"].shape[0]),
    "mean": float(adm["los_days"].mean()),
    "std": float(adm["los_days"].std(ddof=0)),
    "median": float(adm["los_days"].median()),
    f"p{int(cap_q*100)}": clip_max,
    "max": float(adm["los_days"].max())
})

# Age/sex
def compute_age_at_admit(adm_df, pat_df):
    if pat_df is None or "subject_id" not in adm_df.columns:
        return pd.Series(np.nan, index=adm_df.index)
    pm = pat_df.set_index("subject_id")
    y = pd.to_datetime(adm_df["admittime"], errors="coerce").dt.year
    age = pd.Series(np.nan, index=adm_df.index, dtype=float)
    if {"anchor_year","anchor_age"}.issubset(pm.columns):
        ay = pd.to_numeric(adm_df["subject_id"].map(pm["anchor_year"]), errors="coerce")
        aa = pd.to_numeric(adm_df["subject_id"].map(pm["anchor_age"]),  errors="coerce")
        age = (aa + (y - ay)).astype(float)
    if "dob" in pm.columns:
        dob = pd.to_datetime(adm_df["subject_id"].map(pm["dob"]), errors="coerce")
        age_from_dob = (adm_df["admittime"] - dob).dt.total_seconds()/(365.25*24*3600)
        age = age.where(age.notna(), age_from_dob)
    return age.clip(lower=0, upper=120)

adm["age_at_admit"] = compute_age_at_admit(adm, pat)
if pat is not None:
    if "gender" in pat.columns:
        adm["sex"] = adm["subject_id"].map(pat.set_index("subject_id")["gender"])
    elif "sex" in pat.columns:
        adm["sex"] = adm["subject_id"].map(pat.set_index("subject_id")["sex"])
else:
    adm["sex"] = np.nan

# Simple comorb from diagnoses (optional)
def icd_to_flags(code, version):
    if not isinstance(code, str): code = str(code)
    c = code.strip().upper().replace(".","")
    out = {}
    def add(k): out[k]=1
    if pd.isna(version): version = 10 if re.match(r"^[A-Z]", c) else 9
    if version == 9 or (c[:1].isdigit() and len(c)>=3):
        if c.startswith("250"): add("comor_diabetes")
        try:
            p3 = float(c[:3])
            if 401 <= p3 < 406: add("comor_htn")
            if 410 <= p3 < 415: add("comor_cad")
        except Exception: pass
        if c.startswith("428"): add("comor_hf")
        if c.startswith("585"): add("comor_ckd")
        if c.startswith("493"): add("comor_asthma")
        if c.startswith("272"): add("comor_lipids")
        if c.startswith("278"): add("comor_obesity")
    else:
        if c.startswith(("E10","E11","E12","E13","E14")): add("comor_diabetes")
        if c.startswith(("I10","I11","I12","I13","I15","I16")): add("comor_htn")
        if c.startswith(("I20","I21","I22","I23","I24","I25")): add("comor_cad")
        if c.startswith("I50"): add("comor_hf")
        if c.startswith("N18"): add("comor_ckd")
        if c.startswith("J45"): add("comor_asthma")
        if c.startswith("E78"): add("comor_lipids")
        if c.startswith("E66"): add("comor_obesity")
    return out

comorb = None
if diag is not None and {"hadm_id","icd_code"}.issubset(diag.columns):
    if "icd_version" in diag.columns:
        diag["icd_version"] = pd.to_numeric(diag["icd_version"], errors="coerce")
    else:
        diag["icd_version"] = np.nan
    rows = []
    for hadm, g in diag.groupby("hadm_id", dropna=True):
        flags = {}
        for _, r in g.iterrows():
            flags.update(icd_to_flags(r["icd_code"], r.get("icd_version", np.nan)))
        if flags:
            flags["hadm_id"] = hadm
            flags["comor_count"] = sum(v for k,v in flags.items() if k.startswith("comor_"))
            rows.append(flags)
    if rows:
        comorb = pd.DataFrame(rows).fillna(0)
        comorb["hadm_id"] = to_int_id(comorb["hadm_id"])
        for c in [c for c in comorb.columns if c.startswith("comor_")]:
            comorb[c] = pd.to_numeric(comorb[c], errors="coerce").fillna(0).astype(int)

# ======================
# Assemble tabular dataset + features
# ======================
keep_cols = ["subject_id","hadm_id","admittime","los_days","age_at_admit","sex",
             "admission_type","admission_location","discharge_location","insurance","language"]
for c in keep_cols:
    if c not in adm.columns: adm[c] = np.nan
adm["subject_id"] = to_int_id(adm["subject_id"])
adm["hadm_id"]    = to_int_id(adm["hadm_id"])

base = adm[keep_cols].copy()

# Time features from admittime
base["dow"]       = pd.to_datetime(base["admittime"], errors="coerce").dt.dayofweek
base["month"]     = pd.to_datetime(base["admittime"], errors="coerce").dt.month
base["hour"]      = pd.to_datetime(base["admittime"], errors="coerce").dt.hour
base["is_weekend"]= base["dow"].isin([5,6]).astype(float)

# Emergency flag
def is_emerg(x):
    x = str(x).lower()
    return int(("emerg" in x) or (x in {"er","ed"}))
base["is_emergency"] = base["admission_type"].apply(is_emerg) if "admission_type" in base.columns else 0

if comorb is not None:
    comorb["hadm_id"] = to_int_id(comorb["hadm_id"])
    base = base.merge(comorb[["hadm_id","comor_count"]], on="hadm_id", how="left")
else:
    base["comor_count"] = 0

base["comor_count"] = pd.to_numeric(base["comor_count"], errors="coerce").fillna(0.0).astype(float)

# Categorical cleanup (preserve NaN -> imputer)
for c in ["admission_type","admission_location","discharge_location","insurance","language","sex"]:
    if c in base.columns:
        base[c] = base[c].astype("object")
        base[c] = base[c].where(~base[c].isna(), other=np.nan)

# Remove rows without labels/IDs
base = base.dropna(subset=["los_days","subject_id","hadm_id"]).reset_index(drop=True)

print_header("Dataset snapshot")
print(base.head(3))

# ======================
# Split by subject_id
# ======================
subjects = base["subject_id"].dropna().unique()
rng.shuffle(subjects)
n = len(subjects)
tr_ids = set(subjects[:int(0.7*n)])
va_ids = set(subjects[int(0.7*n):int(0.85*n)])
te_ids = set(subjects[int(0.85*n):])

def mask_ids(ids): return base["subject_id"].isin(ids)
train_df = base[mask_ids(tr_ids)].copy()
val_df   = base[mask_ids(va_ids)].copy()
test_df  = base[mask_ids(te_ids)].copy()

print_header("Split sizes")
print({k: len(v) for k,v in {"train":train_df, "val":val_df, "test":test_df}.items()})

In [None]:
# ======================
# Pipelines with IMPUTERS (fixes NaNs)
# ======================
from sklearn.compose import ColumnTransformer, TransformedTargetRegressor
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import Ridge, HuberRegressor, LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import joblib
import inspect

# --- version-safe RMSE helper (always returns RMSE) ---
def rmse_score(y_true, y_pred):
    y_true = np.asarray(y_true, float); y_pred = np.asarray(y_pred, float)
    try:
        if 'squared' in inspect.signature(mean_squared_error).parameters:
            return float(mean_squared_error(y_true, y_pred))
    except Exception:
        pass
    return float(np.sqrt(mean_squared_error(y_true, y_pred)))

NUM = ["age_at_admit","comor_count","dow","month","hour","is_weekend","is_emergency"]
CAT = ["sex","admission_type","admission_location","discharge_location","insurance","language"]

num_pipe = Pipeline([
    ("imputer", SimpleImputer(strategy="median")),
    ("scaler", StandardScaler())
])
cat_pipe = Pipeline([
    ("imputer", SimpleImputer(strategy="most_frequent")),
    ("ohe", OneHotEncoder(handle_unknown="ignore", min_frequency=20))
])

pre = ColumnTransformer([
    ("num", num_pipe, NUM),
    ("cat", cat_pipe, CAT)
], remainder="drop")

def baseline_report(df, train_mean):
    y = df["los_days"].values.astype(float)
    p = np.full_like(y, float(train_mean), dtype=float)
    rmse = float(np.sqrt(np.mean((p-y)**2)))
    mae  = float(np.mean(np.abs(p-y)))
    r2   = float(1 - np.sum((y-p)**2)/(np.var(y)*len(y)+1e-9))
    return {"RMSE": rmse, "MAE": mae, "R2": r2}

def sanity_rmse_r2(y, p):
    y = np.asarray(y, float); p = np.asarray(p, float)
    rmse = float(np.sqrt(np.mean((p-y)**2)))
    r2   = float(r2_score(y, p))
    rmse_base = float(np.sqrt(np.mean((y - y.mean())**2)))
    rmse_expected = float(((1 - r2) ** 0.5) * rmse_base)
    return {"rmse": rmse, "r2": r2, "rmse_baseline": rmse_base,
            "rmse_expected_from_r2": rmse_expected,
            "ratio_rmse_over_expected": (rmse / (rmse_expected + 1e-12))}

class CalibratedPipeline:
    """Wrap a Pipeline and learn a linear post-hoc calibration y ≈ a*pred+b on train.
       NOTE: We DO NOT pickle this class. We save only a pure sklearn pipeline later.
    """
    def __init__(self, pipe):
        self.pipe = pipe
        self.cal = LinearRegression()
        self.fitted = False
    def fit(self, X, y):
        self.pipe.fit(X, y)
        p = self.pipe.predict(X).reshape(-1, 1)
        self.cal.fit(p, y)
        self.fitted = True
        return self
    def predict(self, X):
        p = self.pipe.predict(X).reshape(-1, 1)
        return self.cal.predict(p)

def fit_and_eval(name, model, train_df, val_df, test_df, calibrate=True):
    base_pipe = Pipeline([("pre", pre), ("mdl", model)])
    pipe = CalibratedPipeline(base_pipe) if calibrate else base_pipe

    FEATS = NUM + CAT
    Xtr, ytr = train_df[FEATS], train_df["los_days"].astype(float)
    Xva, yva = val_df[FEATS],   val_df["los_days"].astype(float)
    Xte, yte = test_df[FEATS],  test_df["los_days"].astype(float)

    for y_name, yv in [("train", ytr), ("val", yva), ("test", yte)]:
        if not np.all(np.isfinite(yv)):
            raise ValueError(f"Non-finite y in {y_name} set.")

    t0 = time.time()
    pipe.fit(Xtr, ytr)
    tr_time = time.time()-t0

    def report(X, y):
        p = pipe.predict(X)
        rmse = rmse_score(y, p)
        mae  = float(mean_absolute_error(y, p))
        r2   = float(r2_score(y, p))
        san  = sanity_rmse_r2(y, p)
        return {"RMSE": rmse, "MAE": mae, "R2": r2, "sanity": san}

    rep = {"train": report(Xtr, ytr), "val": report(Xva, yva), "test": report(Xte, yte)}
    print_header(f"{name} metrics")
    print("Train :", {k:v for k,v in rep["train"].items() if k!='sanity'})
    print("Val   :", {k:v for k,v in rep["val"].items() if k!='sanity'})
    print("Test  :", {k:v for k,v in rep["test"].items() if k!='sanity'})

    rat = rep["val"]["sanity"]["ratio_rmse_over_expected"]
    if rat > 2.0:
        print(f"[WARN] RMSE appears inflated vs R² on val (ratio={rat:.2f}). "
              f"Units or a few extreme outliers may still be present.")

    return pipe, rep, tr_time

In [8]:
# ======================
# Baseline
# ======================
baseline_mean = float(train_df["los_days"].mean())
print_header("Baseline (train-mean) metrics")
print("Val :", baseline_report(val_df, baseline_mean))
print("Test:", baseline_report(test_df, baseline_mean))

# ======================
# Models
# ======================
ridge, ridge_rep, _ = fit_and_eval("Ridge(alpha=3.0)", Ridge(alpha=3.0), train_df, val_df, test_df)

# ---- FIXED: log-target wraps only the regressor; preprocessing stays outside ----
ridge_log = TransformedTargetRegressor(
    regressor=Ridge(alpha=3.0), func=np.log1p, inverse_func=np.expm1
)
ridge_log_m, ridge_log_rep, _ = fit_and_eval("Ridge (log-target)", ridge_log, train_df, val_df, test_df)

huber, huber_rep, _ = fit_and_eval("HuberRegressor", HuberRegressor(alpha=1e-4), train_df, val_df, test_df)

rf, rf_rep, _ = fit_and_eval(
    "RandomForest",
    RandomForestRegressor(n_estimators=400, max_depth=None, min_samples_leaf=3, n_jobs=-1, random_state=SEED),
    train_df, val_df, test_df
)

# ======================
# Select best by Val RMSE
# ======================
candidates = [
    ("ridge", ridge, ridge_rep),
    ("ridge_log", ridge_log_m, ridge_log_rep),
    ("huber", huber, huber_rep),
    ("rf", rf, rf_rep)
]
best_name, best_model, best_rep = min(candidates, key=lambda t: t[2]["val"]["RMSE"])

print_header(f"Best (by Val RMSE): {best_name}")
print("Test metrics:", {k:v for k,v in best_rep["test"].items() if k!='sanity'})


Baseline (train-mean) metrics
Val : {'RMSE': 4.0194242709547305, 'MAE': 3.057494173207971, 'R2': -4.8989754345862835e-05}
Test: {'RMSE': 4.000061408635235, 'MAE': 3.046565622906377, 'R2': -7.301197698850181e-05}

Ridge(alpha=3.0) metrics
Train : {'RMSE': 3.1179297281505662, 'MAE': 2.1838660740025153, 'R2': 0.4035272912651202}
Val   : {'RMSE': 3.0895496166451735, 'MAE': 2.161242184535112, 'R2': 0.4091409083001437}
Test  : {'RMSE': 3.079397040220502, 'MAE': 2.1635665152188377, 'R2': 0.4073070433112076}

Ridge (log-target) metrics
Train : {'RMSE': 3.1194015342574053, 'MAE': 2.15671368455773, 'R2': 0.402964033299378}
Val   : {'RMSE': 3.087117153027243, 'MAE': 2.1310814022333187, 'R2': 0.41007093217877255}
Test  : {'RMSE': 3.08012016753548, 'MAE': 2.1364571612381256, 'R2': 0.40702864933768534}

HuberRegressor metrics
Train : {'RMSE': 3.1438072283896976, 'MAE': 2.1748523917227636, 'R2': 0.393585261181262}
Val   : {'RMSE': 3.113227610997899, 'MAE': 2.151950950957782, 'R2': 0.4000496359005095

In [9]:
# ======================
# SAVE: production-ready inference pipeline (pure sklearn, picklable)
# ======================
# We save ONLY a plain sklearn Pipeline: ('pre', ColumnTransformer) + ('model', final estimator).
# Do NOT save the custom CalibratedPipeline wrapper.
from sklearn.base import RegressorMixin

FEATS = NUM + CAT
X_train_full = train_df[FEATS].copy()
y_train_full = train_df["los_days"].astype(float).values

# Build clean inference pipeline
inference_pipe = Pipeline([
    ("pre", pre),
    ("model", best_model.pipe.steps[-1][1] if isinstance(best_model, CalibratedPipeline) else best_model.steps[-1][1])
    if isinstance(best_model, (CalibratedPipeline, Pipeline)) else
    ("model", best_model)  # fallback if already a regressor
])

# If the logic above confuses, simplify:
if isinstance(best_model, CalibratedPipeline):
    # take the underlying sklearn Pipeline and extract the final regressor
    final_reg = best_model.pipe.steps[-1][1]
    inference_pipe = Pipeline([("pre", pre), ("model", final_reg)])
elif isinstance(best_model, Pipeline):
    final_reg = best_model.steps[-1][1]
    inference_pipe = Pipeline([("pre", pre), ("model", final_reg)])
else:
    # plain estimator already
    inference_pipe = Pipeline([("pre", pre), ("model", best_model)])

# Fit on full training data
inference_pipe.fit(X_train_full, y_train_full)

# Persist model + feature list
MODEL_PATH = MODELS_DIR / "los_inference_pipeline.pkl"
joblib.dump(inference_pipe, MODEL_PATH)
with open(MODELS_DIR / "los_features.json", "w", encoding="utf-8") as f:
    json.dump({"features": FEATS}, f)

# Smoke test reload -> predict
reloaded = joblib.load(MODEL_PATH)
_ = reloaded.predict(X_train_full.head(5))
print_header("Saved inference pipeline")
print({"path": str(MODEL_PATH.resolve()), "features": FEATS})


Saved inference pipeline
{'path': 'D:\\HealthAI Project\\Models\\los_inference_pipeline.pkl', 'features': ['age_at_admit', 'comor_count', 'dow', 'month', 'hour', 'is_weekend', 'is_emergency', 'sex', 'admission_type', 'admission_location', 'discharge_location', 'insurance', 'language']}


In [10]:
# ======================
# (Optional) Save individual training-time wrappers to artifacts (debug only)
# ======================
with open(ART/"ridge_model.pkl","wb")      as f: pickle.dump(ridge, f)
with open(ART/"ridge_log_model.pkl","wb")  as f: pickle.dump(ridge_log_m, f)
with open(ART/"huber_model.pkl","wb")      as f: pickle.dump(huber, f)
with open(ART/"rf_model.pkl","wb")         as f: pickle.dump(rf, f)
with open(ART/"best_model_name.txt","w",encoding="utf-8") as f: f.write(best_name)
with open(ART/"best_model.pkl","wb")       as f: pickle.dump(best_model, f)

In [11]:
# ======================
# Consolidated report (+ sanity)
# ======================
def pack(rep):
    return {k: {mk: mv for mk, mv in v.items() if mk != "sanity"} for k, v in rep.items()}
report = {
    "label_cap_percentile": 0.95,
    "label_stats_days": {
        "mean": float(adm["los_days"].mean()),
        "std": float(adm["los_days"].std(ddof=0)),
        "median": float(adm["los_days"].median()),
        "p95": float(adm["los_days"].quantile(0.95)),
        "max": float(adm["los_days"].max())
    },
    "baseline": {
        "val": baseline_report(val_df, baseline_mean),
        "test": baseline_report(test_df, baseline_mean)
    },
    "ridge": pack(ridge_rep),
    "ridge_log": pack(ridge_log_rep),
    "huber": pack(huber_rep),
    "rf": pack(rf_rep),
    "best": {"name": best_name, "test": {k:v for k,v in best_rep["test"].items() if k!='sanity'}}
}

with open(ART/"report.json","w",encoding="utf-8") as f: json.dump(report, f, indent=2)

print_header("Artifacts saved")
print({
    "ridge_model.pkl": str((ART/"ridge_model.pkl").resolve()),
    "ridge_log_model.pkl": str((ART/"ridge_log_model.pkl").resolve()),
    "huber_model.pkl": str((ART/"huber_model.pkl").resolve()),
    "rf_model.pkl": str((ART/"rf_model.pkl").resolve()),
    "best_model.pkl": str((ART/"best_model.pkl").resolve()),
    "best_model_name.txt": str((ART/"best_model_name.txt").resolve()),
    "report.json": str((ART/"report.json").resolve()),
    "inference_pipeline.pkl": str(MODELS_DIR.joinpath("los_inference_pipeline.pkl").resolve()),
    "los_features.json": str(MODELS_DIR.joinpath("los_features.json").resolve())
})


Artifacts saved
{'ridge_model.pkl': 'D:\\HealthAI Project\\los_artifacts\\ridge_model.pkl', 'ridge_log_model.pkl': 'D:\\HealthAI Project\\los_artifacts\\ridge_log_model.pkl', 'huber_model.pkl': 'D:\\HealthAI Project\\los_artifacts\\huber_model.pkl', 'rf_model.pkl': 'D:\\HealthAI Project\\los_artifacts\\rf_model.pkl', 'best_model.pkl': 'D:\\HealthAI Project\\los_artifacts\\best_model.pkl', 'best_model_name.txt': 'D:\\HealthAI Project\\los_artifacts\\best_model_name.txt', 'report.json': 'D:\\HealthAI Project\\los_artifacts\\report.json', 'inference_pipeline.pkl': 'D:\\HealthAI Project\\Models\\los_inference_pipeline.pkl', 'los_features.json': 'D:\\HealthAI Project\\Models\\los_features.json'}


In [None]:
# ======================
# Merge all CSVs → single table for prediction
# ======================
print_header("Merging all sources into one prediction table")

# --- 1) Start from base (already includes LOS label, age/sex, time-features, emergency flag, comorb_count) ---
merged = base.copy()

# --- 2) OPTIONAL: Add vitals 24h summaries per admission if vitals are available ---
def summarize_vitals(vit_df: pd.DataFrame, adm_df: pd.DataFrame) -> pd.DataFrame:
    """
    Returns per-hadm_id summary of vitals.
    If 'charttime' exists, restrict to first 24h since admission.
    Otherwise, aggregates across all rows available for that hadm_id.
    """
    v = vit_df.copy()
    v.columns = [c.lower() for c in v.columns]
    # normalize ids
    if "hadm_id" not in v.columns:
        return pd.DataFrame({"hadm_id": pd.Series(dtype="Int64")})  # nothing to merge

    v["hadm_id"] = to_int_id(v["hadm_id"])

    # Try to standardize a timestamp column
    ts_col = None
    for cand in ("charttime", "chart_time", "measured_time", "event_time"):
        if cand in v.columns:
            ts_col = cand
            break
    if ts_col is not None:
        v[ts_col] = pd.to_datetime(v[ts_col], errors="coerce")

    # Common vital columns if present
    maybe_vital_cols = [
        # heart/bp
        "heartrate","heart_rate","hr",
        "sbp","sysbp","systolic_bp","bp_systolic",
        "dbp","diasbp","diastolic_bp","bp_diastolic",
        "mbp","meanbp","map",
        # resp/temp/spo2
        "resp_rate","respiratory_rate","rr",
        "temperature","temp_c","temp_f",
        "spo2","o2sat","oxygen_saturation",
    ]
    present = [c for c in maybe_vital_cols if c in v.columns]

    if not present:
        # nothing to aggregate
        return pd.DataFrame({"hadm_id": v["hadm_id"].dropna().unique()})

    # Keep rows within first 24h of admission if we can align by time
    if ts_col is not None and {"hadm_id", "admittime"}.issubset(adm_df.columns):
        adm_times = adm_df[["hadm_id", "admittime"]].dropna().copy()
        adm_times["hadm_id"] = to_int_id(adm_times["hadm_id"])
        v = v.merge(adm_times, on="hadm_id", how="left")
        # Keep only obs in [admittime, admittime+24h]
        v = v[(v[ts_col].notna()) & (v["admittime"].notna())]
        dt = (v[ts_col] - v["admittime"]).dt.total_seconds() / 3600.0
        v = v[(dt >= 0) & (dt <= 24)]

    # Aggregate (mean/median) by hadm_id
    agg_map = {c: ["mean", "median"] for c in present}
    g = v.groupby("hadm_id")[present].agg(agg_map)
    # Flatten MultiIndex columns: e.g., 'heartrate_mean', 'heartrate_median'
    g.columns = [f"{col}_{stat}" for col, stat in g.columns]
    g = g.reset_index()
    return g

vit_agg = pd.DataFrame({"hadm_id": pd.Series(dtype="Int64")})
if vit is not None:
    try:
        vit_agg = summarize_vitals(vit, adm)
        merged = merged.merge(vit_agg, on="hadm_id", how="left")
        print(f"Merged vitals summary: {vit_agg.shape}")
    except Exception as e:
        print(f"[WARN] Could not summarize/merge vitals: {e}")

# --- 3) (Already done above) comorbidities merged as comor_count in `base` ---

# --- 4) Order and type-clean for prediction ---
# Keep ID/time columns for traceability, then features your model expects
ID_TIME_COLS = ["subject_id", "hadm_id", "admittime", "dischtime"]
for c in ID_TIME_COLS:
    if c not in merged.columns:
        merged[c] = np.nan

# Model features (must match your training pipeline)
FEATS = ["age_at_admit","comor_count","dow","month","hour","is_weekend","is_emergency",
         "sex","admission_type","admission_location","discharge_location","insurance","language"]

# Ensure all FEATS exist (add NaN placeholders if missing so imputers can handle them)
for c in FEATS:
    if c not in merged.columns:
        merged[c] = np.nan

# Final column order: IDs/times → FEATS → (any vitals summaries if you want to keep them)
ordered_cols = ID_TIME_COLS + FEATS + [c for c in merged.columns
                                       if c not in ID_TIME_COLS + FEATS + ["los_days"]]
merged = merged.loc[:, ordered_cols + (["los_days"] if "los_days" in merged.columns else [])]

# De-duplicate per admission (keep the first row if merges created multiples)
merged = merged.sort_values(by=["subject_id","hadm_id","admittime"], na_position="last")
merged = merged.drop_duplicates(subset=["hadm_id"], keep="first").reset_index(drop=True)

print_header("Merged prediction table snapshot")
print(merged.head(5))

# Save a clean CSV you can use directly for inference (the model will use FEATS; extra cols are for reference)
MERGED_CSV = MODELS_DIR / "los_test_data.csv"
merged.to_csv(MERGED_CSV, index=False)
print(f"[OK] Wrote merged table for prediction → {MERGED_CSV.resolve()}")


Merging all sources into one prediction table
Merged vitals summary: (0, 1)

Merged prediction table snapshot
   subject_id   hadm_id           admittime  dischtime  age_at_admit  \
0    10000032  22595853 2180-05-06 22:23:00        NaN          52.0   
1    10000032  22841357 2180-06-26 18:27:00        NaN          52.0   
2    10000032  25742920 2180-08-05 23:44:00        NaN          52.0   
3    10000032  29079034 2180-07-23 12:35:00        NaN          52.0   
4    10000068  25022803 2160-03-03 23:16:00        NaN          19.0   

   comor_count  dow  month  hour  is_weekend  is_emergency sex  \
0          0.0    5      5    22         1.0             0   F   
1          0.0    0      6    18         0.0             0   F   
2          0.0    5      8    23         1.0             0   F   
3          0.0    6      7    12         1.0             0   F   
4          0.0    0      3    23         0.0             0   F   

   admission_type      admission_location discharge_locatio