In [None]:
# Training Pipeline for Chest X-ray Multi-Label Classification
# Features:
# - Patient-level stratified train/val/test split
# - Data augmentation for training set
# - Custom PyTorch Dataset class for image-label pairing
# - Pretrained model loading and fine-tuning using timm
# - BCEWithLogitsLoss for multilabel classification
# - Learning rate scheduler and early stopping
# - Per-epoch tracking of training and validation loss
# - Computation of overall and per-class AUROC
# - Best model checkpointing based on validation AUROC
# - Logging all results and loss histories to disk
# - Saving loss curve plots per model
# - Train/Val AUROC
# - Train/Val loss tracking
# - Early stopping
# - Scheduler
# - Best model saving
# - Individual loss curve plots per model

import os
import numpy as np
import pandas as pd
from PIL import Image
import json
import matplotlib.pyplot as plt
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
import timm

# 1. Define class labels for 15 disease categories
CLASSES = [
    "No Finding", "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration",
    "Mass", "Nodule", "Pneumonia", "Pneumothorax", "Consolidation",
    "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"
]
num_classes = len(CLASSES)

# 2. Load dataset and apply patient-level stratified split
df = pd.read_csv("/student/csc490_project/shared/labels.csv")
df["label_list"] = df["Finding Labels"].apply(lambda x: x.split("|"))
mlb = MultiLabelBinarizer(classes=CLASSES)
df["labels"] = list(mlb.fit_transform(df["label_list"]))

# Ensure no patient leakage between splits
unique_patients = df["Patient ID"].unique()
np.random.seed(42)
np.random.shuffle(unique_patients)
train_end = int(0.7 * len(unique_patients))
val_end = int(0.8 * len(unique_patients))
train_patients = unique_patients[:train_end]
val_patients = unique_patients[train_end:val_end]
test_patients = unique_patients[val_end:]

assert set(train_patients).isdisjoint(val_patients)
assert set(train_patients).isdisjoint(test_patients)
assert set(val_patients).isdisjoint(test_patients)

# Subset dataframes by patient groups
train_df = df[df["Patient ID"].isin(train_patients)].reset_index(drop=True)
val_df = df[df["Patient ID"].isin(val_patients)].reset_index(drop=True)
test_df = df[df["Patient ID"].isin(test_patients)].reset_index(drop=True)

# 3. Define image transformations
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 4. Dataset class for loading X-ray images and labels
class ChestXrayDataset(Dataset):
    """
    PyTorch Dataset for Chest X-ray multi-label classification.

    Args:
        df (pd.DataFrame): Dataframe containing image paths and labels.
        root_dir (str): Path to the directory containing images.
        transform (callable, optional): Transformations to apply to images.
    """
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.df.iloc[idx]["Image Index"])
        image = Image.open(img_path).convert("L")  # Convert to grayscale
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(self.df.iloc[idx]["labels"], dtype=torch.float32)
        return image, label

# 5. Create dataloaders for train, validation, and test
data_root = "/student/csc490_project/shared/preprocessed_images/preprocessed_images"
train_loader = DataLoader(ChestXrayDataset(train_df, data_root, train_transform), batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(ChestXrayDataset(val_df, data_root, val_test_transform), batch_size=16, shuffle=False, num_workers=4)
test_loader = DataLoader(ChestXrayDataset(test_df, data_root, val_test_transform), batch_size=16, shuffle=False, num_workers=4)

# 6. Helper to load pretrained model

def get_model(model_name, num_classes):
    """
    Load a pretrained model from timm with the specified output size.

    Args:
        model_name (str): timm model identifier.
        num_classes (int): Number of output classes.

    Returns:
        nn.Module: Model with modified classification head.
    """
    return timm.create_model(model_name, pretrained=True, num_classes=num_classes)

# 7. AUROC computation and evaluation routines

def compute_auroc(y_true, y_pred):
    """
    Compute mean and per-class AUROC.

    Args:
        y_true (list): Ground truth binary labels.
        y_pred (list): Predicted probabilities.

    Returns:
        tuple: Mean AUROC, per-class AUROCs
    """
    y_true = np.vstack(y_true)
    y_pred = np.vstack(y_pred)
    try:
        per_class_auroc = roc_auc_score(y_true, y_pred, average=None)
        return np.mean(per_class_auroc), per_class_auroc
    except:
        return 0.0, [0.0] * num_classes

def evaluate(model, device, loader):
    """
    Evaluate model using AUROC on given DataLoader.

    Args:
        model (nn.Module): Trained model.
        device (torch.device): Torch device.
        loader (DataLoader): Dataset loader.

    Returns:
        tuple: Mean AUROC, per-class AUROCs
    """
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            outputs = torch.sigmoid(model(imgs)).cpu().numpy()
            all_preds.append(outputs)
            all_labels.append(labels.numpy())
    return compute_auroc(all_labels, all_preds)

# 8. Training loop with early stopping and checkpointing

def train_one_model(model_name, num_epochs=25, early_stopping_patience=2):
    """
    Train a single model with early stopping and save best checkpoint.

    Args:
        model_name (str): Name of model to train.
        num_epochs (int): Number of training epochs.
        early_stopping_patience (int): Patience for early stopping.

    Returns:
        tuple: (val_auroc, test_auroc, val_aurocs, test_aurocs, model_path, final_train_loss, train_losses, val_losses, val_aurocs_epoch)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model(model_name, num_classes).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)

    best_val_auroc = 0
    epochs_since_improvement = 0
    best_model_path = None
    train_losses = []
    val_aurocs_epoch = []
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        num_batches = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            num_batches += 1
        epoch_loss /= num_batches
        train_losses.append(epoch_loss)

        val_auroc, _ = evaluate(model, device, val_loader)
        val_aurocs_epoch.append(val_auroc)

        # Validation loss
        val_epoch_loss = 0
        num_val_batches = 0
        model.eval()
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                val_epoch_loss += loss.item()
                num_val_batches += 1
        val_epoch_loss /= num_val_batches
        val_losses.append(val_epoch_loss)

        scheduler.step(val_auroc)
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_loss:.4f} - Val Loss: {val_epoch_loss:.4f} - Val AUROC: {val_auroc:.4f}")

        if val_auroc > best_val_auroc:
            best_val_auroc = val_auroc
            epochs_since_improvement = 0
            os.makedirs("/student/csc490_project/shared/training", exist_ok=True)
            os.makedirs("/student/csc490_project/shared/training/results", exist_ok=True)
            best_model_path = f"/student/csc490_project/shared/training/{model_name.replace('/', '_')}.pt"
            torch.save(model.state_dict(), best_model_path)
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= early_stopping_patience:
                break

    val_auroc, val_aurocs = evaluate(model, device, val_loader)
    test_auroc, test_aurocs = evaluate(model, device, test_loader)
    return val_auroc, test_auroc, val_aurocs, test_aurocs, best_model_path, train_losses[-1], train_losses, val_losses, val_aurocs_epoch

# 9. Train each model and log results
model_names = [
    'coatnet_2_rw_224.sw_in12k_ft_in1k',
    'convnext_large.fb_in22k',
    'densenet121',
    'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k',
    'swin_large_patch4_window7_224',
    'vgg19.tv_in1k'
]

results = {}
for model_name in model_names:
    print(f"Training {model_name}...")
    val_auroc, test_auroc, val_aurocs, test_aurocs, model_path, final_loss, train_loss_history, val_loss_history, val_aurocs_epoch = train_one_model(model_name)
    model_results = {
        "val_auroc": val_auroc,
        "val_aurocs_epoch": val_aurocs_epoch,
        "test_auroc": test_auroc,
        "model_path": model_path,
        "final_train_loss": final_loss,
        "train_loss_history": train_loss_history,
        "val_loss_history": val_loss_history
    }
    for i, cls in enumerate(CLASSES):
        model_results[f"val_{cls}"] = val_aurocs[i]
        model_results[f"test_{cls}"] = test_aurocs[i]
    results[model_name] = model_results

    pd.DataFrame([model_results]).to_csv(f"/student/csc490_project/shared/training/results/{model_name.replace('/', '_')}.csv", index=False)

# Save combined results and plots
pd.DataFrame(results).T.to_csv("/student/csc490_project/shared/training/model_aurocs_per_class.csv")

with open("/student/csc490_project/shared/training/all_loss_histories.json", "w") as f:
    json.dump({k: {"train": v["train_loss_history"], "val": v["val_loss_history"], "val_auroc": v["val_aurocs_epoch"]} for k, v in results.items()}, f)

os.makedirs("/student/csc490_project/shared/training/loss_plots", exist_ok=True)
with open("/student/csc490_project/shared/training/all_loss_histories.json", "r") as f:
    loss_data = json.load(f)

# Plot loss and AUROC curves for each model
for model_name, losses in loss_data.items():
    plt.figure(figsize=(8, 5))
    plt.plot(losses["train"], label="Train Loss", linewidth=2)
    plt.plot(losses["val"], label="Val Loss", linewidth=2)
    plt.plot(losses["val_auroc"], label="Val AUROC", linewidth=2)
    plt.title(f"{model_name} - Loss Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"/student/csc490_project/shared/training/loss_plots/{model_name.replace('/', '_')}_loss_curve.png")
    plt.close()

print("Saved loss curves, model weights, per-model and combined CSVs, and performance metrics to '/student/csc490_project/shared/training/'")



Training coatnet_2_rw_224.sw_in12k_ft_in1k...
Epoch 1/25 - Train Loss: 0.1997 - Val Loss: 0.1883 - Val AUROC: 0.7558
Epoch 2/25 - Train Loss: 0.1872 - Val Loss: 0.1943 - Val AUROC: 0.7865
Epoch 3/25 - Train Loss: 0.1821 - Val Loss: 0.1763 - Val AUROC: 0.8143
Epoch 4/25 - Train Loss: 0.1788 - Val Loss: 0.1763 - Val AUROC: 0.8139
Epoch 5/25 - Train Loss: 0.1768 - Val Loss: 0.1753 - Val AUROC: 0.8197
Epoch 6/25 - Train Loss: 0.1791 - Val Loss: 0.1768 - Val AUROC: 0.8170
Epoch 7/25 - Train Loss: 0.1735 - Val Loss: 0.1756 - Val AUROC: 0.8223
Epoch 8/25 - Train Loss: 0.1725 - Val Loss: 0.1752 - Val AUROC: 0.8276
Epoch 9/25 - Train Loss: 0.1709 - Val Loss: 0.1744 - Val AUROC: 0.8286
Epoch 10/25 - Train Loss: 0.1693 - Val Loss: 0.1725 - Val AUROC: 0.8316
Epoch 11/25 - Train Loss: 0.1676 - Val Loss: 0.1719 - Val AUROC: 0.8318
Epoch 12/25 - Train Loss: 0.1661 - Val Loss: 0.1706 - Val AUROC: 0.8351
Epoch 13/25 - Train Loss: 0.1643 - Val Loss: 0.1710 - Val AUROC: 0.8323
Epoch 14/25 - Train Loss: 0

model.safetensors:   0%|          | 0.00/919M [00:00<?, ?B/s]

Epoch 1/25 - Train Loss: 0.1855 - Val Loss: 0.1761 - Val AUROC: 0.8068
Epoch 2/25 - Train Loss: 0.1727 - Val Loss: 0.1722 - Val AUROC: 0.8269
Epoch 3/25 - Train Loss: 0.1648 - Val Loss: 0.1726 - Val AUROC: 0.8306
Epoch 4/25 - Train Loss: 0.1544 - Val Loss: 0.1756 - Val AUROC: 0.8276
Epoch 5/25 - Train Loss: 0.1382 - Val Loss: 0.1868 - Val AUROC: 0.8162
Training densenet121...


model.safetensors:   0%|          | 0.00/32.3M [00:00<?, ?B/s]

Epoch 1/25 - Train Loss: 0.1935 - Val Loss: 0.1785 - Val AUROC: 0.8025
Epoch 2/25 - Train Loss: 0.1801 - Val Loss: 0.1758 - Val AUROC: 0.8147
Epoch 3/25 - Train Loss: 0.1760 - Val Loss: 0.1739 - Val AUROC: 0.8205
Epoch 4/25 - Train Loss: 0.1727 - Val Loss: 0.1762 - Val AUROC: 0.8212
Epoch 5/25 - Train Loss: 0.1697 - Val Loss: 0.1739 - Val AUROC: 0.8246
Epoch 6/25 - Train Loss: 0.1672 - Val Loss: 0.1729 - Val AUROC: 0.8255
Epoch 7/25 - Train Loss: 0.1643 - Val Loss: 0.1756 - Val AUROC: 0.8246
Epoch 8/25 - Train Loss: 0.1615 - Val Loss: 0.1773 - Val AUROC: 0.8230
Training maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k...


model.safetensors:   0%|          | 0.00/465M [00:00<?, ?B/s]

Epoch 1/25 - Train Loss: 0.1903 - Val Loss: 0.1779 - Val AUROC: 0.8085
Epoch 2/25 - Train Loss: 0.1787 - Val Loss: 0.1802 - Val AUROC: 0.8060
Epoch 3/25 - Train Loss: 0.1738 - Val Loss: 0.1759 - Val AUROC: 0.8152
Epoch 4/25 - Train Loss: 0.1700 - Val Loss: 0.1724 - Val AUROC: 0.8263
Epoch 5/25 - Train Loss: 0.1662 - Val Loss: 0.1712 - Val AUROC: 0.8330
Epoch 6/25 - Train Loss: 0.1625 - Val Loss: 0.1738 - Val AUROC: 0.8287
Epoch 7/25 - Train Loss: 0.1585 - Val Loss: 0.1756 - Val AUROC: 0.8262
Training swin_large_patch4_window7_224...


model.safetensors:   0%|          | 0.00/788M [00:00<?, ?B/s]

Epoch 1/25 - Train Loss: 0.1910 - Val Loss: 0.1783 - Val AUROC: 0.8016
Epoch 2/25 - Train Loss: 0.1800 - Val Loss: 0.1766 - Val AUROC: 0.8125
Epoch 3/25 - Train Loss: 0.1755 - Val Loss: 0.1755 - Val AUROC: 0.8177
Epoch 4/25 - Train Loss: 0.1719 - Val Loss: 0.1734 - Val AUROC: 0.8248
Epoch 5/25 - Train Loss: 0.1689 - Val Loss: 0.1726 - Val AUROC: 0.8242
Epoch 6/25 - Train Loss: 0.1655 - Val Loss: 0.1746 - Val AUROC: 0.8244
Training vgg19.tv_in1k...


model.safetensors:   0%|          | 0.00/575M [00:00<?, ?B/s]

Epoch 1/25 - Train Loss: 0.2006 - Val Loss: 0.1928 - Val AUROC: 0.7319
Epoch 2/25 - Train Loss: 0.1921 - Val Loss: 0.1855 - Val AUROC: 0.7698
Epoch 3/25 - Train Loss: 0.1869 - Val Loss: 0.1858 - Val AUROC: 0.7650
Epoch 4/25 - Train Loss: 0.1835 - Val Loss: 0.1802 - Val AUROC: 0.7929
Epoch 5/25 - Train Loss: 0.1810 - Val Loss: 0.1805 - Val AUROC: 0.7923
Epoch 6/25 - Train Loss: 0.1789 - Val Loss: 0.1791 - Val AUROC: 0.7970
Epoch 7/25 - Train Loss: 0.1771 - Val Loss: 0.1774 - Val AUROC: 0.7984
Epoch 8/25 - Train Loss: 0.1757 - Val Loss: 0.1793 - Val AUROC: 0.7977
Epoch 9/25 - Train Loss: 0.1741 - Val Loss: 0.1786 - Val AUROC: 0.8058
Epoch 10/25 - Train Loss: 0.1725 - Val Loss: 0.1788 - Val AUROC: 0.8068
Epoch 11/25 - Train Loss: 0.1709 - Val Loss: 0.1777 - Val AUROC: 0.8113
Epoch 12/25 - Train Loss: 0.1693 - Val Loss: 0.1762 - Val AUROC: 0.8114
Epoch 13/25 - Train Loss: 0.1677 - Val Loss: 0.1757 - Val AUROC: 0.8149
Epoch 14/25 - Train Loss: 0.1663 - Val Loss: 0.1776 - Val AUROC: 0.8146
E

: 