# Module 5B: Labels, Splits, and Deep Learning Training
## Building a Chest X-ray CNN End to End

**Goal:** Move from problem framing to a trainable CNN workflow with real image files, clean split logic, and reproducible experiments.

### Learning objectives
1. Validate folder-based labels and inspect subtype artifacts in filenames.
2. Build a CPU-friendly subset from a real chest X-ray dataset.
3. Create PyTorch datasets and dataloaders without hidden magic.
4. Train a compact CNN and track ROC-AUC + PR-AUC on validation data.
5. Run small hyperparameter experiments and save outputs for Module 5C.

## Section 0: Dataset Choice and Setup
Recommended dataset: **Chest X-Ray Images (Pneumonia)** (Kermany et al.).

Expected structure:
- `data/chest_xray_small/train/NORMAL`, `.../train/PNEUMONIA`
- `data/chest_xray_small/val/NORMAL`, `.../val/PNEUMONIA`
- `data/chest_xray_small/test/NORMAL`, `.../test/PNEUMONIA`

If your dataset is larger, this notebook downsamples to a balanced subset.

In [None]:
from pathlib import Path
import time
import copy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

from IPython.display import display

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import average_precision_score, roc_auc_score

try:
    import ipywidgets as widgets
except ImportError as exc:
    raise ImportError('ipywidgets is required for Module 5B interactive controls.') from exc


def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)


set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


def resolve_cxr_root():
    candidates = [Path('../data/chest_xray_small'), Path('data/chest_xray_small')]
    for cand in candidates:
        if cand.exists():
            return cand
    return None


def build_manifest(root):
    rows = []
    for split in ['train', 'val', 'test']:
        for label_name, label in [('NORMAL', 0), ('PNEUMONIA', 1)]:
            folder = root / split / label_name
            if not folder.exists():
                continue
            for p in sorted(folder.glob('*')):
                if p.suffix.lower() in {'.jpeg', '.jpg', '.png'}:
                    rows.append({'path': str(p), 'split': split, 'label_name': label_name, 'label': label})
    return pd.DataFrame(rows)


CXR_ROOT = resolve_cxr_root()
if CXR_ROOT is None:
    DATA_READY = False
    manifest = pd.DataFrame(columns=['path', 'split', 'label_name', 'label'])
    print('Dataset not found. Add files under data/chest_xray_small and rerun.')
    print('See setup notes in data/chest_xray_small/README.md')
else:
    DATA_READY = True
    manifest = build_manifest(CXR_ROOT)
    print(f'Loaded manifest with {len(manifest)} rows from {CXR_ROOT}')


In [None]:
if not DATA_READY or manifest.empty:
    print('No manifest to inspect yet.')
else:
    display(manifest.head(8))

    counts = manifest.groupby(['split', 'label_name']).size().rename('n_images').reset_index()
    display(counts)

    pivot = counts.pivot(index='split', columns='label_name', values='n_images').fillna(0)
    ax = pivot.plot(kind='bar', figsize=(7, 3.5), color=['#4c78a8', '#e45756'])
    ax.set_title('Raw Class Balance by Split')
    ax.set_ylabel('Images')
    ax.tick_params(axis='x', rotation=0)
    plt.tight_layout()
    plt.show()


## Section 1: Label Audit
Even with folder labels, inspect filenames for potential subtype hints (e.g., bacterial vs viral).
This helps plan future multi-label or multi-class extensions.

In [None]:
if not DATA_READY or manifest.empty:
    print('Label audit skipped: dataset missing.')
else:
    lower_path = manifest['path'].str.lower()
    manifest['pneumonia_subtype_hint'] = np.where(
        manifest['label_name'] == 'PNEUMONIA',
        np.where(lower_path.str.contains('virus'), 'virus', np.where(lower_path.str.contains('bacteria'), 'bacteria', 'unspecified')),
        'normal',
    )

    subtype_counts = (
        manifest[manifest['label_name'] == 'PNEUMONIA']
        .groupby(['split', 'pneumonia_subtype_hint'])
        .size()
        .rename('n_images')
        .reset_index()
    )
    display(subtype_counts)


## Section 2: Build a Balanced, CPU-Friendly Subset
Target subset per class (modifiable):
- Train: up to 600
- Val: up to 150
- Test: up to 150

Balanced subsets make teaching and debugging easier.

In [None]:
def build_balanced_subset(manifest_df, per_class=None, random_state=42):
    if per_class is None:
        per_class = {'train': 600, 'val': 150, 'test': 150}

    sampled = []
    for split, target in per_class.items():
        for label in [0, 1]:
            group = manifest_df[(manifest_df['split'] == split) & (manifest_df['label'] == label)]
            if group.empty:
                continue
            take = min(target, len(group))
            sampled.append(group.sample(n=take, random_state=random_state))

    if not sampled:
        return pd.DataFrame(columns=manifest_df.columns)

    out = pd.concat(sampled, ignore_index=True)
    return out.sample(frac=1.0, random_state=random_state).reset_index(drop=True)


if not DATA_READY or manifest.empty:
    subset_manifest = pd.DataFrame(columns=manifest.columns)
    print('Subset creation skipped: dataset missing.')
else:
    subset_manifest = build_balanced_subset(manifest)
    display(
        subset_manifest.groupby(['split', 'label_name'])
        .size()
        .rename('n_images')
        .reset_index()
    )

    out_dir = CXR_ROOT.parent
    subset_manifest.to_csv(out_dir / 'module_05_manifest_subset.csv', index=False)
    for split in ['train', 'val', 'test']:
        subset_manifest[subset_manifest['split'] == split].to_csv(out_dir / f'module_05_{split}_manifest.csv', index=False)
    print(f'Saved subset manifests in {out_dir}')


## Section 3: PyTorch Dataset and DataLoader
We avoid hidden abstractions and implement a small custom dataset class.

In [None]:
class ChestXrayDataset(Dataset):
    def __init__(self, frame, img_size=224, augment=False):
        self.frame = frame.reset_index(drop=True)
        self.img_size = img_size
        self.augment = augment

    def __len__(self):
        return len(self.frame)

    def __getitem__(self, idx):
        row = self.frame.iloc[idx]
        img = Image.open(row['path']).convert('L').resize((self.img_size, self.img_size))
        arr = np.asarray(img, dtype=np.float32) / 255.0

        if self.augment and np.random.rand() < 0.5:
            arr = np.fliplr(arr).copy()

        tensor = torch.from_numpy(arr).unsqueeze(0)

        # Fixed normalization for teaching simplicity.
        tensor = (tensor - 0.5) / 0.25

        label = torch.tensor(float(row['label']), dtype=torch.float32)
        return tensor, label, row['path']


def make_loaders(subset_df, batch_size=16, img_size=224):
    train_df = subset_df[subset_df['split'] == 'train'].copy()
    val_df = subset_df[subset_df['split'] == 'val'].copy()
    test_df = subset_df[subset_df['split'] == 'test'].copy()

    train_ds = ChestXrayDataset(train_df, img_size=img_size, augment=True)
    val_ds = ChestXrayDataset(val_df, img_size=img_size, augment=False)
    test_ds = ChestXrayDataset(test_df, img_size=img_size, augment=False)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader


if subset_manifest.empty:
    print('DataLoader setup pending dataset availability.')
else:
    train_loader, val_loader, test_loader = make_loaders(subset_manifest, batch_size=16, img_size=224)
    xb, yb, _ = next(iter(train_loader))
    print('Batch tensor shape:', tuple(xb.shape), '| labels shape:', tuple(yb.shape))


## Section 4: Define a Compact CNN
This network is intentionally small so students can train on CPU.

In [None]:
class SmallCXRNet(nn.Module):
    def __init__(self, dropout=0.3):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 28 * 28, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1),
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x).squeeze(1)


model = SmallCXRNet(dropout=0.3)
print(model)


## Section 5: Training and Validation Loops
Track both loss and ranking metrics (ROC-AUC, PR-AUC) each epoch.

In [None]:
def evaluate_model(model, loader, criterion):
    model.eval()
    losses = []
    all_probs = []
    all_labels = []
    all_paths = []

    with torch.no_grad():
        for xb, yb, paths in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)
            probs = torch.sigmoid(logits)

            losses.append(loss.item())
            all_probs.extend(probs.cpu().numpy().tolist())
            all_labels.extend(yb.cpu().numpy().tolist())
            all_paths.extend(paths)

    y_true = np.array(all_labels, dtype=int)
    y_prob = np.array(all_probs, dtype=float)

    roc_auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else 0.0
    pr_auc = average_precision_score(y_true, y_prob) if len(np.unique(y_true)) > 1 else 0.0

    return {
        'loss': float(np.mean(losses)) if losses else np.nan,
        'roc_auc': float(roc_auc),
        'pr_auc': float(pr_auc),
        'y_true': y_true,
        'y_prob': y_prob,
        'paths': all_paths,
    }


def train_model(train_loader, val_loader, learning_rate=1e-3, weight_decay=1e-4, dropout=0.3, epochs=5):
    model = SmallCXRNet(dropout=dropout).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    history_rows = []
    best_state = None
    best_val_pr = -1.0

    for epoch in range(1, epochs + 1):
        model.train()
        train_losses = []

        for xb, yb, _ in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)

            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        val_result = evaluate_model(model, val_loader, criterion)
        train_loss = float(np.mean(train_losses)) if train_losses else np.nan

        history_rows.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_result['loss'],
            'val_roc_auc': val_result['roc_auc'],
            'val_pr_auc': val_result['pr_auc'],
        })

        print(
            f"Epoch {epoch:02d} | train_loss={train_loss:.4f} "
            f"| val_loss={val_result['loss']:.4f} | val_roc_auc={val_result['roc_auc']:.3f} "
            f"| val_pr_auc={val_result['pr_auc']:.3f}"
        )

        if val_result['pr_auc'] > best_val_pr:
            best_val_pr = val_result['pr_auc']
            best_state = copy.deepcopy(model.state_dict())

    history = pd.DataFrame(history_rows)
    if best_state is not None:
        model.load_state_dict(best_state)

    return model, history


## Section 6: Interactive Hyperparameter Runs
Use `Run Interact` to launch a short experiment.
Saved artifacts:
- `data/module_05b_training_history.csv`
- `data/module_05b_val_predictions.csv`
- `data/module_05b_test_predictions.csv`
- `data/module_05b_best_model.pt`

In [None]:
def save_predictions_csv(path, eval_dict):
    out = pd.DataFrame({
        'path': eval_dict['paths'],
        'label': eval_dict['y_true'],
        'probability': eval_dict['y_prob'],
    })
    out.to_csv(path, index=False)


def run_experiment(batch_size=16, img_size=224, learning_rate=1e-3, weight_decay=1e-4, dropout=0.3, epochs=5, seed=42):
    if subset_manifest.empty:
        print('No dataset available. Add chest X-ray files first.')
        return

    set_seed(seed)
    train_loader, val_loader, test_loader = make_loaders(subset_manifest, batch_size=int(batch_size), img_size=int(img_size))

    start = time.time()
    model, history = train_model(
        train_loader,
        val_loader,
        learning_rate=float(learning_rate),
        weight_decay=float(weight_decay),
        dropout=float(dropout),
        epochs=int(epochs),
    )
    elapsed = time.time() - start

    criterion = nn.BCEWithLogitsLoss()
    val_eval = evaluate_model(model, val_loader, criterion)
    test_eval = evaluate_model(model, test_loader, criterion)

    run_id = f"run_bs{batch_size}_lr{learning_rate}_do{dropout}_ep{epochs}_seed{seed}"

    history = history.copy()
    history['run_id'] = run_id
    history['batch_size'] = int(batch_size)
    history['img_size'] = int(img_size)
    history['learning_rate'] = float(learning_rate)
    history['weight_decay'] = float(weight_decay)
    history['dropout'] = float(dropout)
    history['epochs'] = int(epochs)
    history['elapsed_sec'] = float(elapsed)

    out_dir = CXR_ROOT.parent if CXR_ROOT is not None else Path('../data')
    hist_path = out_dir / 'module_05b_training_history.csv'
    val_path = out_dir / 'module_05b_val_predictions.csv'
    test_path = out_dir / 'module_05b_test_predictions.csv'
    model_path = out_dir / 'module_05b_best_model.pt'

    if hist_path.exists():
        prev = pd.read_csv(hist_path)
        all_hist = pd.concat([prev, history], ignore_index=True)
    else:
        all_hist = history
    all_hist.to_csv(hist_path, index=False)

    save_predictions_csv(val_path, val_eval)
    save_predictions_csv(test_path, test_eval)
    torch.save({'model_state_dict': model.state_dict(), 'run_id': run_id}, model_path)

    print(f'Run ID: {run_id}')
    print(f'Training time: {elapsed:.1f} sec')
    print(f'Validation ROC-AUC: {val_eval["roc_auc"]:.3f} | PR-AUC: {val_eval["pr_auc"]:.3f}')
    print(f'Test ROC-AUC:       {test_eval["roc_auc"]:.3f} | PR-AUC: {test_eval["pr_auc"]:.3f}')
    print(f'Saved history to: {hist_path}')
    print(f'Saved test predictions to: {test_path}')

    fig, ax = plt.subplots(figsize=(7, 3.5))
    ax.plot(history['epoch'], history['train_loss'], marker='o', label='Train loss')
    ax.plot(history['epoch'], history['val_loss'], marker='o', label='Val loss')
    ax.set_title('Loss Curves')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend()
    plt.tight_layout()
    plt.show()


widgets.interact_manual(
    run_experiment,
    batch_size=widgets.IntSlider(value=16, min=8, max=32, step=8, description='Batch'),
    img_size=widgets.IntSlider(value=224, min=128, max=256, step=32, description='Img'),
    learning_rate=widgets.FloatLogSlider(value=1e-3, base=10, min=-4, max=-2, step=0.2, description='LR'),
    weight_decay=widgets.FloatLogSlider(value=1e-4, base=10, min=-6, max=-3, step=0.2, description='WD'),
    dropout=widgets.FloatSlider(value=0.3, min=0.0, max=0.6, step=0.1, description='Dropout'),
    epochs=widgets.IntSlider(value=5, min=2, max=10, step=1, description='Epochs'),
    seed=widgets.IntSlider(value=42, min=1, max=999, step=1, description='Seed'),
)


## Section 7: Recommended Starter Runs
Try these runs in order:
1. `batch=16`, `lr=1e-3`, `dropout=0.3`, `epochs=5`
2. `batch=16`, `lr=5e-4`, `dropout=0.3`, `epochs=8`
3. `batch=32`, `lr=1e-3`, `dropout=0.2`, `epochs=6`

Then compare results in Module 5C.

## Wrap-up and Handoff to Module 5C
- You validated labels and dataset splits.
- You trained a compact CNN and tracked validation PR-AUC/ROC-AUC.
- You saved prediction files and training logs.
- Next: choose thresholds, compare PR vs ROC in context, and pick operating points using policy constraints.