# 1. Introduction
# ML for Healthcare with Weights & Biases: Interpretable Baselines and Clinical Evaluation

We continue our focus on shifting from coding models to *understanding* them — how well they perform, how trustworthy their probabilities are, and how they behave across patient groups. Strenghtening the foundation for clinically aware ML practice.

**Objective**  
Build, track, and interpret models predicting in-hospital mortality in ICU patients using Weights & Biases (W&B).  
This notebook transforms standard Machine Learning (ML) practice into an auditable, interpretable, and clinically meaningful workflow

**You will learn**
- How to track model training and evaluation runs using Weights & Biases  
- How to interpret models and understand their calibration  
- How to examine fairness and subgroup performance  
- How to communicate model insights for healthcare decisions

**Models**
We’ll compare three interpretable baselines:
1. **Logistic Regression**: simple linear reference, easy to explain  
2. **Decision Tree (shallow)**: intuitive splits, visually transparent  
3. **Random Forest**: robust ensemble, main focus for tuning and interpretability

**Dataset**
[PhysioNet Challenge 2012 dataset](https://physionet.org/content/challenge-2012/1.0.0/), containing clinical measurements, demographics, and the target:  
`In-hospital_death` (binary: 1 = patient died during stay, 0 = survived).

**Weights & Biases**
Used to:
- Track configurations, metrics, and plots  
- Compare models and hyperparameters  
- Log interpretability results (feature importance, calibration, subgroups)  
- Support model transparency and documentation



# 2. Setup

In [4]:
# Setup: imports, reproducibility, and Weights & Biases initialization

import os
import sys
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import shap
import joblib # Added for model saving

# Initialize Weights & Biases
import wandb
from wandb.plot import roc_curve as wandb_roc_curve
from wandb.plot import pr_curve as wandb_pr_curve
from wandb.plot import confusion_matrix as wandb_cm
from wandb import Api as WandbApi # Added for sweep automation

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
from sklearn.metrics import (
    roc_auc_score, 
    average_precision_score, 
    brier_score_loss,
    roc_curve, 
    precision_recall_curve,
    confusion_matrix,
)

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

In [5]:
# Defined once here and used by all model baseline cells
def calibration_table(y_true, y_prob, n_bins=10):
    """
    Calculates calibration metrics (ECE) and returns a table.
    It bins predictions, compares average predicted prob to actual positive rate,
    and computes the weighted average error (ECE).
    """
    if isinstance(y_true, pd.Series):
        y_true = y_true.values # Convert from pandas to numpy if needed
        
    # Use pd.qcut to bin probabilities into `n_bins` groups (quantiles)
    q = pd.qcut(y_prob, q=n_bins, duplicates="drop")
    dfb = pd.DataFrame({"y_true": y_true, "y_prob": y_prob, "bin": q})
    
    # Group by the bins and calculate:
    agg = dfb.groupby("bin", observed=False).agg(
        mean_prob=("y_prob", "mean"),     # The average *predicted risk* in this bin
        observed_rate=("y_true", "mean"), # The actual *mortality rate* in this bin
        count=("y_true", "size")          # How many patients are in this bin
    ).reset_index()
    
    # Convert bin (Interval object) to string for W&B Table
    agg["bin"] = agg["bin"].astype(str)
    
    # Calculate ECE
    weights = agg["count"] / agg["count"].sum()
    ece = float(np.sum(weights * np.abs(agg["observed_rate"] - agg["mean_prob"])))
    
    return agg, ece # Return the table and the single ECE score

In [6]:
# !wandb login
# !wandb.login(relogin=True)

In [7]:
WB_PROJECT = "ml-healthcare-intro"

# wandb.login() # Uncomment if not logged in
run = wandb.init(
    project=WB_PROJECT, 
    job_type="data-exploration", # More descriptive job_type
    name="01-data-exploration",  # Clean name for the UI
    config={
        "seed": SEED,
        "framework": "scikit-learn",
        "dataset": "physionet2012_set_a"
    }
)

# Log environment metadata
wandb.config.update({
    "python_version": sys.version.split()[0],
    "pandas_version": pd.__version__,
    "numpy_version": np.__version__,
}, allow_val_change=True)

print(f"Weights & Biases tracking URL: {run.url}")

Weights & Biases tracking URL: https://wandb.ai/idiazl/ml-healthcare-intro/runs/j26qxxof


# 3. Data loading and initial checks
We will load the dataset, confirm the target, and log basic summaries to Weights & Biases. This lets us explore class balance, missingness, and a quick data preview directly in the dashboard.

### Action Items
- Open the W&B run link generated above
- Inspect the `data_preview_table` to understand the features
- Check the `class_balance_table` to confirm the target imbalance
- Review `missingness_top30_table` to identify problematic features

In [8]:
# --- 1. Load Data ---
PATH = "PhysionetChallenge2012-set-a.csv.gz"

# Simple check to ensure the data file exists before trying to load it
if not os.path.exists(PATH):
    raise FileNotFoundError(
        f"Error: The data file was not found at '{PATH}'. "
        "Please ensure the dataset is in the correct directory."
    )

ICU = pd.read_csv(PATH, compression="gzip")

TARGET = "In-hospital_death"
ID_COL = "recordid" if "recordid" in ICU.columns else None

if TARGET not in ICU.columns:
    raise ValueError(f"Target column '{TARGET}' not found in dataset")

# Ensure target is numeric and binary
ICU[TARGET] = pd.to_numeric(ICU[TARGET], errors="coerce").fillna(0).astype(int)

# Basic facts
n_rows, n_cols = ICU.shape
pos_rate = float(ICU[TARGET].mean())

# Class balance table
cb_series = ICU[TARGET].value_counts().sort_index()
class_balance_tbl = wandb.Table(data=[[int(k), int(v), float(v / n_rows)] for k, v in cb_series.items()],
                                columns=["label", "count", "fraction"])

# Missingness table (top 5)
miss = ICU.isna().mean().sort_values(ascending=False)
miss_top = miss.head(5).reset_index()
miss_top.columns = ["column", "missing_fraction"]
missing_tbl = wandb.Table(data=miss_top.values.tolist(), columns=list(miss_top.columns))

# Data preview table (sample up to 5 rows for UI responsiveness)
preview = ICU.sample(n=min(5, len(ICU)), random_state=SEED)
preview_tbl = wandb.Table(dataframe=preview)

wandb.log({
    "dataset_rows": n_rows,
    "dataset_cols": n_cols,
    "positive_rate": pos_rate,
    "class_balance_table": class_balance_tbl,
    "missingness_top5_table": missing_tbl,
    "data_preview_table": preview_tbl
})

print(f"Loaded ICU with shape {ICU.shape} and positive rate {pos_rate:.3f}")

Loaded ICU with shape (4000, 120) and positive rate 0.139


# 4. Simple preprocessing
### Preprocessing and splits
We will split the data into train, validation, and test sets, impute missing values and one-hot encode categorical variables, and log feature lists and split sizes to Weights & Biases for transparency

In [9]:
# Preprocessing: split data and prepare simple pipelines

# Define all columns to drop: the target, the ID, and the leaking 'Survival' feature
LEAKAGE_COL = "Survival"
COLS_TO_DROP = [c for c in [TARGET, ID_COL, LEAKAGE_COL] if c in ICU.columns]

# Drop ID column if present
X = ICU.drop(columns=COLS_TO_DROP)
y = ICU[TARGET]

# Identify categorical and numeric columns
cat_cols = [c for c in X.columns if X[c].dtype == "object"]
num_cols = [c for c in X.columns if c not in cat_cols]

# Split data (60 percent train, 20 percent validation, 20 percent test)
X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=SEED
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train_full, y_train_full, test_size=0.25, stratify=y_train_full, random_state=SEED
)

print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")

# Define transformations
num_transformer = SimpleImputer(strategy="median")
cat_transformer = Pipeline(steps=[
    ("imputer", SimpleImputer(strategy="most_frequent")),
    ("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=False))
])

preprocessor = ColumnTransformer(
    transformers=[
        ("num", num_transformer, num_cols),
        ("cat", cat_transformer, cat_cols)
    ],
    verbose_feature_names_out=False
)

preprocessor.set_output(transform="pandas")

# Fit and transform
X_train_t = preprocessor.fit_transform(X_train)
X_val_t = preprocessor.transform(X_val)
X_test_t = preprocessor.transform(X_test)

# Log split sizes
wandb.log({
    "train_rows": len(X_train),
    "val_rows": len(X_val),
    "test_rows": len(X_test),
    "n_num_features_raw": len(num_cols),
    "n_cat_features_raw": len(cat_cols),
    "n_features_transformed": X_train_t.shape[1]
})

# Log feature lists as W&B Tables for inspectability
num_tbl = wandb.Table(data=[[c, "numeric"] for c in num_cols], columns=["feature", "type"])
cat_tbl = wandb.Table(data=[[c, "categorical"] for c in cat_cols], columns=["feature", "type"])
wandb.log({"feature_list_numeric": num_tbl, "feature_list_categorical": cat_tbl})

print("Preprocessing complete")

# Finish the data exploration run
run.finish()

Train: (2400, 117), Val: (800, 117), Test: (800, 117)
Preprocessing complete


0,1
dataset_cols,▁
dataset_rows,▁
n_cat_features_raw,▁
n_features_transformed,▁
n_num_features_raw,▁
positive_rate,▁
test_rows,▁
train_rows,▁
val_rows,▁

0,1
dataset_cols,120.0
dataset_rows,4000.0
n_cat_features_raw,0.0
n_features_transformed,117.0
n_num_features_raw,117.0
positive_rate,0.1385
test_rows,800.0
train_rows,2400.0
val_rows,800.0


Use the feature tables and split sizes in Weights & Biases to verify preprocessing choices  
All models next will consume the same transformed matrices for fair comparisons


# 5. Logistic Regression: establishing a simple reference

Before exploring complex models, it’s helpful to start with a simple and interpretable baseline. Logistic Regression gives a linear relationship between features and the log-odds of the outcome. Helping us understand whether more flexible models (like Random Forests) truly add value.

1. We’ll train a Logistic Regression model, evaluate it on the validation and test sets
2. Log all metrics to Weights & Biases to compare later


In [10]:
# Logistic Regression baseline with W&B's built-in plotting
run = wandb.init(
    project=WB_PROJECT,
    job_type="baseline",
    name="02-baseline-logistic-regression", # Add clean name
    config={"model_type": "logistic_regression", "seed": SEED},
    reinit=True,
)

# The 'calibration_table' function is now defined globally in Section 2

# Train
log_reg = LogisticRegression(max_iter=500, solver="liblinear", random_state=SEED)
log_reg.fit(X_train_t, y_train)

# Predict probabilities
y_val_prob = log_reg.predict_proba(X_val_t)[:, 1]
y_test_prob = log_reg.predict_proba(X_test_t)[:, 1]

# Metrics
val_auc = roc_auc_score(y_val, y_val_prob)
val_pr  = average_precision_score(y_val, y_val_prob)
val_brier = brier_score_loss(y_val, y_val_prob)
test_auc = roc_auc_score(y_test, y_test_prob)
test_pr  = average_precision_score(y_test, y_test_prob)
test_brier = brier_score_loss(y_test, y_test_prob)

wandb.log({
    "val_auc": val_auc,
    "val_pr": val_pr,
    "val_brier": val_brier,
    "test_auc": test_auc,
    "test_pr": test_pr,
    "test_brier": test_brier
})

# --- Log Calibration Metrics (using global helper) ---
cal_val_tbl, val_ece = calibration_table(y_val, y_val_prob)
cal_test_tbl, test_ece = calibration_table(y_test, y_test_prob)
wandb.log({
    "val_ece": val_ece, "test_ece": test_ece, # ECE stands for Expected Calibration Error
    "calibration_table_val": wandb.Table(dataframe=cal_val_tbl),
    "calibration_table_test": wandb.Table(dataframe=cal_test_tbl)
})

# Create the 2D probability array that wandb.plot expects
y_val_probas_2d = np.stack([1.0 - y_val_prob, y_val_prob], axis=1)

# Pass the 2D array and remove the 'labels' argument
wandb.log({
    "roc_curve_val": wandb_roc_curve(y_val.values, y_val_probas_2d),
    "pr_curve_val": wandb_pr_curve(y_val.values, y_val_probas_2d)
})

# Coefficients for transparency
coef_df = pd.DataFrame({"feature": X_train_t.columns, "coefficient": log_reg.coef_[0]})
coef_tbl = wandb.Table(dataframe=coef_df.sort_values("coefficient", ascending=False))
wandb.log({"log_reg_coefficients": coef_tbl})

# Predictions table sample
pred_sample = pd.DataFrame({
    "id": X_val.index if ID_COL is None else X_val.index, 
    "y_true": y_val.values,
    "y_prob": y_val_prob
}).sample(n=min(500, len(y_val)), random_state=SEED)
wandb.log({"predictions_val_sample": wandb.Table(dataframe=pred_sample)})

run.finish()

print(f"LR validation AUROC {val_auc:.3f}, AUPRC {val_pr:.3f}, Brier {val_brier:.3f}, ECE {val_ece:.3f}")



0,1
test_auc,▁
test_brier,▁
test_ece,▁
test_pr,▁
val_auc,▁
val_brier,▁
val_ece,▁
val_pr,▁

0,1
test_auc,0.83594
test_brier,0.09979
test_ece,0.03062
test_pr,0.43254
val_auc,0.8355
val_brier,0.09637
val_ece,0.0363
val_pr,0.46973


LR validation AUROC 0.835, AUPRC 0.470, Brier 0.096, ECE 0.036


### To Do
- Open the new "02-baseline-logistic-regression" run in W&B
- Examine the interactive `roc_curve_val` and `pr_curve_val` plots
- Sort the `log_reg_coefficients` table to find the strongest predictors


# 6. Decision Tree baseline

A shallow tree is easy to read and helps us see simple non-linear rules. We will train a small tree, log metrics, curve points, feature importances, and a predictions sample to Weights & Biases


In [11]:
# 6. Decision Tree baseline
run = wandb.init(
    project=WB_PROJECT,
    job_type="baseline",
    name="02-baseline-decision-tree", # Add clean name
    config={
        "model_type": "decision_tree",
        "seed": SEED,
        "max_depth": 4,
        "min_samples_leaf": 20
    },
    reinit=True,
)

# The 'calibration_table' function is now defined globally in Section 2

# Train a small, readable tree
dt = DecisionTreeClassifier(
    max_depth=wandb.config.max_depth,
    min_samples_leaf=wandb.config.min_samples_leaf,
    random_state=SEED
)
dt.fit(X_train_t, y_train)

# Predict probabilities
y_val_prob = dt.predict_proba(X_val_t)[:, 1]
y_test_prob = dt.predict_proba(X_test_t)[:, 1]

# Metrics
val_auc = roc_auc_score(y_val, y_val_prob)
val_pr  = average_precision_score(y_val, y_val_prob)
val_brier = brier_score_loss(y_val, y_val_prob)
test_auc = roc_auc_score(y_test, y_test_prob)
test_pr  = average_precision_score(y_test, y_test_prob)
test_brier = brier_score_loss(y_test, y_test_prob)

wandb.log({
    "val_auc": val_auc,
    "val_pr": val_pr,
    "val_brier": val_brier,
    "test_auc": test_auc,
    "test_pr": test_pr,
    "test_brier": test_brier
})

# --- Log Calibration Metrics (using global helper) ---
cal_val_tbl, val_ece = calibration_table(y_val, y_val_prob)
cal_test_tbl, test_ece = calibration_table(y_test, y_test_prob)
wandb.log({
    "val_ece": val_ece, "test_ece": test_ece,
    "calibration_table_val": wandb.Table(dataframe=cal_val_tbl),
    "calibration_table_test": wandb.Table(dataframe=cal_test_tbl)
})

# Create the 2D probability array that wandb.plot expects
y_val_probas_2d = np.stack([1.0 - y_val_prob, y_val_prob], axis=1)

# Pass the 2D array and remove the 'labels' argument
wandb.log({
    "roc_curve_val": wandb_roc_curve(y_val.values, y_val_probas_2d),
    "pr_curve_val": wandb_pr_curve(y_val.values, y_val_probas_2d)
})

# Feature importances
imp_df = (
    pd.DataFrame({"feature": X_train_t.columns, "importance": dt.feature_importances_})
    .sort_values("importance", ascending=False)
)
wandb.log({"feature_importances": wandb.Table(dataframe=imp_df)})

# Predictions sample
pred_sample = pd.DataFrame({
    "id": X_val.index,
    "y_true": y_val.values,
    "y_prob": y_val_prob
}).sample(n=min(500, len(y_val)), random_state=SEED)
wandb.log({"predictions_val_sample": wandb.Table(dataframe=pred_sample)})

run.finish()

print(f"DT validation AUROC {val_auc:.3f}, AUPRC {val_pr:.3f}, Brier {val_brier:.3f}, ECE {val_ece:.3f}")

0,1
test_auc,▁
test_brier,▁
test_ece,▁
test_pr,▁
val_auc,▁
val_brier,▁
val_ece,▁
val_pr,▁

0,1
test_auc,0.80022
test_brier,0.10108
test_ece,0.04175
test_pr,0.42453
val_auc,0.78984
val_brier,0.1011
val_ece,0.04323
val_pr,0.41323


DT validation AUROC 0.790, AUPRC 0.413, Brier 0.101, ECE 0.043


# 7. Random Forest baseline
### Random Forest baseline

As we know, a Random Forest averages many trees to improve stability and performance. We will train a baseline model and log metrics, curve points, feature importances, and a predictions sample to W&B. Later we will tune hyperparameters with a short sweep

In [13]:
# 7. Random Forest Comprehensive Baseline Analysis
# This single cell trains one Random Forest model and then performs a
# deep-dive analysis covering performance, calibration, thresholding, and fairness.

# Start a new W&B run to log everything
run = wandb.init(
    project=WB_PROJECT,
    job_type="baseline-comprehensive", # A tag to group this run with other deep analyses
    name="02-baseline-random-forest-full", # A clear name for the W&B dashboard
    config={
        # --- Model Hyperparameters ---
        "model_type": "random_forest",
        "seed": SEED,
        "n_estimators": 300,  # Number of trees in the forest
        "max_depth": None,  # Let trees grow as deep as they want
        "max_features": "sqrt", # Number of features to consider at each split
        "min_samples_leaf": 5,  # Minimum samples required to be at a leaf node
        "n_jobs": -1, # Use all available CPU cores for training
        
        # --- Analysis Parameters (for steps 3, 4, 5) ---
        "calibration_bins": 10,       # How many bins to use for the ECE calculation
        "target_sensitivity": 0.85, # A clinical goal: "We must find at least 85% of mortality cases"
        "target_specificity": 0.90, # A clinical goal: "We must correctly clear at least 90% of survival cases"
        "subgroups": ["SOFA_bin", "CSRU"] # Features to use for the fairness/bias check
    },
    reinit=True, # Allows running wandb.init() again in the same notebook
)

# --- 1. Train Model ---

# Create a 'cfg' shortcut to access the config dictionary
cfg = wandb.config 

# Initialize the RandomForestClassifier with hyperparameters from our config
rf = RandomForestClassifier(
    n_estimators=cfg.n_estimators,
    max_depth=cfg.max_depth,
    max_features=cfg.max_features,
    min_samples_leaf=cfg.min_samples_leaf,
    random_state=SEED,
    n_jobs=cfg.n_jobs
)
# Train (fit) the model on the training data
rf.fit(X_train_t, y_train)

# Get predicted probabilities for the positive class (mortality)
# .predict_proba() returns [prob_of_0, prob_of_1], so [:, 1] selects just prob_of_1
y_val_prob = rf.predict_proba(X_val_t)[:, 1]
y_test_prob = rf.predict_proba(X_test_t)[:, 1]

# --- 2. Log Baseline Metrics & Plots ---

# Calculate standard performance metrics on both validation and test sets
val_auc = roc_auc_score(y_val, y_val_prob)
val_pr  = average_precision_score(y_val, y_val_prob) # AUPRC
val_brier = brier_score_loss(y_val, y_val_prob)      # Mean Squared Error for probabilities

test_auc = roc_auc_score(y_test, y_test_prob)
test_pr  = average_precision_score(y_test, y_test_prob)
test_brier = brier_score_loss(y_test, y_test_prob)

# Group simple metrics into a dictionary for a single wandb.log() call
metrics_log = {
    "val_auc": val_auc, "val_pr": val_pr, "val_brier": val_brier,
    "test_auc": test_auc, "test_pr": test_pr, "test_brier": test_brier
}
wandb.log(metrics_log)

# W&B's built-in plotting functions (wandb_roc_curve, wandb_pr_curve)
# expect a 2D array of probabilities: [prob_for_class_0, prob_for_class_1]
y_val_probas_2d = np.stack([1.0 - y_val_prob, y_val_prob], axis=1)

# Log the interactive plots to the W&B run
wandb.log({
    "roc_curve_val": wandb_roc_curve(y_val.values, y_val_probas_2d),
    "pr_curve_val": wandb_pr_curve(y_val.values, y_val_probas_2d)
})

# Create a DataFrame of feature importances from the trained model
imp_df = (
    pd.DataFrame({"feature": X_train_t.columns, "importance": rf.feature_importances_})
    .sort_values("importance", ascending=False)
)
# Log this as an interactive W&B Table
wandb.log({"feature_importances": wandb.Table(dataframe=imp_df)})

# Create a DataFrame with a sample of predictions for manual inspection
pred_sample = pd.DataFrame({
    "id": X_val.index, 
    "y_true": y_val.values, 
    "y_prob": y_val_prob
}).sample(n=min(500, len(y_val)), random_state=SEED) # Sample max 500 rows
# Log this sample as a W&B Table
wandb.log({"predictions_val_sample": wandb.Table(dataframe=pred_sample)})

print(f"RF validation AUROC {val_auc:.3f}, AUPRC {val_pr:.3f}, Brier {val_brier:.3f}")

# --- 3. Calibration Analysis (Trustworthiness Check) ---

# --- REFACTORED: Redundant helper function removed ---
# The 'calibration_table' function is now defined globally in Section 2

# Calculate calibration for both validation and test sets
cal_val_tbl, val_ece = calibration_table(y_val, y_val_prob, n_bins=cfg.calibration_bins)
cal_test_tbl, test_ece = calibration_table(y_test, y_test_prob, n_bins=cfg.calibration_bins)

# Log the ECE scores and the full calibration tables to W&B
wandb.log({
    "val_ece": val_ece, 
    "test_ece": test_ece,
    "calibration_table_val": wandb.Table(dataframe=cal_val_tbl),
    "calibration_table_test": wandb.Table(dataframe=cal_test_tbl)
})
print(f"RF validation ECE {val_ece:.3f}, Test ECE {test_ece:.3f}")


# --- 4. Threshold Selection (Finding Clinical Cutoffs) ---

# Helper function to get detailed metrics for a *single* probability threshold
def metrics_at_threshold(y_true, y_prob, thr):
    """Calculates confusion matrix metrics for a given probability threshold."""
    if isinstance(y_true, pd.Series):
        y_true = y_true.values # Convert to numpy array
        
    # Convert probabilities to binary predictions (0 or 1) based on the threshold
    y_pred = (y_prob >= thr).astype(int)
    
    # Calculate the confusion matrix components
    # .ravel() flattens the 2x2 matrix into a 1D array [tn, fp, fn, tp]
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
    
    # Calculate key clinical metrics. Handle division by zero if a class is empty.
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else np.nan # Recall
    specificity = tn / (tn + fp) if (tn + fp) > 0 else np.nan
    ppv = tp / (tp + fp) if (tp + fp) > 0 else np.nan # Precision
    npv = tn / (tn + fn) if (tn + fn) > 0 else np.nan
    prevalence = (tp + fn) / (tp + tn + fp + fn)
    
    # Return all metrics as a dictionary
    return dict(
        threshold=float(thr), tp=int(tp), fp=int(fp), tn=int(tn), fn=int(fn),
        sensitivity=float(sensitivity), specificity=float(specificity),
        ppv=float(ppv), npv=float(npv), prevalence=float(prevalence)
    )

# --- Find Optimal Thresholds on Validation Set ---

# Create a grid of 501 potential thresholds to test
# We use quantiles of the validation probabilities to get a sensitive grid
grid = np.unique(np.quantile(y_val_prob, q=np.linspace(0, 1, 501)))

# Run the helper function for *every* threshold in the grid on the *validation* data
val_rows = [metrics_at_threshold(y_val, y_val_prob, thr) for thr in grid]
val_tbl = pd.DataFrame(val_rows) # Convert list of dictionaries to a DataFrame

# Find the threshold that gets *closest* to our target sensitivity (from config)
# .abs() finds the absolute difference
# .argsort() finds the row index of the *smallest* difference
# .iloc[0] selects that best row
thr_for_sens = val_tbl.iloc[(val_tbl["sensitivity"] - cfg.target_sensitivity).abs().argsort()].iloc[0]["threshold"]

# Find the threshold that gets *closest* to our target specificity (from config)
thr_for_spec = val_tbl.iloc[(val_tbl["specificity"] - cfg.target_specificity).abs().argsort()].iloc[0]["threshold"]

# Log these chosen thresholds back to the W&B run's config
wandb.config.update({
    "chosen_threshold_sensitivity": float(thr_for_sens),
    "chosen_threshold_specificity": float(thr_for_spec)
}, allow_val_change=True) # allow_val_change lets us add to an existing config

# --- Apply Chosen Thresholds to Test Set ---

# Now, use the thresholds we found on validation to evaluate the *test* set
test_at_sens = metrics_at_threshold(y_test, y_test_prob, thr_for_sens)
test_at_spec = metrics_at_threshold(y_test, y_test_prob, thr_for_spec)

# Log the full table of all 501 thresholds (from validation) for review
wandb.log({"validation_threshold_sweep": wandb.Table(dataframe=val_tbl[["threshold","sensitivity","specificity","ppv","npv","prevalence"]])})

# Create a small DataFrame summarizing the test set performance at our chosen points
test_results_df = pd.DataFrame([
    dict(target="sensitivity", **test_at_sens), # "**" unpacks the dictionary
    dict(target="specificity", **test_at_spec)
])
# Log this summary table
wandb.log({"test_operating_points": wandb.Table(dataframe=test_results_df)})

# --- Log Interactive Confusion Matrices ---
# Create binary predictions for the test set using our two chosen thresholds
y_pred_sens = (y_test_prob >= thr_for_sens).astype(int)
y_pred_spec = (y_test_prob >= thr_for_spec).astype(int)

# Log interactive confusion matrix plots to W&B
wandb.log({
    "confusion_matrix_test_at_sensitivity": wandb_cm(y_true=y_test.values, preds=y_pred_sens, class_names=["Survived", "Died"]),
    "confusion_matrix_test_at_specificity": wandb_cm(y_true=y_test.values, preds=y_pred_spec, class_names=["Survived", "Died"])
})
print(f"Chosen thresholds -> Sensitivity target: {thr_for_sens:.3f}, Specificity target: {thr_for_spec:.3f}")

# --- 5. Subgroup Performance (Fairness & Bias Check) ---

# Define the original column names for our subgroups
SOFA_COL = "SOFA" # A clinical score for patient sickness
CSRU_COL = "CSRU" # A binary flag for a specific ICU type

# Get the original (non-preprocessed) subgroup features for the *test set* patients
test_idx = X_test.index # Get the original index of the test set rows
sofa_test = ICU.loc[test_idx, SOFA_COL] # Get SOFA scores for test patients
csru_test = ICU.loc[test_idx, CSRU_COL] # Get CSRU status for test patients

# Create the subgroup bins/categories
# Bin SOFA scores into 5 quintiles (e.g., "very low", "low", "medium", "high", "very high")
sofa_bins = pd.qcut(sofa_test, q=5, duplicates="drop").astype(str)
# Bin CSRU into "CSRU" vs "non_CSRU"
csru_group = np.where(pd.to_numeric(csru_test, errors="coerce").fillna(0).astype(int) == 1, "CSRU", "non_CSRU")

# Helper function to get metrics for a subgroup (almost identical to the one in step 4)
def subgroup_metrics_fixed_threshold(y_true, y_prob, thr):
    """Calculates confusion matrix metrics for a subgroup."""
    if isinstance(y_true, pd.Series):
        y_true = y_true.values
        
    y_pred = (y_prob >= thr).astype(int)
    
    # We must handle cases where a small subgroup has 0 positive or 0 negative cases
    # This would cause confusion_matrix() to error or return a non-4-element array
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    except ValueError: 
        # Failsafe for empty or single-class slices
        tn, fp, fn, tp = 0, 0, 0, 0
        if len(y_pred) > 0:
            # Manually calculate if possible
            tn = np.sum((y_true == 0) & (y_pred == 0))
            fp = np.sum((y_true == 0) & (y_pred == 1))
            fn = np.sum((y_true == 1) & (y_pred == 0))
            tp = np.sum((y_true == 1) & (y_pred == 1))

    # Calculate metrics, checking for division by zero
    sens = tp / (tp + fn) if (tp + fn) else np.nan
    spec = tn / (tn + fp) if (tn + fp) else np.nan
    ppv  = tp / (tp + fp) if (tp + fp) else np.nan
    npv  = tn / (tn + fn) if (tn + fn) else np.nan
    prev = (tp + fn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else np.nan
    
    return dict(
        tp=int(tp), fp=int(fp), tn=int(tn), fn=int(fn),
        sensitivity=float(sens), specificity=float(spec),
        ppv=float(ppv), npv=float(npv), prevalence=float(prev)
    )

# This list will hold all results, one row per group per threshold
subgroup_rows = []

def add_group(group_name, group_values):
    """
    Loops through all unique values in a group (e.g., all 5 SOFA bins),
    calculates metrics for that slice, and appends to the subgroup_rows list.
    """
    series = pd.Series(group_values, index=test_idx).astype(str)
    
    # Iterate over each unique value (e.g., "CSRU", then "non_CSRU")
    for g in sorted(series.unique()):
        
        # Create a boolean mask to select only patients in this group
        mask = (series == g).values
        
        # Slice the test set using the mask
        y_true_g = y_test.values[mask] # True labels for this group
        y_prob_g = y_test_prob[mask]  # Predictions for this group
        
        # Skip if the group is too small to calculate meaningful metrics
        if len(y_true_g) < 10: continue
            
        if isinstance(y_true_g, pd.Series):
            y_true_g = y_true_g.values
        
        # Calculate overall metrics (AUROC, AUPRC) for this subgroup
        # Use try/except in case a subgroup has only 1 class (e.g., all survived)
        try:
            auroc = roc_auc_score(y_true_g, y_prob_g)
        except ValueError:
            auroc = np.nan
        try:
            auprc = average_precision_score(y_true_g, y_prob_g)
        except ValueError:
            auprc = np.nan
            
        # Calculate metrics at the *fixed thresholds* (found in step 4)
        m_sens = subgroup_metrics_fixed_threshold(y_true_g, y_prob_g, thr_for_sens)
        m_spec = subgroup_metrics_fixed_threshold(y_true_g, y_prob_g, thr_for_spec)
        
        # Add two rows to our results: one for each target threshold
        subgroup_rows.append({"subgroup_type": group_name, "subgroup_value": g, "n": int(len(y_true_g)), "auroc": float(auroc), "auprc": float(auprc), "target": "sensitivity", "threshold": float(thr_for_sens), **m_sens})
        subgroup_rows.append({"subgroup_type": group_name, "subgroup_value": g, "n": int(len(y_true_g)), "auroc": float(auroc), "auprc": float(auprc), "target": "specificity", "threshold": float(thr_for_spec), **m_spec})

# Run the subgroup analysis function for our two defined groups
add_group("SOFA_bin", sofa_bins) # This will add ~10 rows (5 bins * 2 thresholds)
add_group("ICU_unit", csru_group) # This will add ~4 rows (2 bins * 2 thresholds)

# Convert the final list of dictionaries into a DataFrame
subgroup_df = pd.DataFrame(subgroup_rows)
# Log the complete subgroup analysis as a W&B Table
wandb.log({"subgroup_metrics_test": wandb.Table(dataframe=subgroup_df)})
print("Logged subgroup metrics for SOFA bins and CSRU")

# --- 6. Finish Comprehensive Run ---
# Mark the W&B run as complete
run.finish()

0,1
test_auc,▁
test_brier,▁
test_pr,▁
val_auc,▁
val_brier,▁
val_pr,▁

0,1
test_auc,0.86647
test_brier,0.0905
test_pr,0.5463
val_auc,0.85844
val_brier,0.09222
val_pr,0.5251


RF validation AUROC 0.858, AUPRC 0.525, Brier 0.092
RF validation ECE 0.042, Test ECE 0.055
Chosen thresholds -> Sensitivity target: 0.137, Specificity target: 0.240
Logged subgroup metrics for SOFA bins and CSRU


0,1
test_auc,▁
test_brier,▁
test_ece,▁
test_pr,▁
val_auc,▁
val_brier,▁
val_ece,▁
val_pr,▁

0,1
test_auc,0.86647
test_brier,0.0905
test_ece,0.05455
test_pr,0.5463
val_auc,0.85844
val_brier,0.09222
val_ece,0.04241
val_pr,0.5251


### TO DO
- Go to the W&B project page and compare the three `baseline` runs (LR, DT, RF)
- Add `val_auc`, `val_pr`, and `val_brier` to the comparison table
- Open the "02-baseline-random-forest-full" run. This run contains everything:
    - **Metrics**: Check `val_pr` and `test_pr`
    - **Plots**: Review the `roc_curve_val` and `calibration_table_val`
    - **Thresholds**: Inspect the `test_operating_points` table and the `confusion_matrix...` plots
    - **Fairness**: Analyze the `subgroup_metrics_test` table. How does `auroc` or `sensitivity` change between SOFA bins?

# 8. Interpretability for understanding

Model interpretability connects predictive performance to clinical meaning:
- Use **Permutation Importance** for Logistic Regression, Decision Tree, and Random Forest  
- Use **SHAP** on the Random Forest to see which features drive individual predictions  

All outputs are logged to Weights & Biases for exploration and comparison

In [14]:
# Interpretability: Permutation Importance and SHAP for Random Forest
# Logs all interpretability outputs to W&B

run = wandb.init(
    project=WB_PROJECT,
    job_type="interpretability",
    name="03-interpretability-report", # Add clean name
    config={
        "models": ["logistic_regression", "decision_tree", "random_forest"],
        "shap_sample_size": 500
    },
    reinit=True,
)

# 1) Permutation Importance for all models on Validation
models = {
    "logistic_regression": log_reg,
    "decision_tree": dt,
    "random_forest": rf
}

for model_name, model in models.items():
    result = permutation_importance(
        model, X_val_t, y_val, n_repeats=10, random_state=SEED, n_jobs=-1
    )
    imp_df = (
        pd.DataFrame({
            "feature": X_val_t.columns,
            "importance_mean": result.importances_mean,
            "importance_std": result.importances_std
        })
        .sort_values("importance_mean", ascending=False)
        .reset_index(drop=True)
    )
    wandb.log({f"{model_name}_permutation_importance": wandb.Table(dataframe=imp_df)})

# 2) SHAP for Random Forest
shap_sample_n = min(wandb.config.shap_sample_size, len(X_val_t))
shap_sample = X_val_t.sample(n=shap_sample_n, random_state=SEED)

explainer = shap.TreeExplainer(
    rf,
    model_output="raw",
    feature_perturbation="tree_path_dependent"
)

# This logic handles inconsistencies in shap/sklearn versions
shap_values_raw = explainer.shap_values(shap_sample, check_additivity=False)
if isinstance(shap_values_raw, list):
    sv = np.asarray(shap_values_raw[1])
else:
    sv = np.asarray(shap_values_raw)
if sv.ndim == 3:
    cls_axis = sv.shape[-1]
    pick = 1 if cls_axis >= 2 else 0
    sv = sv[..., pick]
sv = np.squeeze(sv)
assert sv.ndim == 2, f"Expected 2D SHAP values, got {sv.shape}"
assert sv.shape[1] == shap_sample.shape[1], "Feature count mismatch"

# Global summary plot
shap.summary_plot(sv, shap_sample, show=False)
wandb.log({"shap_summary_plot": wandb.Image(plt.gcf())})
plt.close()

# Ranked mean absolute SHAP for table
mean_abs_shap = np.abs(sv).mean(axis=0).reshape(-1)
feat_names = list(shap_sample.columns)
shap_df = pd.DataFrame(
    {"feature": feat_names, "mean_abs_shap": mean_abs_shap}
).sort_values("mean_abs_shap", ascending=False).reset_index(drop=True)
wandb.log({"rf_shap_feature_importance": wandb.Table(dataframe=shap_df)})

print("Logging SHAP dependence plots for top 2 features...")
top_features = shap_df["feature"].head(2).tolist()

for feature_name in top_features:
    try:
        fig, ax = plt.subplots()
        shap.dependence_plot(feature_name, sv, shap_sample, ax=ax, show=False, interaction_index=None)
        wandb.log({f"shap_dependence_{feature_name}": wandb.Image(fig)})
        plt.close(fig)
    except Exception as e:
        print(f"Could not plot dependence for {feature_name}: {e}")
        plt.close('all') # Close any open figures to avoid bleed-over

run.finish()

  shap.summary_plot(sv, shap_sample, show=False)


Logging SHAP dependence plots for top 2 features...


### TO DO
- In the `03-interpretability-report` run, compare the permutation importance tables across the three models
- Examine the `shap_summary_plot` to see which features drive RF predictions (e.g., SOFA, CSRU)
- Review the `shap_dependence_...` plots to understand *how* top features impact mortality risk

# 9. Random Forest hyperparameter sweep

We will run a short Weights & Biases Sweep to tune Random Forest hyperparameters with the goal if **Maximizing validation area under the Precision-Recall curve** 
- Remember Precision-Recall is more informative than ROC under class imbalance

We log validation AUROC and Brier score as secondary signals for discrimination and calibration


In [15]:
# Random Forest sweep optimized for imbalanced data using validation AUPRC

# Sweep training function
def train_rf_sweep():
    # Use job_type to group sweep agents
    run = wandb.init(project=WB_PROJECT, job_type="sweep-agent", reinit=True)
    cfg = wandb.config

    model = RandomForestClassifier(
        n_estimators=cfg.n_estimators,
        max_depth=None if cfg.max_depth == 0 else cfg.max_depth,
        max_features=cfg.max_features,
        min_samples_leaf=cfg.min_samples_leaf,
        class_weight="balanced",   # encourage attention to minority class
        random_state=SEED,
        n_jobs=-1
    )
    model.fit(X_train_t, y_train)

    # Validation probabilities
    y_val_prob = model.predict_proba(X_val_t)[:, 1]

    # Metrics
    val_pr   = average_precision_score(y_val, y_val_prob)     # sweep objective
    val_auc  = roc_auc_score(y_val, y_val_prob)
    val_brier = brier_score_loss(y_val, y_val_prob)

    # Log to W&B
    wandb.log({
        "val_pr": val_pr,
        "val_auc": val_auc,
        "val_brier": val_brier
    })
    run.finish()

# Compact search space
sweep_config = {
    "name": "rf_pr_tuning_v2", # Give a new name
    "method": "bayes",
    "metric": {"name": "val_pr", "goal": "maximize"},
    "parameters": {
        "n_estimators": {"values": [100, 200, 300, 500]},
        "max_depth": {"values": [0, 8, 12, 16]},          # 0 means None
        "max_features": {"values": ["sqrt", "log2"]},
        "min_samples_leaf": {"values": [1, 3, 5, 10]}
    },
    "early_terminate": {
        "type": "hyperband",
        "min_iter": 3
    }
}

# Launch the sweep
sweep_id = wandb.sweep(sweep_config, project=WB_PROJECT)
print(f"Sweep started. ID: {sweep_id}")

# Run 5 agents for the demo
wandb.agent(sweep_id, function=train_rf_sweep, count=5)

print("Sweep complete. In Weights & Biases, sort by val_pr and inspect Parameter Importance.")

Create sweep with ID: be45g8em
Sweep URL: https://wandb.ai/idiazl/ml-healthcare-intro/sweeps/be45g8em
Sweep started. ID: be45g8em


wandb: Agent Starting Run: w6tf2l0o with config:
wandb: 	max_depth: 0
wandb: 	max_features: log2
wandb: 	min_samples_leaf: 5
wandb: 	n_estimators: 300


0,1
val_auc,▁
val_brier,▁
val_pr,▁

0,1
val_auc,0.86573
val_brier,0.0996
val_pr,0.49982


wandb: Sweep Agent: Waiting for job.
wandb: Job received.
wandb: Agent Starting Run: db11b3dv with config:
wandb: 	max_depth: 16
wandb: 	max_features: sqrt
wandb: 	min_samples_leaf: 10
wandb: 	n_estimators: 200


0,1
val_auc,▁
val_brier,▁
val_pr,▁

0,1
val_auc,0.86054
val_brier,0.10989
val_pr,0.51334


wandb: Agent Starting Run: mggpjllp with config:
wandb: 	max_depth: 8
wandb: 	max_features: log2
wandb: 	min_samples_leaf: 10
wandb: 	n_estimators: 200
wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin


0,1
val_auc,▁
val_brier,▁
val_pr,▁

0,1
val_auc,0.85894
val_brier,0.12475
val_pr,0.48954


wandb: Agent Starting Run: 2y77c2iv with config:
wandb: 	max_depth: 16
wandb: 	max_features: sqrt
wandb: 	min_samples_leaf: 10
wandb: 	n_estimators: 300


0,1
val_auc,▁
val_brier,▁
val_pr,▁

0,1
val_auc,0.86518
val_brier,0.10901
val_pr,0.51956


wandb: Agent Starting Run: hau3gvc4 with config:
wandb: 	max_depth: 16
wandb: 	max_features: sqrt
wandb: 	min_samples_leaf: 10
wandb: 	n_estimators: 500


0,1
val_auc,▁
val_brier,▁
val_pr,▁

0,1
val_auc,0.8673
val_brier,0.10815
val_pr,0.525


Sweep complete. In Weights & Biases, sort by val_pr and inspect Parameter Importance.


### TO DO
- Open the W&B Sweep link (e.g., `rf_pr_tuning_v2`)
- View the Parallel Coordinates and Parameter Importance plots to see which hyperparameters matter most
- Sort the sweep table by `val_pr` (descending) to find the best run

# 10. Best model selection, Test evaluation, and Model Registry

Let's pick the Random Forest configuration with the best validation Precision-Recall area from the sweep.
- We will use the **W&B API** to programmatically fetch the best run's config
- We refit that model on Train + Validation
- We evaluate on Test
- We log the final model as a **W&B Artifact** and register it in the **Model Registry**

In [None]:
# Final Random Forest evaluation using the best sweep config

# --- 1. Fetch Best Run from Sweep ---
print("Initializing W&B API to find best sweep run...")
api = WandbApi()

# We need the full sweep path: f"{entity}/{project}/{sweep_id}"
# We can get the entity by starting a temporary run
try:
    temp_run = wandb.init(project=WB_PROJECT, job_type="api_helper", reinit=True)
    ENTITY = temp_run.entity
    temp_run.finish()
    
    sweep_path = f"{ENTITY}/{WB_PROJECT}/{sweep_id}"
    print(f"Accessing sweep at: {sweep_path}")
    sweep = api.sweep(sweep_path)
    
    best_run = sweep.best_run()
    print(f"Found best run: {best_run.name} with val_pr: {best_run.summary['val_pr']:.4f}")

    # --- 2. Get Best Parameters ---
    best_params_config = best_run.config
    
    # Re-create the logic from the sweep function (e.g., max_depth=0 -> None)
    rf_params = {
        "n_estimators": best_params_config["n_estimators"],
        "max_depth": None if best_params_config["max_depth"] == 0 else best_params_config["max_depth"],
        "max_features": best_params_config["max_features"],
        "min_samples_leaf": best_params_config["min_samples_leaf"],
        "class_weight": "balanced" # This was fixed in our sweep
    }

except Exception as e:
    print(f"Error fetching sweep data. Using fallback parameters. Error: {e}")
    # Fallback in case API fails in a restricted environment
    rf_params = {
        "n_estimators": 300, "max_depth": 12, "max_features": "sqrt",
        "min_samples_leaf": 3, "class_weight": "balanced"
    }
    best_run = None # Flag that we used fallback

# --- 3. Start Final Evaluation Run ---
run = wandb.init(
    project=WB_PROJECT,
    job_type="final_eval",
    name=f"04-final-model-{'fallback' if best_run is None else best_run.name}",
    config=rf_params, # Log the actual params used
    reinit=True,
)
if best_run:
    wandb.config.update({"source_sweep": sweep_path, "source_run_id": best_run.id})

# --- REFACTORED: Redundant helper function removed ---
# The 'calibration_table' function is now defined globally in Section 2

# --- 4. Retrain Model on Train+Val ---
print("Retraining best model on Train + Validation data...")
X_train_full_t = pd.concat([X_train_t, X_val_t])
y_train_full = pd.concat([y_train, y_val])

rf_best = RandomForestClassifier(
    **rf_params, # Unpack the fetched params
    random_state=SEED,
    n_jobs=-1
)
rf_best.fit(X_train_full_t, y_train_full)

# --- 5. Evaluate and Log on Test ---
y_test_prob = rf_best.predict_proba(X_test_t)[:, 1]

test_pr = average_precision_score(y_test, y_test_prob)
test_auc = roc_auc_score(y_test, y_test_prob)
test_brier = brier_score_loss(y_test, y_test_prob)

wandb.log({
    "test_pr": test_pr,
    "test_auc": test_auc,
    "test_brier": test_brier
})

# --- ENHANCEMENT: Log Final Test Curves ---
y_test_probas_2d = np.stack([1.0 - y_test_prob, y_test_prob], axis=1)
wandb.log({
    "roc_curve_test_final": wandb_roc_curve(y_test.values, y_test_probas_2d),
    "pr_curve_test_final": wandb_pr_curve(y_test.values, y_test_probas_2d)
})

# Log calibration table and predictions
cal_test_tbl, test_ece = calibration_table(y_test, y_test_prob)
wandb.log({
    "test_ece_final": test_ece,
    "calibration_table_test_final": wandb.Table(dataframe=cal_test_tbl)
})
pred_tbl = pd.DataFrame({"id": X_test.index, "y_true": y_test.values, "y_prob": y_test_prob})
wandb.log({"final_test_predictions": wandb.Table(dataframe=pred_tbl)})

print(f"Final Test AUROC {test_auc:.3f}, AUPRC {test_pr:.3f}, Brier {test_brier:.3f}, ECE {test_ece:.3f}")


print("Running subgroup (fairness) analysis on *final* tuned model...")

# --- Helper functions (copied from Cell 7) ---
def subgroup_metrics_fixed_threshold(y_true, y_prob, thr):
    if isinstance(y_true, pd.Series):
        y_true = y_true.values
    y_pred = (y_prob >= thr).astype(int)
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    except ValueError: 
        tn, fp, fn, tp = 0, 0, 0, 0
        if len(y_pred) > 0:
            tn = np.sum((y_true == 0) & (y_pred == 0))
            fp = np.sum((y_true == 0) & (y_pred == 1))
            fn = np.sum((y_true == 1) & (y_pred == 0))
            tp = np.sum((y_true == 1) & (y_pred == 1))
    sens = tp / (tp + fn) if (tp + fn) else np.nan
    spec = tn / (tn + fp) if (tn + fp) else np.nan
    ppv  = tp / (tp + fp) if (tp + fp) else np.nan
    npv  = tn / (tn + fn) if (tn + fn) else np.nan
    prev = (tp + fn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else np.nan
    return dict(
        tp=int(tp), fp=int(fp), tn=int(tn), fn=int(fn),
        sensitivity=float(sens), specificity=float(spec),
        ppv=float(ppv), npv=float(npv), prevalence=float(prev)
    )

subgroup_rows_final = [] # Use a new list name
def add_group_final(group_name, group_values, y_true_all, y_prob_all, thr_sens, thr_spec):
    series = pd.Series(group_values, index=y_true_all.index).astype(str)
    for g in sorted(series.unique()):
        mask = (series == g).values
        y_true_g = y_true_all.values[mask]
        y_prob_g = y_prob_all[mask]
        if len(y_true_g) < 10: continue
        if isinstance(y_true_g, pd.Series):
            y_true_g = y_true_g.values
        try:
            auroc = roc_auc_score(y_true_g, y_prob_g)
        except ValueError:
            auroc = np.nan
        try:
            auprc = average_precision_score(y_true_g, y_prob_g)
        except ValueError:
            auprc = np.nan
        m_sens = subgroup_metrics_fixed_threshold(y_true_g, y_prob_g, thr_sens)
        m_spec = subgroup_metrics_fixed_threshold(y_true_g, y_prob_g, thr_spec)
        subgroup_rows_final.append({"subgroup_type": group_name, "subgroup_value": g, "n": int(len(y_true_g)), "auroc": float(auroc), "auprc": float(auprc), "target": "sensitivity", "threshold": float(thr_sens), **m_sens})
        subgroup_rows_final.append({"subgroup_type": group_name, "subgroup_value": g, "n": int(len(y_true_g)), "auroc": float(auroc), "auprc": float(auprc), "target": "specificity", "threshold": float(thr_spec), **m_spec})

# --- Run Subgroup Analysis ---
try:
    _ = thr_for_sens
    _ = thr_for_spec
except NameError:
    print("Warning: Thresholds from Cell 7 not found. Using default 0.5.")
    thr_for_sens = 0.5
    thr_for_spec = 0.5

# Get subgroup features (same as Cell 7)
SOFA_COL = "SOFA"
CSRU_COL = "CSRU"
test_idx = X_test.index
sofa_test = ICU.loc[test_idx, SOFA_COL]
csru_test = ICU.loc[test_idx, CSRU_COL]
sofa_bins = pd.qcut(sofa_test, q=5, duplicates="drop").astype(str)
csru_group = np.where(pd.to_numeric(csru_test, errors="coerce").fillna(0).astype(int) == 1, "CSRU", "non_CSRU")

# Run the analysis
add_group_final("SOFA_bin", sofa_bins, y_test, y_test_prob, thr_for_sens, thr_for_spec)
add_group_final("ICU_unit", csru_group, y_test, y_test_prob, thr_for_sens, thr_for_spec)

# Log the final subgroup analysis as a W&B Table
subgroup_df_final = pd.DataFrame(subgroup_rows_final)
wandb.log({"subgroup_metrics_test_FINAL": wandb.Table(dataframe=subgroup_df_final)})
print("Logged *final* subgroup metrics for SOFA bins and CSRU")


# --- 7. Log Model as Artifact and Register ---
print("Logging model to W&B Artifacts and Model Registry...")
model_file = "final_rf_model.joblib"
joblib.dump(rf_best, model_file)

# Define the artifact
model_at = wandb.Artifact(
    "best-rf-model",
    type="model",
    description="Final Random Forest model trained on train+val, tuned for AUPRC.",
    metadata=rf_params
)
model_at.add_file(model_file)

# Log the artifact to the run
run.log_artifact(model_at, aliases=["production_candidate"])

# Register the model in the Model Registry
try:
    run.link_artifact(model_at, f"{WB_PROJECT}/ICU_Mortality_RF_Model")
    print("Successfully logged and registered model artifact")
except Exception as e:
    print(f"Note: Could not auto-register model. Logged artifact instead. Error: {e}")

run.finish()

Initializing W&B API to find best sweep run...


Accessing sweep at: idiazl/ml-healthcare-intro/be45g8em


wandb: Sorting runs by -summary_metrics.val_pr


Error fetching sweep data. Using fallback parameters. Error: string indices must be integers, not 'str'


Retraining best model on Train + Validation data...
Final Test AUROC 0.885, AUPRC 0.589, Brier 0.092, ECE 0.079
Running subgroup (fairness) analysis on *final* tuned model...
Logged *final* subgroup metrics for SOFA bins and CSRU
Logging model to W&B Artifacts and Model Registry...
Successfully logged and registered model artifact


0,1
test_auc,▁
test_brier,▁
test_ece_final,▁
test_pr,▁

0,1
test_auc,0.88546
test_brier,0.09191
test_ece_final,0.07939
test_pr,0.58917


### TO DO
- Open the `04-final-eval` run
- Compare its `test_pr` metric against the `02-baseline-random-forest-full` run to quantify the impact of tuning
- Go to the **Artifacts** tab in the W&B project (left-hand sidebar)
- Find the `best-rf-model` artifact and inspect its contents and metadata
- Go to the **Models** tab to see the registered `ICU_Mortality_RF_Model` and its version history

# 11. Final Documentation: Model Card and Reporting

The final step in a responsible ML workflow is documentation. This ensures that the model's performance, limitations, and intended use are understood by all stakeholders (e.g., clinicians, regulators, engineers).

In Weights & Biases, you can create a **W&B Report** to weave together your findings into a comprehensive model card.

### A good report for this project would include:

-   **Objective**: The clinical goal (predicting in-hospital mortality)
-   **Data**: A link to the `01-data-exploration` run, showing class balance and missingness
-   **Models**: The "Model Comparison" table (from comparing `baseline` runs) showing why Random Forest was chosen
-   **Tuning**: Key visualizations from the `rf_pr_tuning_v2` sweep
-   **Final Performance**: Final metrics from the `04-final-eval` run (Test AUROC, AUPRC, Brier)
-   **Interpretability**: The `shap_summary_plot` and `shap_dependence_...` plots from the `03-interpretability-report` run
-   **Fairness & Safety**: The `subgroup_metrics_test` table, discussing performance on SOFA and CSRU groups
-   **Clinical Use**: The `test_operating_points` table, explaining the trade-offs between the sensitivity and specificity thresholds

You can also attach this information directly to the registered model in the W&B Model Registry UI. This creates a "model card" that lives with the model, ensuring downstream users understand its strengths and limitations before deployment.