# Module 4: Evaluating ML Models in Medicine
## Beyond Accuracy: Thresholds, Capacity, and Clinical Trade-offs

**Goal:** Learn how evaluation changes when model outputs are used for real decisions in a clinical workflow.

### How to use this notebook
- Run cells from top to bottom.
- Change thresholds and policy settings with widgets.
- Focus on how decisions shift when the clinical context changes.

### Learning objectives
1. Compare model performance using sensitivity, specificity, precision, F1, ROC-AUC, and PR-AUC.
2. Understand why threshold choice is a policy decision, not just a technical default.
3. Simulate resource-limited triage where only top-risk patients can be flagged.
4. Use simple cost assumptions to pick an operating threshold.
5. Produce a transparent policy summary for team discussion.

## Section 0: Clinical Problem
A readmission model gives each patient a risk probability.
The clinical team must decide **where to set the action threshold** and **how many patients can receive interventions**.
This module focuses on the decision layer after modeling.

## Helper Functions
Run this cell once at the start. It auto-configures paths in Google Colab and does nothing harmful on local Jupyter.

In [None]:
import os
import sys
import subprocess
from pathlib import Path

def setup_repo_for_colab(
    repo_url='https://github.com/aaekay/Medical-AI-101.git',
    repo_dir='/content/Medical-AI-101',
    notebook_dir='chapters',
):
    if 'google.colab' not in sys.modules:
        print(f'Local runtime detected. Working directory: {Path.cwd()}')
        return

    repo_path = Path(repo_dir)
    if not repo_path.exists():
        print('Cloning Medical-AI-101 into /content ...')
        subprocess.check_call(['git', 'clone', repo_url, str(repo_path)])

    target = repo_path / notebook_dir
    os.chdir(target)
    print(f'Colab ready. Working directory: {Path.cwd()}')

setup_repo_for_colab()


In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from IPython.display import display
from sklearn.metrics import (
    average_precision_score,
    confusion_matrix,
    precision_recall_curve,
    roc_auc_score,
    roc_curve,
)

try:
    import ipywidgets as widgets
except ImportError as exc:
    raise ImportError('ipywidgets is required for interactive demos in this notebook.') from exc


def resolve_data_path(filename):
    candidates = [Path('../data') / filename, Path('data') / filename]
    for cand in candidates:
        if cand.exists():
            return cand
    raise FileNotFoundError(f'Could not locate {filename} in ../data or data/. Run Module 3 first.')


PRED_PATH = resolve_data_path('module_03_test_predictions.csv')
df_pred = pd.read_csv(PRED_PATH)
print(f'Loaded {len(df_pred)} prediction rows from {PRED_PATH}.')


In [None]:
required_cols = [
    'encounter_id',
    'actual_readmit_30d',
    'lr_probability',
    'tree_probability',
]
missing = [c for c in required_cols if c not in df_pred.columns]
if missing:
    raise ValueError(f'Missing required columns: {missing}')

for col in ['lr_probability', 'tree_probability']:
    if ((df_pred[col] < 0) | (df_pred[col] > 1)).any():
        raise ValueError(f'{col} has values outside [0, 1].')

prevalence = df_pred['actual_readmit_30d'].mean()
print(f"Outcome prevalence in this test set: {prevalence:.3f}")
display(df_pred.head(8))


## Section 1: Metrics at a Default Threshold (0.50)
This mirrors common quick evaluations, but we will soon show why fixed 0.50 is rarely enough in clinical operations.

In [None]:
y_true = df_pred['actual_readmit_30d'].astype(int).values
model_probs = {
    'Logistic Regression': df_pred['lr_probability'].values,
    'Decision Tree': df_pred['tree_probability'].values,
}


def metric_bundle(y, proba, threshold):
    pred = (proba >= threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(y, pred, labels=[0, 1]).ravel()

    total = tp + tn + fp + fn
    accuracy = (tp + tn) / total if total else np.nan
    precision = tp / (tp + fp) if (tp + fp) else 0.0
    sensitivity = tp / (tp + fn) if (tp + fn) else 0.0
    specificity = tn / (tn + fp) if (tn + fp) else 0.0
    f1 = 2 * precision * sensitivity / (precision + sensitivity) if (precision + sensitivity) else np.nan
    roc_auc = roc_auc_score(y, proba)
    pr_auc = average_precision_score(y, proba)

    return {
        'accuracy': accuracy,
        'precision': precision,
        'sensitivity_recall': sensitivity,
        'specificity': specificity,
        'f1': f1,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc,
        'tp': tp,
        'tn': tn,
        'fp': fp,
        'fn': fn,
        'flag_rate': pred.mean(),
    }


default_rows = []
for model_name, proba in model_probs.items():
    row = metric_bundle(y_true, proba, threshold=0.50)
    row['model'] = model_name
    default_rows.append(row)

default_metrics = pd.DataFrame(default_rows).set_index('model')
display(default_metrics[['accuracy', 'precision', 'sensitivity_recall', 'specificity', 'f1', 'roc_auc', 'pr_auc', 'flag_rate']].round(3))


## Section 2: ROC and Precision-Recall Curves
ROC shows ranking performance; PR is often more informative when positives are less frequent.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4.2))

# ROC
for model_name, proba in model_probs.items():
    fpr, tpr, _ = roc_curve(y_true, proba)
    axes[0].plot(fpr, tpr, linewidth=2, label=f"{model_name} (AUC={roc_auc_score(y_true, proba):.2f})")
axes[0].plot([0, 1], [0, 1], '--', color='gray', linewidth=1)
axes[0].set_title('ROC Curve')
axes[0].set_xlabel('False Positive Rate')
axes[0].set_ylabel('True Positive Rate (Sensitivity)')
axes[0].legend()

# PR
baseline = y_true.mean()
for model_name, proba in model_probs.items():
    precision, recall, _ = precision_recall_curve(y_true, proba)
    ap = average_precision_score(y_true, proba)
    axes[1].plot(recall, precision, linewidth=2, label=f"{model_name} (AP={ap:.2f})")
axes[1].hlines(baseline, 0, 1, linestyles='--', color='gray', linewidth=1, label=f'Prevalence={baseline:.2f}')
axes[1].set_title('Precision-Recall Curve')
axes[1].set_xlabel('Recall')
axes[1].set_ylabel('Precision')
axes[1].legend()

plt.tight_layout()
plt.show()


## Section 3: Interactive Threshold Explorer
Move the threshold and observe how clinical workload and miss rate change.

In [None]:
def threshold_explorer(model='Logistic Regression', threshold=0.50):
    proba = model_probs[model]
    m = metric_bundle(y_true, proba, threshold)

    flagged_per_100 = m['flag_rate'] * 100
    missed_per_100 = (m['fn'] / len(y_true)) * 100

    print(f'Model: {model}')
    print(f'Threshold: {threshold:.2f}')
    print(
        f"Sensitivity: {m['sensitivity_recall']:.3f} | Specificity: {m['specificity']:.3f} | "
        f"Precision: {m['precision']:.3f}"
    )
    print(f'Flagged for intervention: {flagged_per_100:.1f} per 100 discharges')
    print(f'Missed true readmissions (false negatives): {missed_per_100:.1f} per 100 discharges')

    counts = pd.Series({'TP': m['tp'], 'TN': m['tn'], 'FP': m['fp'], 'FN': m['fn']})

    fig, ax = plt.subplots(figsize=(6.2, 3.2))
    ax.bar(counts.index, counts.values, color=['#54a24b', '#4c78a8', '#f58518', '#e45756'])
    ax.set_title('Confusion Counts at Selected Threshold')
    ax.set_ylabel('Patients')
    plt.tight_layout()
    plt.show()


widgets.interact(
    threshold_explorer,
    model=widgets.Dropdown(options=list(model_probs.keys()), value='Logistic Regression', description='Model'),
    threshold=widgets.FloatSlider(value=0.50, min=0.05, max=0.95, step=0.05, description='Threshold', continuous_update=False),
)


## Section 4: Capacity-Constrained Triage
Some teams can only review a fixed percentage of patients.
Here we simulate a top-K strategy using highest predicted risk.

In [None]:
def topk_policy_metrics(y, proba, capacity_pct):
    n = len(y)
    k = max(1, int(np.ceil((capacity_pct / 100.0) * n)))

    order = np.argsort(-proba)
    flagged = np.zeros(n, dtype=int)
    flagged[order[:k]] = 1

    tn, fp, fn, tp = confusion_matrix(y, flagged, labels=[0, 1]).ravel()
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    precision = tp / (tp + fp) if (tp + fp) else 0.0

    return {
        'capacity_pct': capacity_pct,
        'k_flagged': k,
        'tp': tp,
        'fp': fp,
        'fn': fn,
        'tn': tn,
        'recall': recall,
        'precision': precision,
        'flag_rate': flagged.mean(),
    }


def capacity_explorer(model='Logistic Regression', capacity_pct=20):
    proba = model_probs[model]
    m = topk_policy_metrics(y_true, proba, capacity_pct)

    print(f'Model: {model}')
    print(f'Capacity: {capacity_pct:.0f}% of patients ({m["k_flagged"]} out of {len(y_true)})')
    print(f"Recall captured by top-K: {m['recall']:.3f}")
    print(f"Precision among flagged: {m['precision']:.3f}")

    bars = pd.Series({'TP captured': m['tp'], 'FP flagged': m['fp'], 'FN missed': m['fn']})
    fig, ax = plt.subplots(figsize=(6.2, 3.2))
    ax.bar(bars.index, bars.values, color=['#54a24b', '#f58518', '#e45756'])
    ax.set_title('Top-K Capacity Policy Outcome Counts')
    ax.set_ylabel('Patients')
    ax.tick_params(axis='x', rotation=10)
    plt.tight_layout()
    plt.show()


widgets.interact(
    capacity_explorer,
    model=widgets.Dropdown(options=list(model_probs.keys()), value='Logistic Regression', description='Model'),
    capacity_pct=widgets.IntSlider(value=20, min=5, max=80, step=5, description='Capacity %', continuous_update=False),
)


## Section 5: Utility-Based Threshold Selection
Choose threshold by minimizing expected policy cost, where false negatives and false positives have different penalties.

In [None]:
def expected_cost(y, proba, threshold, fn_cost=10.0, fp_cost=1.0):
    pred = (proba >= threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(y, pred, labels=[0, 1]).ravel()
    total_cost = (fn * fn_cost) + (fp * fp_cost)
    cost_per_patient = total_cost / len(y)
    return cost_per_patient, {'tn': tn, 'fp': fp, 'fn': fn, 'tp': tp}


def find_best_threshold(y, proba, fn_cost=10.0, fp_cost=1.0):
    grid = np.round(np.arange(0.05, 0.96, 0.01), 2)
    rows = []
    for t in grid:
        c, cm = expected_cost(y, proba, t, fn_cost=fn_cost, fp_cost=fp_cost)
        m = metric_bundle(y, proba, t)
        rows.append({
            'threshold': t,
            'cost_per_patient': c,
            'sensitivity_recall': m['sensitivity_recall'],
            'specificity': m['specificity'],
            'precision': m['precision'],
            'flag_rate': m['flag_rate'],
            'tp': cm['tp'],
            'fp': cm['fp'],
            'fn': cm['fn'],
            'tn': cm['tn'],
        })
    df_grid = pd.DataFrame(rows)
    best = df_grid.loc[df_grid['cost_per_patient'].idxmin()]
    return df_grid, best


def utility_explorer(model='Logistic Regression', fn_cost=10.0, fp_cost=1.0):
    proba = model_probs[model]
    grid, best = find_best_threshold(y_true, proba, fn_cost=fn_cost, fp_cost=fp_cost)

    print(f'Model: {model}')
    print(f'False negative cost: {fn_cost:.1f} | False positive cost: {fp_cost:.1f}')
    print(f"Best threshold by expected cost: {best['threshold']:.2f}")
    print(f"Expected cost per patient: {best['cost_per_patient']:.3f}")
    print(
        f"Sensitivity: {best['sensitivity_recall']:.3f} | Specificity: {best['specificity']:.3f} | "
        f"Precision: {best['precision']:.3f}"
    )

    fig, ax = plt.subplots(figsize=(7, 3.5))
    ax.plot(grid['threshold'], grid['cost_per_patient'], color='#4c78a8', linewidth=2)
    ax.axvline(best['threshold'], color='#e45756', linestyle='--', label=f"Best={best['threshold']:.2f}")
    ax.set_title('Expected Cost per Patient Across Thresholds')
    ax.set_xlabel('Threshold')
    ax.set_ylabel('Cost per patient')
    ax.legend()
    plt.tight_layout()
    plt.show()


widgets.interact(
    utility_explorer,
    model=widgets.Dropdown(options=list(model_probs.keys()), value='Logistic Regression', description='Model'),
    fn_cost=widgets.FloatSlider(value=10.0, min=1.0, max=20.0, step=1.0, description='FN cost', continuous_update=False),
    fp_cost=widgets.FloatSlider(value=1.0, min=0.5, max=10.0, step=0.5, description='FP cost', continuous_update=False),
)


## Section 6: Scenario Table for Team Discussion
This creates a simple policy summary under three clinical scenarios.

In [None]:
scenarios = [
    {'scenario': 'Safety-first clinic', 'fn_cost': 12.0, 'fp_cost': 1.0},
    {'scenario': 'Balanced operations', 'fn_cost': 6.0, 'fp_cost': 2.0},
    {'scenario': 'Resource-limited service', 'fn_cost': 3.0, 'fp_cost': 3.0},
]

summary_rows = []
for s in scenarios:
    for model_name, proba in model_probs.items():
        grid, best = find_best_threshold(y_true, proba, fn_cost=s['fn_cost'], fp_cost=s['fp_cost'])
        summary_rows.append({
            'scenario': s['scenario'],
            'model': model_name,
            'fn_cost': s['fn_cost'],
            'fp_cost': s['fp_cost'],
            'recommended_threshold': float(best['threshold']),
            'expected_cost_per_patient': float(best['cost_per_patient']),
            'sensitivity_recall': float(best['sensitivity_recall']),
            'specificity': float(best['specificity']),
            'precision': float(best['precision']),
            'flag_rate': float(best['flag_rate']),
        })

policy_summary = pd.DataFrame(summary_rows)
policy_summary = policy_summary.sort_values(['scenario', 'expected_cost_per_patient']).reset_index(drop=True)
display(policy_summary.round(3))


## Section 7: Save Policy Summary for Module 5
This file can be reused in later modules on workflow design, governance, and deployment decisions.

In [None]:
out_path = PRED_PATH.parent / 'module_04_threshold_policy_summary.csv'
policy_summary.to_csv(out_path, index=False)
print(f'Saved policy summary to {out_path}')


## Wrap-up: Key Takeaways
- Good ranking performance does not automatically define a good clinical policy.
- Threshold and capacity settings directly change workload and missed cases.
- Costs and risk tolerance should be explicit and discussable with stakeholders.
- A transparent policy table is often more useful than one headline metric.