In [None]:
# Core
import os, time, math, random
from pathlib import Path

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

# Metrics & plots
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt

import kagglehub

# Reproducibility
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
path = kagglehub.dataset_download("paultimothymooney/chest-xray-pneumonia")
# Use the returned path to set base_dir correctly
base_dir = os.path.join(path, "chest_xray")
train_dir = os.path.join(base_dir, 'train')
val_dir   = os.path.join(base_dir, 'val')
test_dir  = os.path.join(base_dir, 'test')

In [None]:
print("Train dir:", train_dir)
print("Val dir:", val_dir)
print("Test dir:", test_dir)

print("Train dir contents:", os.listdir(train_dir))
print("Val dir contents:", os.listdir(val_dir))
print("Test dir contents:", os.listdir(test_dir))

In [None]:

IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 20
LR = 3e-4
WEIGHT_DECAY = 0.05
NUM_WORKERS = 2  # set higher if your env allows
PATIENCE = 5      # early stopping

In [None]:
# Transforms & dataloaders

# Note: X-ray images are stored as RGB in this dataset. We normalize like ImageNet for stability.
mean = (0.485, 0.456, 0.406)
std  = (0.229, 0.224, 0.225)

transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(7),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_ds = datasets.ImageFolder(train_dir, transform=transform)
val_ds   = datasets.ImageFolder(val_dir,   transform=transform)
test_ds  = datasets.ImageFolder(test_dir,  transform=transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

class_names = train_ds.classes
class_names

In [None]:
# Handle class imbalance (loss weights)

# Compute class weights from train set for BCEWithLogitsLoss (binary labels 0/1)
# ImageFolder encodes classes alphabetically; ensure 'NORMAL' and 'PNEUMONIA' mapping if needed.
from collections import Counter
train_counts = Counter([y for _, y in train_ds.samples])
print("Train counts:", train_counts, " -> class_names:", class_names)

# weight for positive class (assume class 1 is 'PNEUMONIA' alphabetically after 'NORMAL')
pos_weight_value = train_counts[0] / max(1, train_counts[1])  # ratio negatives/positives
pos_weight = torch.tensor([pos_weight_value], device=device, dtype=torch.float32)
pos_weight

In [None]:
#  Handle class imbalance (loss weights)

# Compute class weights from train set for BCEWithLogitsLoss (binary labels 0/1)
# ImageFolder encodes classes alphabetically; ensure 'NORMAL' and 'PNEUMONIA' mapping if needed.
from collections import Counter
train_counts = Counter([y for _, y in train_ds.samples])
print("Train counts:", train_counts, " -> class_names:", class_names)

# weight for positive class (assume class 1 is 'PNEUMONIA' alphabetically after 'NORMAL')
pos_weight_value = train_counts[0] / max(1, train_counts[1])  # ratio negatives/positives
pos_weight = torch.tensor([pos_weight_value], device=device, dtype=torch.float32)
pos_weight

In [None]:
# ViT building blocks: Patch Embedding

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_ch=3, embed_dim=384):
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size * self.grid_size
        self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):  # x: [B, C, H, W]
        x = self.proj(x)                 # [B, E, H', W']
        x = x.flatten(2).transpose(1, 2) # [B, N, E]
        return x

In [None]:
# Multi-head self-attention (from scratch)

class MSA(nn.Module):
    def __init__(self, embed_dim=384, num_heads=6, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):  # x: [B, N, E]
        B, N, E = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)  # each: [B, N, H, D]
        q = q.transpose(1, 2)        # [B, H, N, D]
        k = k.transpose(1, 2)        # [B, H, N, D]
        v = v.transpose(1, 2)        # [B, H, N, D]

        attn = (q @ k.transpose(-2, -1)) * self.scale  # [B,H,N,N]
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = attn @ v                                # [B,H,N,D]
        out = out.transpose(1, 2).reshape(B, N, E)    # [B,N,E]
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

In [None]:
# Transformer encoder block (Pre-LN)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=384, num_heads=6, mlp_ratio=4.0, drop=0.1, attn_drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn  = MSA(embed_dim, num_heads, attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(embed_dim)
        hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden, embed_dim),
            nn.Dropout(drop),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
# Vision Transformer model

class ViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_ch=3, num_classes=1,
                 embed_dim=384, depth=8, num_heads=6, mlp_ratio=4.0, drop=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_ch, embed_dim)
        self.num_patches = self.patch_embed.num_patches

        # class token & positional embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
        self.pos_drop = nn.Dropout(drop)

        # Transformer encoder
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, drop=drop, attn_drop=0.0)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)  # logits

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)

        # optional: kaiming for linear/conv
        def _init(m):
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
        self.apply(_init)

    def forward(self, x):  # x: [B,3,224,224]
        B = x.size(0)
        x = self.patch_embed(x)  # [B, N, E]
        cls = self.cls_token.expand(B, -1, -1)  # [B,1,E]
        x = torch.cat([cls, x], dim=1)          # [B, 1+N, E]
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls_out = x[:, 0]                       # [B, E]
        logits = self.head(cls_out)             # [B, 1]
        return logits

In [None]:
# Instantiate model, optimizer, scheduler, loss

model = ViT(
    img_size=IMG_SIZE, patch_size=16, in_ch=3, num_classes=1,
    embed_dim=384, depth=8, num_heads=6, mlp_ratio=4.0, drop=0.1
).to(device)

total_params = sum(p.numel() for p in model.parameters())/1e6
print(model.__class__.__name__, "params (M):", round(total_params, 2))

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)  # handles imbalance
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))

In [None]:
# Utilities: metrics function

@torch.no_grad()
def compute_metrics_from_logits(logits, targets):
    # logits: [N,1], targets: [N]
    probs = torch.sigmoid(logits.view(-1)).cpu().numpy()
    preds = (probs >= 0.5).astype(np.int64)
    y_true = targets.view(-1).cpu().numpy()

    acc = accuracy_score(y_true, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, preds, average='binary', zero_division=0)
    try:
        auc = roc_auc_score(y_true, probs)
    except ValueError:
        auc = float('nan')  # only one class in batch
    return acc, precision, recall, f1, auc, preds, probs


In [None]:
# Train/validate loops (1 epoch helper)

def run_one_epoch(loader, train_mode=True):
    model.train(train_mode)
    epoch_loss = 0.0
    all_logits, all_targets = [], []

    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True).float()

        with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
            logits = model(images).view(-1)  # [B]
            loss = criterion(logits, targets)

        if train_mode:
            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        epoch_loss += loss.item() * images.size(0)
        all_logits.append(logits.detach().unsqueeze(1))
        all_targets.append(targets.detach())

    epoch_loss /= len(loader.dataset)
    all_logits = torch.cat(all_logits, dim=0)   # [N,1]
    all_targets = torch.cat(all_targets, dim=0) # [N]
    acc, prec, rec, f1, auc, preds, probs = compute_metrics_from_logits(all_logits, all_targets)
    return epoch_loss, acc, prec, rec, f1, auc

In [None]:
history = {"train_loss":[], "val_loss":[], "train_acc":[], "val_acc":[],
           "train_f1":[], "val_f1":[], "train_auc":[], "val_auc":[]}

best_val = -np.inf
best_state = None
no_improve = 0
start_time = time.time()

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    tr_loss, tr_acc, tr_prec, tr_rec, tr_f1, tr_auc = run_one_epoch(train_loader, train_mode=True)
    va_loss, va_acc, va_prec, va_rec, va_f1, va_auc = run_one_epoch(val_loader,   train_mode=False)

    scheduler.step()

    history["train_loss"].append(tr_loss); history["val_loss"].append(va_loss)
    history["train_acc"].append(tr_acc);   history["val_acc"].append(va_acc)
    history["train_f1"].append(tr_f1);     history["val_f1"].append(va_f1)
    history["train_auc"].append(tr_auc);   history["val_auc"].append(va_auc)

    # monitor val AUC primarily (ViT benefits from good calibration); fallback to F1 if NaN
    score = va_auc if not np.isnan(va_auc) else va_f1
    if score > best_val:
        best_val = score
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        no_improve = 0
    else:
        no_improve += 1

    print(f"Epoch {epoch:02d}/{EPOCHS} | "
          f"tr_loss {tr_loss:.4f} acc {tr_acc:.3f} f1 {tr_f1:.3f} auc {tr_auc:.3f} || "
          f"val_loss {va_loss:.4f} acc {va_acc:.3f} f1 {va_f1:.3f} auc {va_auc:.3f} "
          f"[{time.time()-t0:.1f}s]")

    if no_improve >= PATIENCE:
        print("Early stopping triggered.")
        break

print(f"Total time: {(time.time()-start_time)/60:.1f} min")
# load best weights
if best_state is not None:
    model.load_state_dict(best_state)


In [None]:
def plot_curves(history):
    epochs = range(1, len(history["train_loss"])+1)
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.plot(epochs, history["train_loss"], label="train")
    plt.plot(epochs, history["val_loss"],   label="val")
    plt.title("Loss"); plt.xlabel("epoch"); plt.legend()

    plt.subplot(1,3,2)
    plt.plot(epochs, history["train_acc"], label="train")
    plt.plot(epochs, history["val_acc"],   label="val")
    plt.title("Accuracy"); plt.xlabel("epoch"); plt.legend()

    plt.subplot(1,3,3)
    plt.plot(epochs, history["train_f1"], label="train")
    plt.plot(epochs, history["val_f1"],   label="val")
    plt.title("F1-score"); plt.xlabel("epoch"); plt.legend()
    plt.tight_layout()
    plt.show()

plot_curves(history)


In [None]:
@torch.no_grad()
def evaluate(loader):
    model.eval()
    all_logits, all_targets = [], []
    for images, targets in loader:
        images = images.to(device); targets = targets.to(device)
        logits = model(images).detach().cpu()
        all_logits.append(logits)
        all_targets.append(targets.cpu())
    all_logits = torch.cat(all_logits, dim=0)   # [N,1]
    all_targets = torch.cat(all_targets, dim=0) # [N]

    acc, prec, rec, f1, auc, preds, probs = compute_metrics_from_logits(all_logits, all_targets)
    cm = confusion_matrix(all_targets.numpy(), preds)
    return {"acc":acc,"precision":prec,"recall":rec,"f1":f1,"auc":auc,"cm":cm,"probs":probs,"preds":preds}

test_metrics = evaluate(test_loader)
test_metrics


In [None]:
print("=== Test set performance ===")
print(f"Accuracy : {test_metrics['acc']:.4f}")
print(f"Precision: {test_metrics['precision']:.4f}")
print(f"Recall   : {test_metrics['recall']:.4f}")
print(f"F1-score : {test_metrics['f1']:.4f}")
print(f"AUC      : {test_metrics['auc']:.4f}")
print("Confusion matrix (rows: true [0,1], cols: pred [0,1]):\n", test_metrics["cm"])
