In [3]:
# Chest X-ray Multi-Label Classification with Gamma Correction Variants
# Key Features:
# - Patient-level stratified split
# - Gamma correction [0.6, 0.8, 1.0, 1.2]
# - Data augmentation for training
# - BCEWithLogitsLoss
# - Early stopping, LR scheduler, AUROC evaluation
# - Best model saving, per-class results, loss curve plotting

import os
import numpy as np
import pandas as pd
from PIL import Image
import json
import matplotlib.pyplot as plt
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.transforms import functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import timm

# Define class labels
CLASSES = [
    "No Finding", "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration",
    "Mass", "Nodule", "Pneumonia", "Pneumothorax", "Consolidation",
    "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"
]
num_classes = len(CLASSES)

# Load dataset and encode multilabel targets
df = pd.read_csv("/student/csc490_project/shared/labels.csv")
df["label_list"] = df["Finding Labels"].apply(lambda x: x.split("|"))
mlb = MultiLabelBinarizer(classes=CLASSES)
df["labels"] = list(mlb.fit_transform(df["label_list"]))

# Patient-level stratified split to avoid data leakage
unique_patients = df["Patient ID"].unique()
np.random.seed(42)
np.random.shuffle(unique_patients)
train_end = int(0.7 * len(unique_patients))
val_end = int(0.8 * len(unique_patients))
train_patients = unique_patients[:train_end]
val_patients = unique_patients[train_end:val_end]
test_patients = unique_patients[val_end:]

train_df = df[df["Patient ID"].isin(train_patients)].reset_index(drop=True)
val_df = df[df["Patient ID"].isin(val_patients)].reset_index(drop=True)
test_df = df[df["Patient ID"].isin(test_patients)].reset_index(drop=True)

# Validation/test transform without augmentation
val_test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def get_train_transform(gamma):
    """
    Returns the training transform pipeline with gamma correction and augmentations.

    Args:
        gamma (float): Gamma correction value.

    Returns:
        torchvision.transforms.Compose: Transform pipeline
    """
    return transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Lambda(lambda img: F.adjust_gamma(img, gamma)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# Dataset class to load images and their multilabel targets
class ChestXrayDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        """
        Args:
            df (pd.DataFrame): DataFrame with image paths and labels.
            root_dir (str): Path to image directory.
            transform (callable): Transformations to apply to images.
        """
        self.df = df
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.df.iloc[idx]["Image Index"])
        image = Image.open(img_path).convert("L")
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(self.df.iloc[idx]["labels"], dtype=torch.float32)
        return image, label

def get_model(model_name, num_classes):
    """
    Loads a pretrained model from timm and modifies its output layer.

    Args:
        model_name (str): timm model identifier.
        num_classes (int): Number of output classes.

    Returns:
        nn.Module: Model instance
    """
    return timm.create_model(model_name, pretrained=True, num_classes=num_classes)

def compute_auroc(y_true, y_pred):
    """
    Computes per-class and mean AUROC.

    Args:
        y_true (list): Ground truth binary labels.
        y_pred (list): Predicted probabilities.

    Returns:
        tuple: (mean_auroc, list of per-class AUROC)
    """
    y_true = np.vstack(y_true)
    y_pred = np.vstack(y_pred)
    try:
        per_class_auroc = roc_auc_score(y_true, y_pred, average=None)
        return np.mean(per_class_auroc), per_class_auroc
    except:
        return 0.0, [0.0] * num_classes

def evaluate(model, device, loader):
    """
    Evaluates a model on a DataLoader using sigmoid and AUROC.

    Args:
        model (nn.Module): Model to evaluate.
        device (torch.device): CUDA or CPU.
        loader (DataLoader): DataLoader for evaluation.

    Returns:
        tuple: (mean_auroc, list of per-class AUROC)
    """
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            outputs = torch.sigmoid(model(imgs)).cpu().numpy()
            all_preds.append(outputs)
            all_labels.append(labels.numpy())
    return compute_auroc(all_labels, all_preds)

def train_one_model(model_name, gamma, num_epochs=25, early_stopping_patience=2):
    """
    Trains a single model with early stopping and AUROC tracking.

    Args:
        model_name (str): timm model name.
        gamma (float): Gamma value for gamma correction.
        num_epochs (int): Max training epochs.
        early_stopping_patience (int): Early stopping patience.

    Returns:
        tuple: (val_auroc, test_auroc, val_aurocs, test_aurocs, model_path, final_train_loss, train_losses, val_losses, val_aurocs_epoch)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model(model_name, num_classes).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)

    best_val_auroc = 0
    epochs_since_improvement = 0
    best_model_path = None
    train_losses = []
    val_losses = []
    val_aurocs_epoch = []

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= len(train_loader)
        train_losses.append(epoch_loss)

        val_auroc, _ = evaluate(model, device, val_loader)
        val_aurocs_epoch.append(val_auroc)

        val_epoch_loss = 0
        model.eval()
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                val_epoch_loss += loss.item()
        val_epoch_loss /= len(val_loader)
        val_losses.append(val_epoch_loss)

        scheduler.step(val_auroc)

        print(f"Epoch {epoch+1}/{num_epochs} - Gamma: {gamma} - Train Loss: {epoch_loss:.4f} - Val Loss: {val_epoch_loss:.4f} - Val AUROC: {val_auroc:.4f}")

        if val_auroc > best_val_auroc:
            best_val_auroc = val_auroc
            epochs_since_improvement = 0
            model_path = f"{gamma_dir}/{model_name.replace('/', '_')}.pt"
            torch.save(model.state_dict(), model_path)
        else:
            epochs_since_improvement += 1
            if epochs_since_improvement >= early_stopping_patience:
                break

    val_auroc, val_aurocs = evaluate(model, device, val_loader)
    test_auroc, test_aurocs = evaluate(model, device, test_loader)
    return val_auroc, test_auroc, val_aurocs, test_aurocs, model_path, train_losses[-1], train_losses, val_losses, val_aurocs_epoch

# Main training loop with gamma correction variants
# Saves results, checkpoints, and plots per gamma value

data_root = "/student/csc490_project/shared/preprocessed_images/preprocessed_images"
model_names = [
    'coatnet_2_rw_224.sw_in12k_ft_in1k',
    'convnext_large.fb_in22k',
    'densenet121',
    'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k',
    'swin_large_patch4_window7_224',
    'vgg19.tv_in1k'
]

gamma_values = [0.6, 0.8, 1.0, 1.2]

for gamma in gamma_values:
    print(f"Starting training for gamma={gamma}")
    train_transform = get_train_transform(gamma)

    # Dataloaders with gamma-specific train transform
    train_loader = DataLoader(ChestXrayDataset(train_df, data_root, train_transform), batch_size=16, shuffle=True, num_workers=4)
    val_loader = DataLoader(ChestXrayDataset(val_df, data_root, val_test_transform), batch_size=16, shuffle=False, num_workers=4)
    test_loader = DataLoader(ChestXrayDataset(test_df, data_root, val_test_transform), batch_size=16, shuffle=False, num_workers=4)

    gamma_dir = f"/student/csc490_project/shared/training_gamma/training_gamma_{gamma}"
    os.makedirs(gamma_dir, exist_ok=True)
    os.makedirs(f"{gamma_dir}/results", exist_ok=True)
    os.makedirs(f"{gamma_dir}/loss_plots", exist_ok=True)

    results = {}
    for model_name in model_names:
        print(f"Training {model_name} with gamma={gamma}...")
        val_auroc, test_auroc, val_aurocs, test_aurocs, model_path, final_loss, train_loss_history, val_loss_history, val_aurocs_epoch = train_one_model(model_name, gamma)

        # Save training results and metrics
        model_results = {
            "val_auroc": val_auroc,
            "val_aurocs_epoch": val_aurocs_epoch,
            "test_auroc": test_auroc,
            "model_path": model_path,
            "final_train_loss": final_loss,
            "train_loss_history": train_loss_history,
            "val_loss_history": val_loss_history
        }
        for i, cls in enumerate(CLASSES):
            model_results[f"val_{cls}"] = val_aurocs[i]
            model_results[f"test_{cls}"] = test_aurocs[i]
        results[model_name] = model_results

        pd.DataFrame([model_results]).to_csv(f"{gamma_dir}/results/{model_name.replace('/', '_')}.csv", index=False)

    pd.DataFrame(results).T.to_csv(f"{gamma_dir}/model_aurocs_per_class.csv")

    with open(f"{gamma_dir}/all_loss_histories.json", "w") as f:
        json.dump({k: {"train": v["train_loss_history"], "val": v["val_loss_history"], "val_auroc": v["val_aurocs_epoch"]} for k, v in results.items()}, f)

    with open(f"{gamma_dir}/all_loss_histories.json", "r") as f:
        loss_data = json.load(f)

    for model_name, losses in loss_data.items():
        plt.figure(figsize=(8, 5))
        plt.plot(losses["train"], label="Train Loss", linewidth=2)
        plt.plot(losses["val"], label="Val Loss", linewidth=2)
        plt.plot(losses["val_auroc"], label="Val AUROC", linewidth=2)
        plt.title(f"{model_name} - Gamma={gamma} Loss Curve")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{gamma_dir}/loss_plots/{model_name.replace('/', '_')}_loss_curve.png")
        plt.close()

    print(f"Finished training for gamma={gamma}, saved to '{gamma_dir}'")


Starting training for gamma=0.6
Training coatnet_2_rw_224.sw_in12k_ft_in1k with gamma=0.6...


model.safetensors:   0%|          | 0.00/296M [00:00<?, ?B/s]

Epoch 1/25 - Gamma: 0.6 - Train Loss: 0.1998 - Val Loss: 0.1891 - Val AUROC: 0.7596
Epoch 2/25 - Gamma: 0.6 - Train Loss: 0.1876 - Val Loss: 0.1889 - Val AUROC: 0.7824
Epoch 3/25 - Gamma: 0.6 - Train Loss: 0.1824 - Val Loss: 0.1822 - Val AUROC: 0.7959
Epoch 4/25 - Gamma: 0.6 - Train Loss: 0.1792 - Val Loss: 0.1799 - Val AUROC: 0.8028
Epoch 5/25 - Gamma: 0.6 - Train Loss: 0.1764 - Val Loss: 0.1788 - Val AUROC: 0.8125
Epoch 6/25 - Gamma: 0.6 - Train Loss: 0.1743 - Val Loss: 0.1785 - Val AUROC: 0.8144
Epoch 7/25 - Gamma: 0.6 - Train Loss: 0.1724 - Val Loss: 0.1772 - Val AUROC: 0.8204
Epoch 8/25 - Gamma: 0.6 - Train Loss: 0.1708 - Val Loss: 0.1770 - Val AUROC: 0.8177
Epoch 9/25 - Gamma: 0.6 - Train Loss: 0.1692 - Val Loss: 0.1746 - Val AUROC: 0.8231
Epoch 10/25 - Gamma: 0.6 - Train Loss: 0.1675 - Val Loss: 0.1746 - Val AUROC: 0.8244
Epoch 11/25 - Gamma: 0.6 - Train Loss: 0.1657 - Val Loss: 0.1780 - Val AUROC: 0.8202
Epoch 12/25 - Gamma: 0.6 - Train Loss: 0.1643 - Val Loss: 0.1748 - Val AUR

model.safetensors:   0%|          | 0.00/919M [00:00<?, ?B/s]

Epoch 1/25 - Gamma: 0.6 - Train Loss: 0.1860 - Val Loss: 0.1765 - Val AUROC: 0.8142
Epoch 2/25 - Gamma: 0.6 - Train Loss: 0.1731 - Val Loss: 0.1738 - Val AUROC: 0.8241
Epoch 3/25 - Gamma: 0.6 - Train Loss: 0.1653 - Val Loss: 0.1734 - Val AUROC: 0.8268
Epoch 4/25 - Gamma: 0.6 - Train Loss: 0.1550 - Val Loss: 0.1745 - Val AUROC: 0.8256
Epoch 5/25 - Gamma: 0.6 - Train Loss: 0.1389 - Val Loss: 0.1909 - Val AUROC: 0.8073
Training densenet121 with gamma=0.6...


model.safetensors:   0%|          | 0.00/32.3M [00:00<?, ?B/s]

Epoch 1/25 - Gamma: 0.6 - Train Loss: 0.1932 - Val Loss: 0.1835 - Val AUROC: 0.7822
Epoch 2/25 - Gamma: 0.6 - Train Loss: 0.1803 - Val Loss: 0.1794 - Val AUROC: 0.8069
Epoch 3/25 - Gamma: 0.6 - Train Loss: 0.1763 - Val Loss: 0.1785 - Val AUROC: 0.8072
Epoch 4/25 - Gamma: 0.6 - Train Loss: 0.1729 - Val Loss: 0.1770 - Val AUROC: 0.8156
Epoch 5/25 - Gamma: 0.6 - Train Loss: 0.1699 - Val Loss: 0.1761 - Val AUROC: 0.8159
Epoch 6/25 - Gamma: 0.6 - Train Loss: 0.1670 - Val Loss: 0.1753 - Val AUROC: 0.8150
Epoch 7/25 - Gamma: 0.6 - Train Loss: 0.1643 - Val Loss: 0.1782 - Val AUROC: 0.8106
Training maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k with gamma=0.6...


model.safetensors:   0%|          | 0.00/465M [00:00<?, ?B/s]

Epoch 1/25 - Gamma: 0.6 - Train Loss: 0.1890 - Val Loss: 0.1843 - Val AUROC: 0.7916
Epoch 2/25 - Gamma: 0.6 - Train Loss: 0.1776 - Val Loss: 0.1780 - Val AUROC: 0.8114
Epoch 3/25 - Gamma: 0.6 - Train Loss: 0.1729 - Val Loss: 0.1751 - Val AUROC: 0.8203
Epoch 4/25 - Gamma: 0.6 - Train Loss: 0.1689 - Val Loss: 0.1734 - Val AUROC: 0.8247
Epoch 5/25 - Gamma: 0.6 - Train Loss: 0.1652 - Val Loss: 0.1739 - Val AUROC: 0.8252
Epoch 6/25 - Gamma: 0.6 - Train Loss: 0.1612 - Val Loss: 0.1753 - Val AUROC: 0.8293
Epoch 7/25 - Gamma: 0.6 - Train Loss: 0.1568 - Val Loss: 0.1756 - Val AUROC: 0.8244
Epoch 8/25 - Gamma: 0.6 - Train Loss: 0.1512 - Val Loss: 0.1831 - Val AUROC: 0.8148
Training swin_large_patch4_window7_224 with gamma=0.6...


model.safetensors:   0%|          | 0.00/788M [00:00<?, ?B/s]

Epoch 1/25 - Gamma: 0.6 - Train Loss: 0.1917 - Val Loss: 0.1815 - Val AUROC: 0.7964
Epoch 2/25 - Gamma: 0.6 - Train Loss: 0.1803 - Val Loss: 0.1767 - Val AUROC: 0.8073
Epoch 3/25 - Gamma: 0.6 - Train Loss: 0.1765 - Val Loss: 0.1748 - Val AUROC: 0.8193
Epoch 4/25 - Gamma: 0.6 - Train Loss: 0.1724 - Val Loss: 0.1771 - Val AUROC: 0.8154
Epoch 5/25 - Gamma: 0.6 - Train Loss: 0.1694 - Val Loss: 0.1758 - Val AUROC: 0.8182
Training vgg19.tv_in1k with gamma=0.6...


model.safetensors:   0%|          | 0.00/575M [00:00<?, ?B/s]

Epoch 1/25 - Gamma: 0.6 - Train Loss: 0.2007 - Val Loss: 0.1933 - Val AUROC: 0.7446
Epoch 2/25 - Gamma: 0.6 - Train Loss: 0.1916 - Val Loss: 0.1870 - Val AUROC: 0.7683
Epoch 3/25 - Gamma: 0.6 - Train Loss: 0.1869 - Val Loss: 0.1838 - Val AUROC: 0.7781
Epoch 4/25 - Gamma: 0.6 - Train Loss: 0.1834 - Val Loss: 0.1826 - Val AUROC: 0.7800
Epoch 5/25 - Gamma: 0.6 - Train Loss: 0.1809 - Val Loss: 0.1797 - Val AUROC: 0.7909
Epoch 6/25 - Gamma: 0.6 - Train Loss: 0.1790 - Val Loss: 0.1795 - Val AUROC: 0.7928
Epoch 7/25 - Gamma: 0.6 - Train Loss: 0.1772 - Val Loss: 0.1786 - Val AUROC: 0.8005
Epoch 8/25 - Gamma: 0.6 - Train Loss: 0.1755 - Val Loss: 0.1799 - Val AUROC: 0.7909
Epoch 9/25 - Gamma: 0.6 - Train Loss: 0.1740 - Val Loss: 0.1778 - Val AUROC: 0.8026
Epoch 10/25 - Gamma: 0.6 - Train Loss: 0.1726 - Val Loss: 0.1807 - Val AUROC: 0.7990
Epoch 11/25 - Gamma: 0.6 - Train Loss: 0.1712 - Val Loss: 0.1786 - Val AUROC: 0.7990
Finished training for gamma=0.6, saved to '/student/csc490_project/shared/