# Module 5A: Chest X-ray Problem Framing and Method Map
## From Clinical Need to Candidate AI Approaches

**Goal:** Start with a real chest X-ray problem, compare solution families, and build a non-deep-learning baseline before training a CNN.

### Why this chapter first?
Before coding deep learning, clinicians should understand **what problem we are solving**, **which methods are available**, and **what baseline performance looks like**.

### Learning objectives
1. Define a clinically meaningful binary task in chest radiology.
2. Compare rule-based, traditional ML, and deep learning options.
3. Build a handcrafted-feature baseline on a real chest X-ray dataset.
4. Understand why PR-AUC and ROC-AUC both matter.
5. Prepare the handoff to deep-learning training in Module 5B.

## Section 0: Clinical Problem
**Problem statement:** At triage, can we flag chest X-rays likely to show pneumonia so radiologists can prioritize review?

This is a **binary classification** task:
- Class 0: `NORMAL`
- Class 1: `PNEUMONIA`

Clinical framing:
- False negatives can delay treatment.
- False positives increase workload.
- The right threshold depends on the service context.

## Helper Functions
Run this cell once at the start. It auto-configures paths in Google Colab and does nothing harmful on local Jupyter.

In [None]:
import os
import sys
import subprocess
from pathlib import Path

def setup_repo_for_colab(
    repo_url='https://github.com/aaekay/Medical-AI-101.git',
    repo_dir='/content/Medical-AI-101',
    notebook_dir='chapters',
):
    if 'google.colab' not in sys.modules:
        print(f'Local runtime detected. Working directory: {Path.cwd()}')
        return

    repo_path = Path(repo_dir)
    if not repo_path.exists():
        print('Cloning Medical-AI-101 into /content ...')
        subprocess.check_call(['git', 'clone', repo_url, str(repo_path)])

    target = repo_path / notebook_dir
    os.chdir(target)
    print(f'Colab ready. Working directory: {Path.cwd()}')

setup_repo_for_colab()


In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

from IPython.display import display
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    average_precision_score,
    confusion_matrix,
    precision_recall_curve,
    roc_auc_score,
    roc_curve,
)

try:
    import ipywidgets as widgets
except ImportError as exc:
    raise ImportError('ipywidgets is required for interactive demos in Module 5A.') from exc


def resolve_cxr_root():
    candidates = [
        Path('../data/chest_xray_small'),
        Path('data/chest_xray_small'),
    ]
    for cand in candidates:
        if cand.exists():
            return cand
    return None


def build_manifest(root):
    rows = []
    for split in ['train', 'val', 'test']:
        for label_name, label in [('NORMAL', 0), ('PNEUMONIA', 1)]:
            folder = root / split / label_name
            if not folder.exists():
                continue
            for p in sorted(folder.glob('*')):
                if p.suffix.lower() in {'.jpeg', '.jpg', '.png'}:
                    rows.append({'path': str(p), 'split': split, 'label_name': label_name, 'label': label})
    return pd.DataFrame(rows)


CXR_ROOT = resolve_cxr_root()
if CXR_ROOT is None:
    print('Dataset folder not found. Set it up first, then rerun this notebook.')
    print('One-time command (from repo root): python scripts/setup_chest_xray_from_gdrive.py')
    print('See: data/chest_xray_small/README.md')
    manifest = pd.DataFrame(columns=['path', 'split', 'label_name', 'label'])
else:
    manifest = build_manifest(CXR_ROOT)
    print(f'Found chest X-ray root: {CXR_ROOT}')
    print(f'Manifest rows: {len(manifest)}')


In [None]:
if manifest.empty:
    print('No images discovered yet. This notebook will still show the method map and setup logic.')
else:
    display(manifest.head(10))
    summary = (
        manifest.groupby(['split', 'label_name'])
        .size()
        .rename('n_images')
        .reset_index()
    )
    display(summary)

    pivot = summary.pivot(index='split', columns='label_name', values='n_images').fillna(0)
    ax = pivot.plot(kind='bar', figsize=(7, 3.6), color=['#4c78a8', '#e45756'])
    ax.set_title('Class Balance by Split')
    ax.set_ylabel('Images')
    ax.tick_params(axis='x', rotation=0)
    plt.tight_layout()
    plt.show()


## Section 1: Method Map (Top-Down)
Different methods answer the same clinical question with different assumptions and trade-offs.

In [None]:
method_map = pd.DataFrame([
    {
        'Method family': 'Rule-based heuristics',
        'How it works': 'Hand-written imaging rules (e.g., brightness zones, edge patterns)',
        'Strength': 'Transparent and explainable',
        'Limitation': 'Brittle; poor generalization to varied scans',
        'Role in course': 'Conceptual baseline',
    },
    {
        'Method family': 'Traditional ML on handcrafted features',
        'How it works': 'Engineer features from images, then train logistic regression / tree',
        'Strength': 'Fast on CPU; useful baseline',
        'Limitation': 'Feature engineering limits performance ceiling',
        'Role in course': 'Practical baseline in this notebook',
    },
    {
        'Method family': 'Deep learning (CNNs)',
        'How it works': 'Learn image features directly from pixels',
        'Strength': 'High performance on visual tasks',
        'Limitation': 'Needs more data, compute, and careful evaluation',
        'Role in course': 'Main model in Module 5B',
    },
])

display(method_map)


## Section 2: Visual Sanity Check
Inspect sample images from each class before modeling. This catches many pipeline mistakes early.

In [None]:
def show_class_examples(manifest_df, n_per_class=4, split='train', random_state=42):
    subset = manifest_df[manifest_df['split'] == split].copy()
    if subset.empty:
        print(f'No rows found for split={split}.')
        return

    fig, axes = plt.subplots(2, n_per_class, figsize=(3 * n_per_class, 6))
    labels = [('NORMAL', 0), ('PNEUMONIA', 1)]

    for r, (label_name, label) in enumerate(labels):
        group = subset[subset['label'] == label]
        if group.empty:
            for c in range(n_per_class):
                axes[r, c].axis('off')
            continue
        sample = group.sample(n=min(n_per_class, len(group)), random_state=random_state)
        for c in range(n_per_class):
            axes[r, c].axis('off')
            if c < len(sample):
                img = Image.open(sample.iloc[c]['path']).convert('L')
                axes[r, c].imshow(img, cmap='gray')
                axes[r, c].set_title(label_name)

    plt.tight_layout()
    plt.show()


if manifest.empty:
    placeholder = Path('../assets/chest_xray_placeholder.png')
    if not placeholder.exists():
        placeholder = Path('assets/chest_xray_placeholder.png')
    if placeholder.exists():
        img = Image.open(placeholder).convert('L')
        plt.figure(figsize=(4, 4))
        plt.imshow(img, cmap='gray')
        plt.title('Placeholder image (dataset not loaded)')
        plt.axis('off')
        plt.show()
    else:
        print('No dataset and no placeholder image found.')
else:
    show_class_examples(manifest, n_per_class=4, split='train')


## Section 3: Handcrafted Feature Baseline
Before CNNs, we build a lightweight baseline to anchor expectations.

In [None]:
def extract_features(image_path, size=128):
    img = Image.open(image_path).convert('L').resize((size, size))
    arr = np.asarray(img, dtype=np.float32) / 255.0

    # Simple handcrafted radiographic proxies for teaching.
    mean_intensity = arr.mean()
    std_intensity = arr.std()
    p90 = np.percentile(arr, 90)
    center = arr[size // 4 : 3 * size // 4, size // 4 : 3 * size // 4].mean()
    periphery = np.concatenate([
        arr[: size // 6, :].ravel(),
        arr[-size // 6 :, :].ravel(),
        arr[:, : size // 6].ravel(),
        arr[:, -size // 6 :].ravel(),
    ]).mean()
    grad_x = np.abs(np.diff(arr, axis=1)).mean()
    grad_y = np.abs(np.diff(arr, axis=0)).mean()
    symmetry_gap = np.abs(arr - np.fliplr(arr)).mean()

    return {
        'mean_intensity': mean_intensity,
        'std_intensity': std_intensity,
        'p90': p90,
        'center_intensity': center,
        'periphery_intensity': periphery,
        'edge_x': grad_x,
        'edge_y': grad_y,
        'symmetry_gap': symmetry_gap,
    }


def make_feature_table(manifest_df):
    rows = []
    for _, row in manifest_df.iterrows():
        feats = extract_features(row['path'])
        feats.update({'split': row['split'], 'label': int(row['label']), 'path': row['path']})
        rows.append(feats)
    return pd.DataFrame(rows)


if manifest.empty:
    feature_df = pd.DataFrame()
    print('Feature extraction skipped because dataset is not loaded.')
else:
    feature_df = make_feature_table(manifest)
    print(f'Feature rows: {len(feature_df)}')
    display(feature_df.head())


In [None]:
if feature_df.empty:
    baseline_test = pd.DataFrame()
    print('Baseline training skipped: no feature table available.')
else:
    train_df = feature_df[feature_df['split'].isin(['train', 'val'])].copy()
    test_df = feature_df[feature_df['split'] == 'test'].copy()

    feature_cols = [
        'mean_intensity',
        'std_intensity',
        'p90',
        'center_intensity',
        'periphery_intensity',
        'edge_x',
        'edge_y',
        'symmetry_gap',
    ]

    X_train = train_df[feature_cols].values
    y_train = train_df['label'].values
    X_test = test_df[feature_cols].values
    y_test = test_df['label'].values

    baseline_model = LogisticRegression(max_iter=1000, class_weight='balanced', random_state=42)
    baseline_model.fit(X_train, y_train)

    test_proba = baseline_model.predict_proba(X_test)[:, 1]
    roc_auc = roc_auc_score(y_test, test_proba)
    pr_auc = average_precision_score(y_test, test_proba)

    print(f'Baseline ROC-AUC: {roc_auc:.3f}')
    print(f'Baseline PR-AUC:  {pr_auc:.3f}')

    baseline_test = test_df[['path', 'label']].copy()
    baseline_test['probability'] = test_proba

    out_path = (CXR_ROOT.parent if CXR_ROOT is not None else Path('../data')) / 'module_05a_baseline_test_predictions.csv'
    baseline_test.to_csv(out_path, index=False)
    print(f'Saved baseline predictions to {out_path}')


## Section 4: Why PR and ROC?
- **ROC-AUC**: Measures ranking quality across thresholds (TPR vs FPR).
- **PR-AUC**: Focuses on positive class retrieval (precision vs recall), often more informative when positives are less frequent.
- Use both, then pick thresholds based on workflow constraints.

In [None]:
if baseline_test.empty:
    print('No baseline predictions available for curve plotting.')
else:
    y_true = baseline_test['label'].values
    proba = baseline_test['probability'].values

    fpr, tpr, _ = roc_curve(y_true, proba)
    precision, recall, _ = precision_recall_curve(y_true, proba)

    fig, axes = plt.subplots(1, 2, figsize=(12, 4.2))

    axes[0].plot(fpr, tpr, color='#4c78a8', linewidth=2)
    axes[0].plot([0, 1], [0, 1], '--', color='gray', linewidth=1)
    axes[0].set_title(f'ROC Curve (AUC={roc_auc_score(y_true, proba):.2f})')
    axes[0].set_xlabel('False Positive Rate')
    axes[0].set_ylabel('True Positive Rate')

    axes[1].plot(recall, precision, color='#e45756', linewidth=2)
    axes[1].hlines(y_true.mean(), 0, 1, linestyles='--', color='gray', linewidth=1)
    axes[1].set_title(f'PR Curve (AP={average_precision_score(y_true, proba):.2f})')
    axes[1].set_xlabel('Recall')
    axes[1].set_ylabel('Precision')

    plt.tight_layout()
    plt.show()


In [None]:
def threshold_demo(threshold=0.50):
    if baseline_test.empty:
        print('No baseline predictions available.')
        return

    y_true = baseline_test['label'].values
    proba = baseline_test['probability'].values
    pred = (proba >= threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, pred, labels=[0, 1]).ravel()

    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    specificity = tn / (tn + fp) if (tn + fp) else 0.0

    print(f'Threshold: {threshold:.2f}')
    print(f'Precision: {precision:.3f} | Recall: {recall:.3f} | Specificity: {specificity:.3f}')

    counts = pd.Series({'TP': tp, 'TN': tn, 'FP': fp, 'FN': fn})
    fig, ax = plt.subplots(figsize=(6, 3.2))
    ax.bar(counts.index, counts.values, color=['#54a24b', '#4c78a8', '#f58518', '#e45756'])
    ax.set_title('Baseline Confusion Counts')
    ax.set_ylabel('Images')
    plt.tight_layout()
    plt.show()


widgets.interact(
    threshold_demo,
    threshold=widgets.FloatSlider(value=0.50, min=0.05, max=0.95, step=0.05, description='Threshold', continuous_update=False),
)


## Wrap-up and Handoff to Module 5B
- You framed the clinical problem first.
- You compared method families before committing to deep learning.
- You created a traditional ML baseline on real chest X-ray files.
- Next: build and train a CNN with explicit label handling, train/val split checks, and hyperparameter controls.