# 1. Introduction
# ML for Healthcare with Weights & Biases: Interpretable Baselines and Clinical Evaluation

We continue our focus on shifting from coding models to *understanding* them — how well they perform, how trustworthy their probabilities are, and how they behave across patient groups. Strenghtening the foundation for clinically aware ML practice.

**Objective**  
Build, track, and interpret models predicting in-hospital mortality in ICU patients using Weights & Biases (W&B).  
This notebook transforms standard Machine Learning (ML) practice into an auditable, interpretable, and clinically meaningful workflow

**You will learn**
- How to track model training and evaluation runs using Weights & Biases  
- How to interpret models and understand their calibration  
- How to examine fairness and subgroup performance  
- How to communicate model insights for healthcare decisions

**Models**
We’ll compare three interpretable baselines:
1. **Logistic Regression**: simple linear reference, easy to explain  
2. **Decision Tree (shallow)**: intuitive splits, visually transparent  
3. **Random Forest**: robust ensemble, main focus for tuning and interpretability

**Dataset**
[PhysioNet Challenge 2012 dataset](https://physionet.org/content/challenge-2012/1.0.0/), containing clinical measurements, demographics, and the target:  
`In-hospital_death` (binary: 1 = patient died during stay, 0 = survived).

**Weights & Biases**
Used to:
- Track configurations, metrics, and plots  
- Compare models and hyperparameters  
- Log interpretability results (feature importance, calibration, subgroups)  
- Support model transparency and documentation



# 2. Setup

In [19]:
# Setup: imports, reproducibility, and Weights & Biases initialization

import os
import sys
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import shap
# Initialize Weights & Biases
import wandb

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
from sklearn.metrics import (
    roc_auc_score, 
    average_precision_score, 
    brier_score_loss,
    roc_curve, 
    precision_recall_curve,
    confusion_matrix,
)



# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)


In [2]:
# !wandb login
# !wandb.login(relogin=True)

In [3]:
WB_PROJECT = "ml-healthcare-intro"

# wandb.login() # Uncomment if not logged in
run = wandb.init(
    project=WB_PROJECT, 
    job_type="setup", 
        config={
        "seed": SEED,
        "framework": "scikit-learn",
        "dataset": "physionet2012_set_a"
    }
)

# Log environment metadata
wandb.config.update({
    "python_version": sys.version.split()[0],
    "pandas_version": pd.__version__,
    "numpy_version": np.__version__,
}, allow_val_change=True)

print(f"Weights & Biases tracking URL: {run.url}")

wandb: Currently logged in as: idiazl to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Weights & Biases tracking URL: https://wandb.ai/idiazl/ml-healthcare-intro/runs/c3c37ya9


# 3. Data loading and initial checks
We will load the dataset, confirm the target, and log basic summaries to Weights & Biases. This lets us explore class balance, missingness, and a quick data preview directly in the dashboard.

- Use the Weights & Biases tables to examine class balance, missingness, and a data preview  
- Check that the target definition and event rate look reasonable before training



In [4]:
# --- 1. Load Data ---
PATH = "PhysionetChallenge2012-set-a.csv.gz"

# Simple check to ensure the data file exists before trying to load it
if not os.path.exists(PATH):
    raise FileNotFoundError(
        f"Error: The data file was not found at '{PATH}'. "
        "Please ensure the dataset is in the correct directory."
    )

ICU = pd.read_csv(PATH, compression="gzip")

TARGET = "In-hospital_death"
ID_COL = "recordid" if "recordid" in ICU.columns else None

if TARGET not in ICU.columns:
    raise ValueError(f"Target column '{TARGET}' not found in dataset")

# Ensure target is numeric and binary
ICU[TARGET] = pd.to_numeric(ICU[TARGET], errors="coerce").fillna(0).astype(int)

# Basic facts
n_rows, n_cols = ICU.shape
pos_rate = float(ICU[TARGET].mean())

# Class balance table
cb_series = ICU[TARGET].value_counts().sort_index()
class_balance_tbl = wandb.Table(data=[[int(k), int(v), float(v / n_rows)] for k, v in cb_series.items()],
                                columns=["label", "count", "fraction"])

# Missingness table (top 30)
miss = ICU.isna().mean().sort_values(ascending=False)
miss_top = miss.head(30).reset_index()
miss_top.columns = ["column", "missing_fraction"]
missing_tbl = wandb.Table(data=miss_top.values.tolist(), columns=list(miss_top.columns))

# Data preview table (sample up to 200 rows for UI responsiveness)
preview = ICU.sample(n=min(200, len(ICU)), random_state=SEED)
preview_tbl = wandb.Table(dataframe=preview)

wandb.log({
    "dataset_rows": n_rows,
    "dataset_cols": n_cols,
    "positive_rate": pos_rate,
    "class_balance_table": class_balance_tbl,
    "missingness_top30_table": missing_tbl,
    "data_preview_table": preview_tbl
})

print(f"Loaded ICU with shape {ICU.shape} and positive rate {pos_rate:.3f}")

Loaded ICU with shape (4000, 120) and positive rate 0.139


# 4. Simple preprocessing
### Preprocessing and splits
We will split the data into train, validation, and test sets, impute missing values and one-hot encode categorical variables, and log feature lists and split sizes to Weights & Biases for transparency


In [5]:
# Preprocessing: split data and prepare simple pipelines

# Drop ID column if present
X = ICU.drop(columns=[c for c in [TARGET, ID_COL] if c in ICU.columns])
y = ICU[TARGET]

# Identify categorical and numeric columns
cat_cols = [c for c in X.columns if X[c].dtype == "object"]
num_cols = [c for c in X.columns if c not in cat_cols]

# Split data (60 percent train, 20 percent validation, 20 percent test)
X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=SEED
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train_full, y_train_full, test_size=0.25, stratify=y_train_full, random_state=SEED
)

print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")

# Define transformations
num_transformer = SimpleImputer(strategy="median")
cat_transformer = Pipeline(steps=[
    ("imputer", SimpleImputer(strategy="most_frequent")),
    ("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=False))
])

preprocessor = ColumnTransformer(
    transformers=[
        ("num", num_transformer, num_cols),
        ("cat", cat_transformer, cat_cols)
    ],
    verbose_feature_names_out=False
)

preprocessor.set_output(transform="pandas")

# Fit and transform
X_train_t = preprocessor.fit_transform(X_train)
X_val_t = preprocessor.transform(X_val)
X_test_t = preprocessor.transform(X_test)

# Log split sizes
wandb.log({
    "train_rows": len(X_train),
    "val_rows": len(X_val),
    "test_rows": len(X_test),
    "n_num_features_raw": len(num_cols),
    "n_cat_features_raw": len(cat_cols),
    "n_features_transformed": X_train_t.shape[1]
})

# Log feature lists as W&B Tables for inspectability
num_tbl = wandb.Table(data=[[c, "numeric"] for c in num_cols], columns=["feature", "type"])
cat_tbl = wandb.Table(data=[[c, "categorical"] for c in cat_cols], columns=["feature", "type"])
wandb.log({"feature_list_numeric": num_tbl, "feature_list_categorical": cat_tbl})

print("Preprocessing complete")


Train: (2400, 118), Val: (800, 118), Test: (800, 118)
Preprocessing complete


Use the feature tables and split sizes in Weights & Biases to verify preprocessing choices  
All models next will consume the same transformed matrices for fair comparisons


# 5. Logistic Regression: establishing a simple reference

Before exploring complex models, it’s helpful to start with a simple and interpretable baseline. Logistic Regression gives a linear relationship between features and the log-odds of the outcome. Helping us understand whether more flexible models (like Random Forests) truly add value.

1. We’ll train a Logistic Regression model, evaluate it on the validation and test sets
2. Log all metrics to Weights & Biases to compare later


In [6]:
# Logistic Regression baseline with ROC/PR logged as W&B Tables
run = wandb.init(
    project=WB_PROJECT,
    job_type="baseline",
    config={"model": "logistic_regression", "seed": SEED},
    reinit=True,
    settings=wandb.Settings(start_method="thread")
)

# Train
log_reg = LogisticRegression(max_iter=500, solver="liblinear", random_state=SEED)
log_reg.fit(X_train_t, y_train)

# Predict probabilities
y_val_prob = log_reg.predict_proba(X_val_t)[:, 1]
y_test_prob = log_reg.predict_proba(X_test_t)[:, 1]

# Metrics
val_auc = roc_auc_score(y_val, y_val_prob)
val_pr  = average_precision_score(y_val, y_val_prob)
val_brier = brier_score_loss(y_val, y_val_prob)

test_auc = roc_auc_score(y_test, y_test_prob)
test_pr  = average_precision_score(y_test, y_test_prob)
test_brier = brier_score_loss(y_test, y_test_prob)

wandb.log({
    "val_auc": val_auc,
    "val_pr": val_pr,
    "val_brier": val_brier,
    "test_auc": test_auc,
    "test_pr": test_pr,
    "test_brier": test_brier
})

# ROC and PR curve points as Tables for interactive plots in W&B
fpr, tpr, roc_thresh = roc_curve(y_val, y_val_prob)
roc_table = wandb.Table(data=list(zip(fpr, tpr, roc_thresh)), columns=["fpr", "tpr", "threshold"])
wandb.log({"roc_curve_val": roc_table})

prec, rec, pr_thresh = precision_recall_curve(y_val, y_val_prob)
# sklearn returns thresholds length one less than precision/recall, pad with None for table alignment
pr_table = wandb.Table(data=list(zip(rec, prec, list(pr_thresh) + [None])), columns=["recall", "precision", "threshold"])
wandb.log({"pr_curve_val": pr_table})

# Coefficients for transparency
coef_df = pd.DataFrame({"feature": X_train_t.columns, "coefficient": log_reg.coef_[0]})
coef_tbl = wandb.Table(dataframe=coef_df.sort_values("coefficient", ascending=False))
wandb.log({"log_reg_coefficients": coef_tbl})

# Predictions table sample for later slicing in the UI
pred_sample = pd.DataFrame({
    "id": X_val.index if ID_COL is None else X_val.index,  # keep index for traceability
    "y_true": y_val.values,
    "y_prob": y_val_prob
}).sample(n=min(500, len(y_val)), random_state=SEED)

wandb.log({"predictions_val_sample": wandb.Table(dataframe=pred_sample)})

run.finish()

print(f"LR validation AUROC {val_auc:.3f}, AUPRC {val_pr:.3f}, Brier {val_brier:.3f}")



0,1
dataset_cols,▁
dataset_rows,▁
n_cat_features_raw,▁
n_features_transformed,▁
n_num_features_raw,▁
positive_rate,▁
test_rows,▁
train_rows,▁
val_rows,▁

0,1
dataset_cols,120.0
dataset_rows,4000.0
n_cat_features_raw,0.0
n_features_transformed,118.0
n_num_features_raw,118.0
positive_rate,0.1385
test_rows,800.0
train_rows,2400.0
val_rows,800.0




0,1
test_auc,▁
test_brier,▁
test_pr,▁
val_auc,▁
val_brier,▁
val_pr,▁

0,1
test_auc,0.8622
test_brier,0.09375
test_pr,0.4755
val_auc,0.87606
val_brier,0.08569
val_pr,0.55215


LR validation AUROC 0.876, AUPRC 0.552, Brier 0.086


### To Do

Open the Weights & Biases project  
Use the ROC and PR tables to create interactive line plots for this run  
Inspect the coefficients table for a first look at feature effects


# 6. Decision Tree baseline

A shallow tree is easy to read and helps us see simple non-linear rules. We will train a small tree, log metrics, curve points, feature importances, and a predictions sample to Weights & Biases


In [7]:
# 6. Decision Tree baseline
run = wandb.init(
    project=WB_PROJECT,
    job_type="baseline",
    config={
        "model_type": "decision_tree",
        "seed": SEED,
        "max_depth": 4,
        "min_samples_leaf": 20
    },
    reinit=True,
    settings=wandb.Settings(start_method="thread")
)

# Train a small, readable tree
dt = DecisionTreeClassifier(
    max_depth=4,
    min_samples_leaf=20,
    random_state=SEED
)
dt.fit(X_train_t, y_train)

# Predict probabilities
y_val_prob = dt.predict_proba(X_val_t)[:, 1]
y_test_prob = dt.predict_proba(X_test_t)[:, 1]

# Metrics
val_auc = roc_auc_score(y_val, y_val_prob)                    # Area under ROC
val_pr  = average_precision_score(y_val, y_val_prob)          # Area under PR
val_brier = brier_score_loss(y_val, y_val_prob)               # Calibration error

test_auc = roc_auc_score(y_test, y_test_prob)
test_pr  = average_precision_score(y_test, y_test_prob)
test_brier = brier_score_loss(y_test, y_test_prob)

wandb.log({
    "val_auc": val_auc,
    "val_pr": val_pr,
    "val_brier": val_brier,
    "test_auc": test_auc,
    "test_pr": test_pr,
    "test_brier": test_brier
})

# ROC and PR curve points as Tables
fpr, tpr, roc_thresh = roc_curve(y_val, y_val_prob)
wandb.log({"roc_curve_val": wandb.Table(data=list(zip(fpr, tpr, roc_thresh)), columns=["fpr", "tpr", "threshold"])})

prec, rec, pr_thresh = precision_recall_curve(y_val, y_val_prob)
wandb.log({"pr_curve_val": wandb.Table(data=list(zip(rec, prec, list(pr_thresh) + [None])), columns=["recall", "precision", "threshold"])})

# Feature importances
imp_df = (
    pd.DataFrame({"feature": X_train_t.columns, "importance": dt.feature_importances_})
    .sort_values("importance", ascending=False)
)
wandb.log({"feature_importances": wandb.Table(dataframe=imp_df)})

# Predictions sample for slicing in the UI
pred_sample = pd.DataFrame({
    "id": X_val.index,
    "y_true": y_val.values,
    "y_prob": y_val_prob
}).sample(n=min(500, len(y_val)), random_state=SEED)
wandb.log({"predictions_val_sample": wandb.Table(dataframe=pred_sample)})

run.finish()

print(f"DT validation AUROC {val_auc:.3f}, AUPRC {val_pr:.3f}, Brier {val_brier:.3f}")


0,1
test_auc,▁
test_brier,▁
test_pr,▁
val_auc,▁
val_brier,▁
val_pr,▁

0,1
test_auc,0.99098
test_brier,0.01047
test_pr,0.94055
val_auc,0.98514
val_brier,0.01532
val_pr,0.9238


DT validation AUROC 0.985, AUPRC 0.924, Brier 0.015


# 7. Random Forest baseline
### Random Forest baseline

As we know, a Random Forest averages many trees to improve stability and performance. We will train a baseline model and log metrics, curve points, feature importances, and a predictions sample to W&B. Later we will tune hyperparameters with a short sweep

In [9]:
# 7. Random Forest baseline
run = wandb.init(
    project=WB_PROJECT,
    job_type="baseline",
    config={
        "model_type": "random_forest",
        "seed": SEED,
        "n_estimators": 300,
        "max_depth": None,
        "max_features": "sqrt",
        "min_samples_leaf": 5,
        "n_jobs": -1
    },
    reinit=True,
    settings=wandb.Settings(start_method="thread")
)

rf = RandomForestClassifier(
    n_estimators=300,
    max_depth=None,
    max_features="sqrt",
    min_samples_leaf=5,
    random_state=SEED,
    n_jobs=-1
)
rf.fit(X_train_t, y_train)

# Predict probabilities
y_val_prob = rf.predict_proba(X_val_t)[:, 1]
y_test_prob = rf.predict_proba(X_test_t)[:, 1]

# Metrics
val_auc = roc_auc_score(y_val, y_val_prob)                    # Area under ROC
val_pr  = average_precision_score(y_val, y_val_prob)          # Area under PR
val_brier = brier_score_loss(y_val, y_val_prob)               # Calibration error

test_auc = roc_auc_score(y_test, y_test_prob)
test_pr  = average_precision_score(y_test, y_test_prob)
test_brier = brier_score_loss(y_test, y_test_prob)

wandb.log({
    "val_auc": val_auc,
    "val_pr": val_pr,
    "val_brier": val_brier,
    "test_auc": test_auc,
    "test_pr": test_pr,
    "test_brier": test_brier
})

# ROC and PR curve points as Tables
fpr, tpr, roc_thresh = roc_curve(y_val, y_val_prob)
wandb.log({"roc_curve_val": wandb.Table(data=list(zip(fpr, tpr, roc_thresh)), columns=["fpr", "tpr", "threshold"])})

prec, rec, pr_thresh = precision_recall_curve(y_val, y_val_prob)
wandb.log({"pr_curve_val": wandb.Table(data=list(zip(rec, prec, list(pr_thresh) + [None])), columns=["recall", "precision", "threshold"])})

# Feature importances
imp_df = (
    pd.DataFrame({"feature": X_train_t.columns, "importance": rf.feature_importances_})
    .sort_values("importance", ascending=False)
)
wandb.log({"feature_importances": wandb.Table(dataframe=imp_df)})

# Predictions sample for slicing in the UI
pred_sample = pd.DataFrame({
    "id": X_val.index,
    "y_true": y_val.values,
    "y_prob": y_val_prob
}).sample(n=min(500, len(y_val)), random_state=SEED)
wandb.log({"predictions_val_sample": wandb.Table(dataframe=pred_sample)})

run.finish()

print(f"RF validation AUROC {val_auc:.3f}, AUPRC {val_pr:.3f}, Brier {val_brier:.3f}")


0,1
test_auc,▁
test_brier,▁
test_pr,▁
val_auc,▁
val_brier,▁
val_pr,▁

0,1
test_auc,0.98419
test_brier,0.05763
test_pr,0.89292
val_auc,0.98135
val_brier,0.05921
val_pr,0.87774


RF validation AUROC 0.981, AUPRC 0.878, Brier 0.059


#### TO DO
- Use the Weights & Biases compare view to contrast Logistic Regression, Decision Tree, and Random Forest  
- Check whether the Random Forest improves area under ROC, area under PR, and calibration score  
- Review feature importances to see if top drivers align with clinical sense


### 7.1. Calibration focus

In clinical use, a calibrated model lets you set thresholds that align with safety targets. Let's log binned reliability tables and summary scores to Weights & Biases for Validation and Test


In [12]:
# Calibration for the current Random Forest baseline
# Logs reliability tables and summary scores to Weights & Biases

run = wandb.init(
    project=WB_PROJECT,
    job_type="calibration",
    config={
        "model_type": "random_forest",
        "seed": SEED,
        "calibration_bins": 10,
        "calibration_splits": ["val", "test"]
    },
    reinit=True,
    settings=wandb.Settings(start_method="thread")
)


# Helper to build a reliability table using equal-frequency bins
def calibration_table(y_true, y_prob, n_bins=10):
    # Rank-based bins for stable counts
    q = pd.qcut(y_prob, q=n_bins, duplicates="drop")
    dfb = pd.DataFrame({"y_true": y_true, "y_prob": y_prob, "bin": q})
    agg = dfb.groupby("bin").agg(
        mean_prob=("y_prob", "mean"),
        observed_rate=("y_true", "mean"),
        count=("y_true", "size")
    ).reset_index()
    # expose numeric bin edges for plotting in W&B
    agg["bin_low"] = agg["bin"].apply(lambda x: float(x.left))
    agg["bin_high"] = agg["bin"].apply(lambda x: float(x.right))
    agg = agg.drop(columns=["bin"])
    # Expected Calibration Error with equal-frequency bins
    # Weighted by bin count over total
    weights = agg["count"] / agg["count"].sum()
    ece = float(np.sum(weights * np.abs(agg["observed_rate"] - agg["mean_prob"])))
    return agg, ece

# Build tables and summary scores
cal_bins = 10
cal_val_tbl, val_ece = calibration_table(y_val, y_val_prob, n_bins=cal_bins)
cal_test_tbl, test_ece = calibration_table(y_test, y_test_prob, n_bins=cal_bins)

val_brier = brier_score_loss(y_val, y_val_prob)
test_brier = brier_score_loss(y_test, y_test_prob)

# Log metrics
wandb.log({
    "val_brier": val_brier,
    "test_brier": test_brier,
    "val_ece": val_ece,
    "test_ece": test_ece
})

# Log tables for interactive calibration plots in W&B
wandb.log({
    "calibration_table_val": wandb.Table(dataframe=cal_val_tbl),
    "calibration_table_test": wandb.Table(dataframe=cal_test_tbl)
})

run.finish()


  agg = dfb.groupby("bin").agg(
  agg = dfb.groupby("bin").agg(


0,1
test_brier,▁
test_ece,▁
val_brier,▁
val_ece,▁

0,1
test_brier,0.05763
test_ece,0.1074
val_brier,0.05921
val_ece,0.09969


### To Do

- Use the calibration tables in Weights & Biases to build line or bar charts  
- Pay special attention to the high risk bins where clinical actions concentrate  


# 8. Threshold selection for clinical use
- We'll choose operating thresholds on Validation to hit target sensitivity and specificity  
- We'll freeze those thresholds and evaluate on Test  
- Finally we'll log confusion matrices and clinical metrics to Weights & Biases for easy 

### Remember:
- Pick thresholds on Validation to meet clinical goals, then freeze and report Test performance  
- Sensitivity-first thresholds catch more true cases but raise alerts, specificity-first thresholds reduce false alarms  
- Use the Weights & Biases tables to compare predictive values and confusion matrices for each operating point


In [14]:
# 8. Threshold selection for the current Random Forest baseline

# Compute probabilities from the already-fitted Random Forest
y_val_prob = rf.predict_proba(X_val_t)[:, 1]
y_test_prob = rf.predict_proba(X_test_t)[:, 1]


run = wandb.init(
    project=WB_PROJECT,
    job_type="threshold_selection",
    config={
        "model_type": "random_forest",
        "seed": SEED,
        "target_sensitivity": 0.85,   # to be adjusted and defined per clinical use case
        "target_specificity": 0.90    # to be adjusted and defined per clinical use case
    },
    reinit=True,
    settings=wandb.Settings(start_method="thread")
)

# Helper: compute metrics at a given threshold
def metrics_at_threshold(y_true, y_prob, thr):
    y_pred = (y_prob >= thr).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else np.nan   # recall for positives
    specificity = tn / (tn + fp) if (tn + fp) > 0 else np.nan   # recall for negatives
    ppv = tp / (tp + fp) if (tp + fp) > 0 else np.nan           # precision
    npv = tn / (tn + fn) if (tn + fn) > 0 else np.nan
    prevalence = (tp + fn) / (tp + tn + fp + fn)
    return dict(
        threshold=float(thr),
        tp=int(tp), fp=int(fp), tn=int(tn), fn=int(fn),
        sensitivity=float(sensitivity),
        specificity=float(specificity),
        ppv=float(ppv), npv=float(npv),
        prevalence=float(prevalence)
    )

# Sweep a dense grid of thresholds on Validation
grid = np.unique(np.quantile(y_val_prob, q=np.linspace(0, 1, 501)))  # 0.2 percent steps by quantiles

val_rows = [metrics_at_threshold(y_val, y_val_prob, thr) for thr in grid]
val_tbl = pd.DataFrame(val_rows)

# Pick thresholds closest to targets on Validation
t_sens = run.config["target_sensitivity"]
t_spec = run.config["target_specificity"]

thr_for_sens = val_tbl.iloc[(val_tbl["sensitivity"] - t_sens).abs().argsort()].iloc[0]["threshold"]
thr_for_spec = val_tbl.iloc[(val_tbl["specificity"] - t_spec).abs().argsort()].iloc[0]["threshold"]

# Log chosen thresholds
wandb.config.update({
    "chosen_threshold_sensitivity": float(thr_for_sens),
    "chosen_threshold_specificity": float(thr_for_spec)
}, allow_val_change=True)

# Apply frozen thresholds on Test
test_at_sens = metrics_at_threshold(y_test, y_test_prob, thr_for_sens)
test_at_spec = metrics_at_threshold(y_test, y_test_prob, thr_for_spec)

# Build W&B Tables
val_table_wb = wandb.Table(dataframe=val_tbl[["threshold","sensitivity","specificity","ppv","npv","prevalence"]])
wandb.log({"validation_threshold_sweep": val_table_wb})

test_results_tbl = pd.DataFrame([
    dict(target="sensitivity", **test_at_sens),
    dict(target="specificity", **test_at_spec)
])

wandb.log({
    "test_operating_points": wandb.Table(dataframe=test_results_tbl[[
        "target","threshold","tp","fp","tn","fn","sensitivity","specificity","ppv","npv","prevalence"
    ]])
})

# Also log the two confusion matrices as compact tables
def cm_table(row):
    return wandb.Table(data=[
        ["Actual 0", row["tn"], row["fp"]],
        ["Actual 1", row["fn"], row["tp"]],
    ], columns=["", "Pred 0", "Pred 1"])

wandb.log({
    "confusion_matrix_test_at_sensitivity": cm_table(test_at_sens),
    "confusion_matrix_test_at_specificity": cm_table(test_at_spec)
})

print(f"Chosen thresholds -> Sensitivity target: {thr_for_sens:.3f}, Specificity target: {thr_for_spec:.3f}")
print("Test metrics at sensitivity target:", {k: v for k, v in test_at_sens.items() if k not in ["tp","fp","tn","fn"]})
print("Test metrics at specificity target:", {k: v for k, v in test_at_spec.items() if k not in ["tp","fp","tn","fn"]})

run.finish()


Chosen thresholds -> Sensitivity target: 0.321, Specificity target: 0.226
Test metrics at sensitivity target: {'threshold': 0.321223955496573, 'sensitivity': 0.9009009009009009, 'specificity': 0.9564586357039188, 'ppv': 0.7692307692307693, 'npv': 0.9835820895522388, 'prevalence': 0.13875}
Test metrics at specificity target: {'threshold': 0.2263541808615648, 'sensitivity': 1.0, 'specificity': 0.8969521044992743, 'ppv': 0.6098901098901099, 'npv': 1.0, 'prevalence': 0.13875}


# 9. Subgroup performance

We'll evaluate the Random Forest on clinically relevant subgroups to check consistency of performance
Subgroups (as identified during our CPH analysis):
- **SOFA** score (severity of illness)
- **CSRU** (Cardiac Surgery Recovery Unit)

These two variables were the most significant predictors of mortality in the Cox Proportional Hazards analysis.  
We will use the thresholds chosen on Validation and report metrics on Test in Weights & Biases


In [17]:
# Subgroup performance on Test using SOFA bands and CSRU flag
# Uses thresholds thr_for_sens and thr_for_spec chosen in the previous step


run = wandb.init(
    project=WB_PROJECT,
    job_type="subgroup_eval",
    config={
        "model_type": "random_forest",
        "seed": SEED,
        "subgroups": ["SOFA_bin", "CSRU"]
    },
    reinit=True,
    settings=wandb.Settings(start_method="thread")
)

# Define subgroup columns
SOFA_COL = "SOFA"
CSRU_COL = "CSRU"

# Extract subgroups from Test set
test_idx = X_test.index
sofa_test = ICU.loc[test_idx, SOFA_COL]
csru_test = ICU.loc[test_idx, CSRU_COL]

# Create SOFA quantile bins (5 bands by severity)
sofa_bins = pd.qcut(sofa_test, q=5, duplicates="drop").astype(str)

# Binary label for CSRU membership
csru_group = np.where(pd.to_numeric(csru_test, errors="coerce").fillna(0).astype(int) == 1, "CSRU", "non_CSRU")

# Helper to compute metrics at fixed threshold
def metrics_fixed_threshold(y_true, y_prob, thr):
    y_pred = (y_prob >= thr).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    sens = tp / (tp + fn) if (tp + fn) else np.nan
    spec = tn / (tn + fp) if (tn + fp) else np.nan
    ppv  = tp / (tp + fp) if (tp + fp) else np.nan
    npv  = tn / (tn + fn) if (tn + fn) else np.nan
    prev = (tp + fn) / (tp + tn + fp + fn)
    return dict(
        tp=int(tp), fp=int(fp), tn=int(tn), fn=int(fn),
        sensitivity=float(sens), specificity=float(spec),
        ppv=float(ppv), npv=float(npv), prevalence=float(prev)
    )

# Aggregate metrics per subgroup
rows = []

def add_group(group_name, group_values):
    series = pd.Series(group_values, index=test_idx).astype(str)
    for g in sorted(series.unique()):
        mask = (series == g).values
        y_true_g = y_test.values[mask]
        y_prob_g = y_test_prob[mask]
        if len(y_true_g) < 10:
            continue
        auroc = roc_auc_score(y_true_g, y_prob_g)
        auprc = average_precision_score(y_true_g, y_prob_g)
        m_sens = metrics_fixed_threshold(y_true_g, y_prob_g, thr_for_sens)
        m_spec = metrics_fixed_threshold(y_true_g, y_prob_g, thr_for_spec)
        rows.append({
            "model_type": "random_forest",
            "subgroup_type": group_name,
            "subgroup_value": g,
            "n": int(len(y_true_g)),
            "auroc": float(auroc),
            "auprc": float(auprc),
            "target": "sensitivity",
            "threshold": float(thr_for_sens),
            **m_sens
        })
        rows.append({
            "model_type": "random_forest",
            "subgroup_type": group_name,
            "subgroup_value": g,
            "n": int(len(y_true_g)),
            "auroc": float(auroc),
            "auprc": float(auprc),
            "target": "specificity",
            "threshold": float(thr_for_spec),
            **m_spec
        })

# Add SOFA and CSRU subgroups
add_group("SOFA_bin", sofa_bins)
add_group("ICU_unit", csru_group)

subgroup_df = pd.DataFrame(rows)

wandb.log({
    "subgroup_metrics_test": wandb.Table(dataframe=subgroup_df[[
        "model_type","subgroup_type","subgroup_value","n",
        "target","threshold",
        "auroc","auprc","sensitivity","specificity","ppv","npv","prevalence",
        "tp","fp","tn","fn"
    ]]),
    "subgroup_columns_used": wandb.Table(data=[[SOFA_COL, CSRU_COL]], columns=["sofa_column","icu_unit_column"])
})

print("Logged subgroup metrics for SOFA bins and CSRU vs non-CSRU on Test")


Logged subgroup metrics for SOFA bins and CSRU vs non-CSRU on Test


### To Do

- In the Weights & Biases table, compare SOFA bins and CSRU vs non CSRU
- Look for drops in sensitivity or PPV at the chosen threshold
- If performance varies widely by subgroup, discuss mitigation options such as recalibration or separate thresholds


# 10. Interpretability for understanding

Model interpretability connects predictive performance to clinical meaning:
- Use **Permutation Importance** for Logistic Regression, Decision Tree, and Random Forest  
- Use **SHAP** on the Random Forest to see which features drive individual predictions  

All outputs are logged to Weights & Biases for exploration and comparison


In [24]:
# Interpretability: Permutation Importance and SHAP for Random Forest
# Logs all interpretability outputs to W&B

run = wandb.init(
    project=WB_PROJECT,
    job_type="interpretability",
    config={
        "models": ["logistic_regression", "decision_tree", "random_forest"],
        "shap_sample_size": 500
    },
    reinit=True,
    settings=wandb.Settings(start_method="thread")
)

# 1) Permutation Importance for all models on Validation
models = {
    "logistic_regression": log_reg,
    "decision_tree": dt,
    "random_forest": rf
}

for model_name, model in models.items():
    result = permutation_importance(
        model, X_val_t, y_val, n_repeats=10, random_state=SEED, n_jobs=-1
    )
    imp_df = (
        pd.DataFrame({
            "feature": X_val_t.columns,
            "importance_mean": result.importances_mean,
            "importance_std": result.importances_std
        })
        .sort_values("importance_mean", ascending=False)
        .reset_index(drop=True)
    )
    wandb.log({f"{model_name}_permutation_importance": wandb.Table(dataframe=imp_df)})

# 2) SHAP for Random Forest on a small Validation sample
# Use model_output="raw" for tree_path_dependent
# Log summary plot image and a ranked table of mean |SHAP| values


# Sample a manageable slice from Validation
shap_sample_n = min(500, len(X_val_t))
shap_sample = X_val_t.sample(n=shap_sample_n, random_state=SEED)

# Use raw output with tree_path_dependent for tree models
explainer = shap.TreeExplainer(
    rf,
    model_output="raw",
    feature_perturbation="tree_path_dependent"
)

# Avoid additivity mismatches across sklearn versions
shap_values_raw = explainer.shap_values(shap_sample, check_additivity=False)

# Select SHAP values for the positive class and ensure 2D shape (n_samples, n_features)
if isinstance(shap_values_raw, list):
    sv = np.asarray(shap_values_raw[1])  # class 1
else:
    sv = np.asarray(shap_values_raw)

# If 3D, last axis usually indexes classes
if sv.ndim == 3:
    # prefer class 1 if available, else class 0
    cls_axis = sv.shape[-1]
    pick = 1 if cls_axis >= 2 else 0
    sv = sv[..., pick]

# Safety: squeeze any trailing singleton dims
sv = np.squeeze(sv)
assert sv.ndim == 2, f"Expected 2D SHAP values after selection, got {sv.shape}"
assert sv.shape[1] == shap_sample.shape[1], "Feature count mismatch between SHAP and input data"

# Global summary plot
shap.summary_plot(sv, shap_sample, show=False)
wandb.log({"shap_summary_plot": wandb.Image(plt.gcf())})
plt.close()

# Ranked mean absolute SHAP for table
mean_abs_shap = np.abs(sv).mean(axis=0).reshape(-1)
feat_names = list(shap_sample.columns)

shap_df = pd.DataFrame(
    {"feature": feat_names, "mean_abs_shap": mean_abs_shap}
).sort_values("mean_abs_shap", ascending=False).reset_index(drop=True)

wandb.log({"rf_shap_feature_importance": wandb.Table(dataframe=shap_df)})

run.finish()


  shap.summary_plot(sv, shap_sample, show=False)


### To DO
In Weights & Biases:
- Use the permutation importance tables to see which features most influence predictions  
- Compare across models to note stability of top features  
- In the SHAP summary, focus on high impact variables — for instance, rising SOFA and CSRU membership should drive higher predicted mortality  