# Module 5C: Metrics, Thresholds, and Hyperparameter Decisions
## Turning Model Scores Into Clinical Action

**Goal:** Use predictions from Module 5B to choose thresholds and hyperparameters based on clinical trade-offs, not just one accuracy number.

### Learning objectives
1. Compare ROC and PR curves for chest X-ray classification.
2. Understand when PR-AUC is more informative than ROC-AUC.
3. Select thresholds using sensitivity/specificity/workload constraints.
4. Compare hyperparameter runs by validation outcomes.
5. Export a simple model policy card for team discussion.

## Section 0: Load Outputs From Module 5A and 5B
Required for full analysis:
- `data/module_05b_test_predictions.csv`
- `data/module_05b_training_history.csv`

Optional baseline comparison:
- `data/module_05a_baseline_test_predictions.csv`

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from IPython.display import display
from sklearn.metrics import (
    average_precision_score,
    confusion_matrix,
    precision_recall_curve,
    roc_auc_score,
    roc_curve,
)

try:
    import ipywidgets as widgets
except ImportError as exc:
    raise ImportError('ipywidgets is required for Module 5C interactive controls.') from exc


def resolve_data_file(name):
    candidates = [Path('../data') / name, Path('data') / name]
    for cand in candidates:
        if cand.exists():
            return cand
    return None


cnn_test_path = resolve_data_file('module_05b_test_predictions.csv')
cnn_hist_path = resolve_data_file('module_05b_training_history.csv')
baseline_path = resolve_data_file('module_05a_baseline_test_predictions.csv')

if cnn_test_path is None:
    print('Missing module_05b_test_predictions.csv. Run Module 5B first.')
    cnn_test = pd.DataFrame(columns=['label', 'probability'])
else:
    cnn_test = pd.read_csv(cnn_test_path)
    print(f'Loaded CNN test predictions: {len(cnn_test)} rows')

if cnn_hist_path is None:
    cnn_history = pd.DataFrame()
    print('No training history found yet. Run at least one Module 5B experiment.')
else:
    cnn_history = pd.read_csv(cnn_hist_path)
    print(f'Loaded training history rows: {len(cnn_history)}')

if baseline_path is None:
    baseline_test = pd.DataFrame()
    print('Baseline file not found (optional).')
else:
    baseline_test = pd.read_csv(baseline_path)
    print(f'Loaded baseline predictions: {len(baseline_test)} rows')


In [None]:
if not cnn_test.empty:
    display(cnn_test.head())
    prevalence = cnn_test['label'].mean()
    print(f'Positive prevalence in test set: {prevalence:.3f}')


## Section 1: Metric Functions
We explicitly compute threshold-dependent and threshold-independent metrics.

In [None]:
def metric_bundle(y_true, y_prob, threshold=0.5):
    y_pred = (y_prob >= threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()

    accuracy = (tp + tn) / (tp + tn + fp + fn)
    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    specificity = tn / (tn + fp) if (tn + fp) else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall_sensitivity': recall,
        'specificity': specificity,
        'f1': f1,
        'tp': tp,
        'tn': tn,
        'fp': fp,
        'fn': fn,
        'flag_rate': y_pred.mean(),
        'roc_auc': roc_auc_score(y_true, y_prob),
        'pr_auc': average_precision_score(y_true, y_prob),
    }


def threshold_grid(y_true, y_prob):
    rows = []
    for t in np.round(np.arange(0.05, 0.96, 0.01), 2):
        m = metric_bundle(y_true, y_prob, threshold=t)
        rows.append({'threshold': t, **m})
    return pd.DataFrame(rows)


## Section 2: PR vs ROC Comparison
Use ROC-AUC for global ranking, and PR-AUC for positive-class retrieval focus.
In screening settings, PR often better reflects operational reality.

In [None]:
if cnn_test.empty:
    print('No CNN predictions available for curve plotting.')
else:
    y = cnn_test['label'].astype(int).values
    p = cnn_test['probability'].values

    fig, axes = plt.subplots(1, 2, figsize=(12, 4.2))

    fpr, tpr, _ = roc_curve(y, p)
    axes[0].plot(fpr, tpr, color='#4c78a8', linewidth=2, label=f'CNN AUC={roc_auc_score(y, p):.2f}')
    axes[0].plot([0, 1], [0, 1], '--', color='gray', linewidth=1)

    if not baseline_test.empty and len(baseline_test) == len(cnn_test):
        pb = baseline_test['probability'].values
        fpr_b, tpr_b, _ = roc_curve(y, pb)
        axes[0].plot(fpr_b, tpr_b, color='#e45756', linewidth=2, label=f'Baseline AUC={roc_auc_score(y, pb):.2f}')

    axes[0].set_title('ROC Curve')
    axes[0].set_xlabel('False Positive Rate')
    axes[0].set_ylabel('True Positive Rate')
    axes[0].legend()

    precision, recall, _ = precision_recall_curve(y, p)
    axes[1].plot(recall, precision, color='#4c78a8', linewidth=2, label=f'CNN AP={average_precision_score(y, p):.2f}')

    if not baseline_test.empty and len(baseline_test) == len(cnn_test):
        precision_b, recall_b, _ = precision_recall_curve(y, pb)
        axes[1].plot(recall_b, precision_b, color='#e45756', linewidth=2, label=f'Baseline AP={average_precision_score(y, pb):.2f}')

    axes[1].hlines(y.mean(), 0, 1, color='gray', linestyle='--', linewidth=1, label=f'Prevalence={y.mean():.2f}')
    axes[1].set_title('Precision-Recall Curve')
    axes[1].set_xlabel('Recall')
    axes[1].set_ylabel('Precision')
    axes[1].legend()

    plt.tight_layout()
    plt.show()


## Section 3: Interactive Threshold Selection
Pick thresholds by clinical objective: sensitivity-first, precision-first, or balanced operations.

In [None]:
def threshold_explorer(threshold=0.50):
    if cnn_test.empty:
        print('No CNN predictions available.')
        return

    y = cnn_test['label'].astype(int).values
    p = cnn_test['probability'].values
    m = metric_bundle(y, p, threshold=threshold)

    print(f'Threshold: {threshold:.2f}')
    print(f"Sensitivity: {m['recall_sensitivity']:.3f} | Specificity: {m['specificity']:.3f} | Precision: {m['precision']:.3f}")
    print(f"PR-AUC: {m['pr_auc']:.3f} | ROC-AUC: {m['roc_auc']:.3f}")
    print(f"Flagged workload: {m['flag_rate'] * 100:.1f} per 100 images")

    bars = pd.Series({'TP': m['tp'], 'TN': m['tn'], 'FP': m['fp'], 'FN': m['fn']})
    fig, ax = plt.subplots(figsize=(6.2, 3.2))
    ax.bar(bars.index, bars.values, color=['#54a24b', '#4c78a8', '#f58518', '#e45756'])
    ax.set_title('CNN Confusion Counts')
    ax.set_ylabel('Images')
    plt.tight_layout()
    plt.show()


widgets.interact(
    threshold_explorer,
    threshold=widgets.FloatSlider(value=0.50, min=0.05, max=0.95, step=0.05, description='Threshold', continuous_update=False),
)


## Section 4: Cost-Aware Threshold Recommendation
If missing pneumonia is more costly than false alarms, threshold should move lower.

In [None]:
def choose_threshold_by_cost(y_true, y_prob, fn_cost=10.0, fp_cost=1.0):
    grid = threshold_grid(y_true, y_prob)
    grid['expected_cost_per_image'] = (grid['fn'] * fn_cost + grid['fp'] * fp_cost) / len(y_true)
    best = grid.loc[grid['expected_cost_per_image'].idxmin()]
    return grid, best


def cost_explorer(fn_cost=10.0, fp_cost=1.0):
    if cnn_test.empty:
        print('No CNN predictions available.')
        return

    y = cnn_test['label'].astype(int).values
    p = cnn_test['probability'].values
    grid, best = choose_threshold_by_cost(y, p, fn_cost=fn_cost, fp_cost=fp_cost)

    print(f'FN cost={fn_cost:.1f}, FP cost={fp_cost:.1f}')
    print(f"Recommended threshold: {best['threshold']:.2f}")
    print(f"Sensitivity={best['recall_sensitivity']:.3f} | Specificity={best['specificity']:.3f} | Precision={best['precision']:.3f}")
    print(f"Expected cost per image: {best['expected_cost_per_image']:.3f}")

    fig, ax = plt.subplots(figsize=(7, 3.5))
    ax.plot(grid['threshold'], grid['expected_cost_per_image'], color='#4c78a8', linewidth=2)
    ax.axvline(best['threshold'], color='#e45756', linestyle='--', label=f"Best={best['threshold']:.2f}")
    ax.set_title('Expected Cost vs Threshold')
    ax.set_xlabel('Threshold')
    ax.set_ylabel('Cost per image')
    ax.legend()
    plt.tight_layout()
    plt.show()


widgets.interact(
    cost_explorer,
    fn_cost=widgets.FloatSlider(value=10.0, min=1.0, max=20.0, step=1.0, description='FN cost', continuous_update=False),
    fp_cost=widgets.FloatSlider(value=1.0, min=0.5, max=10.0, step=0.5, description='FP cost', continuous_update=False),
)


## Section 5: Hyperparameter Run Comparison
Compare experiments from Module 5B to choose stable settings before final training.

In [None]:
if cnn_history.empty:
    print('No training history to compare yet.')
else:
    run_summary = (
        cnn_history.sort_values('epoch')
        .groupby('run_id', as_index=False)
        .agg(
            batch_size=('batch_size', 'last'),
            img_size=('img_size', 'last'),
            learning_rate=('learning_rate', 'last'),
            weight_decay=('weight_decay', 'last'),
            dropout=('dropout', 'last'),
            epochs=('epochs', 'last'),
            elapsed_sec=('elapsed_sec', 'last'),
            best_val_roc_auc=('val_roc_auc', 'max'),
            best_val_pr_auc=('val_pr_auc', 'max'),
            final_val_loss=('val_loss', 'last'),
        )
        .sort_values('best_val_pr_auc', ascending=False)
        .reset_index(drop=True)
    )

    display(run_summary.round(4))

    fig, ax = plt.subplots(figsize=(7, 4))
    sc = ax.scatter(
        run_summary['best_val_roc_auc'],
        run_summary['best_val_pr_auc'],
        s=np.clip(run_summary['elapsed_sec'], 20, 400),
        c=run_summary['dropout'],
        cmap='viridis',
        alpha=0.8,
    )
    ax.set_xlabel('Best Validation ROC-AUC')
    ax.set_ylabel('Best Validation PR-AUC')
    ax.set_title('Hyperparameter Run Comparison')
    cbar = plt.colorbar(sc, ax=ax)
    cbar.set_label('Dropout')
    plt.tight_layout()
    plt.show()


## Section 6: Export Model Policy Card
This creates a compact table for team review and deployment planning.

In [None]:
if cnn_test.empty:
    print('Cannot export policy card without CNN predictions.')
else:
    y = cnn_test['label'].astype(int).values
    p = cnn_test['probability'].values

    scenarios = [
        {'scenario': 'Safety-first screening', 'fn_cost': 12.0, 'fp_cost': 1.0},
        {'scenario': 'Balanced workflow', 'fn_cost': 6.0, 'fp_cost': 2.0},
        {'scenario': 'Resource-limited follow-up', 'fn_cost': 3.0, 'fp_cost': 3.0},
    ]

    rows = []
    for s in scenarios:
        grid, best = choose_threshold_by_cost(y, p, fn_cost=s['fn_cost'], fp_cost=s['fp_cost'])
        rows.append({
            'scenario': s['scenario'],
            'fn_cost': s['fn_cost'],
            'fp_cost': s['fp_cost'],
            'recommended_threshold': float(best['threshold']),
            'sensitivity_recall': float(best['recall_sensitivity']),
            'specificity': float(best['specificity']),
            'precision': float(best['precision']),
            'flag_rate': float(best['flag_rate']),
            'pr_auc': float(best['pr_auc']),
            'roc_auc': float(best['roc_auc']),
        })

    policy_card = pd.DataFrame(rows)
    display(policy_card.round(3))

    out_path = (cnn_test_path.parent if cnn_test_path is not None else Path('../data')) / 'module_05c_model_policy_card.csv'
    policy_card.to_csv(out_path, index=False)
    print(f'Saved policy card to {out_path}')


## Wrap-up
- PR and ROC should be interpreted together, not as substitutes.
- Thresholds are operational choices tied to risk tolerance and staffing.
- Hyperparameter selection should be evidence-based, not trial-and-error guesswork.
- With a policy card, teams can discuss deployment using transparent assumptions.