<a href="https://colab.research.google.com/github/Dewwbe/-Real-Estate-Document-Collection-/blob/main/Swin_Tiny_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

🩺 Chest X-ray Pneumonia Detection using Vision Transformers (Swin-Tiny & ViT-B/16)
📘 Project Overview

This project focuses on developing a deep learning–based diagnostic model to automatically classify chest X-ray images as either NORMAL or PNEUMONIA.
We leverage state-of-the-art Transformer architectures — specifically Swin-Tiny and ViT-B/16 — to explore their capability in medical image analysis and compare their performance.

Unlike traditional CNNs, Vision Transformers (ViTs) capture long-range dependencies and contextual cues, which are especially valuable for subtle patterns in medical imaging.
The Swin Transformer (Shifted Window Transformer) offers hierarchical, windowed self-attention for better efficiency, while ViT-B/16 uses global attention for strong discriminative features.

🎯 Objectives

Build and train a binary classifier to detect pneumonia from chest X-rays.

Compare Swin-Tiny (trained from scratch or fine-tuned) vs ViT-B/16 (fine-tuned) performance.

Analyze training dynamics, evaluation metrics, and calibration reliability.

Visualize attention-based interpretability maps to understand model focus regions.

⚙️ Key Features

Data Augmentation: Realistic geometric and photometric transformations to improve robustness.

Balanced Sampling: WeightedRandomSampler ensures class balance during training.

Mixed Precision Training: Efficient GPU utilization via torch.cuda.amp.

Comprehensive Metrics: Accuracy, Precision, Recall, F1, ROC-AUC, Brier score, and calibration plots.

Attention Rollout Visualization: Generates interpretable heatmaps highlighting model focus areas.

Side-by-Side Comparison: Quantitative and qualitative evaluation of Swin-Tiny vs ViT-B/16.

📊 Expected Outcomes

High recall (sensitivity) for pneumonia detection — prioritizing diagnostic safety.

Well-calibrated confidence scores to improve clinical interpretability.

Visual explanations showing how transformer attention aligns with pathological regions.

🧠 Dataset

The dataset used follows the standard Chest X-Ray (NORMAL vs PNEUMONIA) structure:

/chest_xray/
    train/
        NORMAL/
        PNEUMONIA/
    val/
        NORMAL/
        PNEUMONIA/
    test/
        NORMAL/
        PNEUMONIA/


Total samples ≈ 5,000+ images, with a typical 70/15/15 train–val–test split.


🔧 Setup & Library Installation
# Installs required packages and imports all dependencies.

In [None]:
!pip -q install --upgrade scikit-learn scipy

import os, math, copy, time, random, warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms, models, utils as tv_utils

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, roc_curve, confusion_matrix, classification_report,
    brier_score_loss
)
from sklearn.calibration import calibration_curve
from scipy.stats import pearsonr, spearmanr

print("✅ Libraries loaded successfully.")

⚙️ Configuration and Hyperparameters
# Sets up seeds, device, dataset paths, and model/training configs.

In [None]:
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

DATA_DIR = "/content/drive/MyDrive/chest_xray"  # must contain train/, val/, test/
TRAIN_DIR = os.path.join(DATA_DIR, "train")
VAL_DIR   = os.path.join(DATA_DIR, "val")
TEST_DIR  = os.path.join(DATA_DIR, "test")

# Model toggles
USE_PRETRAINED_SWINT = False   # True = fine-tune ImageNet weights, False = train from scratch
RUN_VIT = True                 # True = also run ViT-B/16 comparison

# Training parameters
EPOCHS_SWINT = 25 if USE_PRETRAINED_SWINT else 80
EPOCHS_VIT   = 20
PATIENCE = 6
BATCH_SIZE = 32
NUM_WORKERS = 2
IMG_SIZE = 384

# Optimization
LR_HEAD = 1e-3
LR_ALL_SWINT = 3e-5 if USE_PRETRAINED_SWINT else 3e-4
LR_ALL_VIT   = 3e-5
WEIGHT_DECAY_SWINT = 1e-4 if USE_PRETRAINED_SWINT else 5e-2
LABEL_SMOOTH = 0.05 if not USE_PRETRAINED_SWINT else 0.0

OUT_DIR = "/content"
os.makedirs(OUT_DIR, exist_ok=True)

🖼️ Data Loading and Exploration

In [None]:
train_tfms = transforms.Compose([
    transforms.Resize(int(IMG_SIZE*1.15)),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.80, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=10, translate=(0.02,0.02), scale=(0.98,1.02)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

eval_tfms = transforms.Compose([
    transforms.Resize(int(IMG_SIZE*1.15)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

train_ds = datasets.ImageFolder(TRAIN_DIR, transform=train_tfms)
val_ds   = datasets.ImageFolder(VAL_DIR,   transform=eval_tfms)
test_ds  = datasets.ImageFolder(TEST_DIR,  transform=eval_tfms)
class_names = train_ds.classes
print("Classes:", class_names)

# Balance sampler
targets = np.array(train_ds.targets)
class_counts = np.bincount(targets)
class_weights = 1.0 / (class_counts + 1e-9)
sample_weights = class_weights[targets]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler,
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

# Overview
def show_dataset_overview(ds, title):
    print(f"{title}: {len(ds)} images | classes={ds.classes}")
    cnt = np.bincount(ds.targets)
    for i,c in enumerate(ds.classes):
        print(f"  {c}: {cnt[i]}")

show_dataset_overview(train_ds, "Train")
show_dataset_overview(val_ds, "Val")
show_dataset_overview(test_ds, "Test")

# Sample visualization
def show_samples(ds, n=12):
    idxs = np.random.choice(len(ds), size=min(n, len(ds)), replace=False)
    imgs, labels = zip(*[ds[i] for i in idxs])
    grid = tv_utils.make_grid(torch.stack(imgs), nrow=6, padding=2, normalize=True)
    plt.figure(figsize=(10,5))
    plt.imshow(grid.permute(1,2,0))
    plt.axis('off')
    plt.title('Random training samples')
    plt.show()

show_samples(train_ds)

🔁 Training Utilities
# Contains helper functions for training loops, evaluation, and plotting.

In [None]:
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

def bce_logits_smooth(logits, targets, eps=0.0):
    if eps > 0:
        targets = targets*(1-eps) + 0.5*eps
    return F.binary_cross_entropy_with_logits(logits, targets)

def run_epoch(model, loader, criterion, optimizer=None, grad_clip=None):
    train_mode = optimizer is not None
    model.train(train_mode)
    total_loss = 0.0
    y_true, y_prob = [], []
    if train_mode:
        optimizer.zero_grad(set_to_none=True)
    for images, labels in loader:
        images = images.to(DEVICE, non_blocking=True)
        labels = labels.float().unsqueeze(1).to(DEVICE, non_blocking=True)
        with torch.set_grad_enabled(train_mode):
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                logits = model(images)
                loss = criterion(logits, labels)
            if train_mode:
                scaler.scale(loss).backward()
                if grad_clip: torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(optimizer); scaler.update()
                optimizer.zero_grad(set_to_none=True)
        total_loss += loss.item() * images.size(0)
        probs = torch.sigmoid(logits).detach().cpu().numpy().ravel()
        y_prob.extend(probs.tolist()); y_true.extend(labels.detach().cpu().numpy().ravel().tolist())
    avg_loss = total_loss/len(loader.dataset)
    y_true = np.array(y_true); y_prob = np.array(y_prob)
    y_pred = (y_prob >= 0.5).astype(int)
    acc = accuracy_score(y_true, y_pred)
    return avg_loss, acc, y_true, y_prob, y_pred

def plot_history(history, title="Training Curves"):
    e = range(1, len(history['train_loss'])+1)
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(e, history['train_loss'], label='Train'); plt.plot(e, history['val_loss'], label='Val')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss'); plt.legend(); plt.grid(True, ls='--', alpha=0.4)
    plt.subplot(1,2,2)
    plt.plot(e, history['train_acc'], label='Train'); plt.plot(e, history['val_acc'], label='Val')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Accuracy'); plt.legend(); plt.grid(True, ls='--', alpha=0.4)
    plt.suptitle(title); plt.show()

🧩 Swin-Tiny Model Definition & Training

In [None]:
def build_swin_tiny(use_pretrained=True):
    if use_pretrained:
        swin = models.swin_t(weights=models.Swin_T_Weights.IMAGENET1K_V1)
    else:
        swin = models.swin_t(weights=None)
    in_feats = swin.head.in_features
    swin.head = nn.Identity()
    head = nn.Sequential(
        nn.Linear(in_feats, 1024),
        nn.ReLU(inplace=True),
        nn.BatchNorm1d(1024),
        nn.Dropout(0.3 if use_pretrained else 0.4),
        nn.Linear(1024, 256),
        nn.ReLU(inplace=True),
        nn.BatchNorm1d(256),
        nn.Dropout(0.25 if use_pretrained else 0.3),
        nn.Linear(256, 1)
    )
    model = nn.Sequential(swin, head)
    return model.to(DEVICE)

def fit_swin(model, epochs, ckpt_path, use_pretrained=True):
    history = {'train_loss':[], 'val_loss':[], 'train_acc':[], 'val_acc':[]}
    # (full training logic same as your version)
    # ...
    # [Include full fit_swin() function from your code]
    return model, history

# Train Swin-Tiny
swin = build_swin_tiny(USE_PRETRAINED_SWINT)
swin_ckpt = os.path.join(OUT_DIR, f"swin_tiny_{'ft' if USE_PRETRAINED_SWINT else 'scratch'}_best.pth")
swin, swin_hist = fit_swin(swin, EPOCHS_SWINT, swin_ckpt, use_pretrained=USE_PRETRAINED_SWINT)
plot_history(swin_hist, f"Swin-Tiny Training Curves ({'fine-tune' if USE_PRETRAINED_SWINT else 'from scratch'})")

🤖 ViT-B/16 Model Fine-tuning #Fine-tunes ViT-B/16 for comparison if RUN_VIT=True.

In [None]:
if RUN_VIT:
    vit = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
    vit.heads = nn.Identity()
    vit_head = nn.Sequential(
        nn.Linear(vit.hidden_dim, 2048), nn.ReLU(inplace=True), nn.BatchNorm1d(2048), nn.Dropout(0.3),
        nn.Linear(2048, 512), nn.ReLU(inplace=True), nn.BatchNorm1d(512), nn.Dropout(0.3),
        nn.Linear(512, 1)
    )
    vit = nn.Sequential(vit, vit_head).to(DEVICE)

    # Warmup: train head only
    for p in vit[0].parameters(): p.requires_grad = False
    for p in vit[1].parameters(): p.requires_grad = True
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, vit.parameters()), lr=LR_HEAD, weight_decay=1e-4)
    crit = nn.BCEWithLogitsLoss().to(DEVICE)

    vit_hist = {'train_loss':[], 'val_loss':[], 'train_acc':[], 'val_acc':[]}
    best_val, best_w, wait = float('inf'), copy.deepcopy(vit.state_dict()), 0
    warmup_epochs = 3

    for ep in range(1, warmup_epochs+1):
        tl, ta, *_ = run_epoch(vit, train_loader, crit, optimizer)
        vl, va, *_ = run_epoch(vit, val_loader,   crit, optimizer=None)
        vit_hist['train_loss'].append(tl); vit_hist['val_loss'].append(vl)
        vit_hist['train_acc'].append(ta);  vit_hist['val_acc'].append(va)
        print(f"[ViT Warmup {ep}/{warmup_epochs}] tl {tl:.4f} ta {ta:.3f} | vl {vl:.4f} va {va:.3f}")
        if vl < best_val: best_val, best_w = vl, copy.deepcopy(vit.state_dict())

    # Unfreeze all and fine-tune
    for p in vit.parameters(): p.requires_grad = True
    optimizer = optim.AdamW(vit.parameters(), lr=LR_ALL_VIT, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_VIT-warmup_epochs, eta_min=1e-6)

    for ep in range(warmup_epochs+1, EPOCHS_VIT+1):
        tl, ta, *_ = run_epoch(vit, train_loader, crit, optimizer)
        vl, va, *_ = run_epoch(vit, val_loader,   crit, optimizer=None)
        scheduler.step()
        vit_hist['train_loss'].append(tl); vit_hist['val_loss'].append(vl)
        vit_hist['train_acc'].append(ta);  vit_hist['val_acc'].append(va)
        print(f"[ViT {ep}/{EPOCHS_VIT}] tl {tl:.4f} ta {ta:.3f} | vl {vl:.4f} va {va:.3f}")
        if vl < best_val: best_val, best_w = vl, copy.deepcopy(vit.state_dict())

    vit.load_state_dict(best_w)
    vit_ckpt = os.path.join(OUT_DIR, "vit_b16_ft_best.pth")
    torch.save(vit.state_dict(), vit_ckpt)

    plot_history(vit_hist, "ViT-B/16 Training Curves (fine-tune)")
    print("✅ ViT-B/16 fine-tuning completed.")

📊 Model Evaluation - Evaluates test performance, ROC, calibration, correlations.

In [None]:
def evaluate_full(model, loader, title="Model"):
    model.eval()
    y_true, y_prob = [], []
    total_loss = 0.0
    criterion = nn.BCEWithLogitsLoss().to(DEVICE)
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(DEVICE)
            labels = labels.float().unsqueeze(1).to(DEVICE)
            logits = model(images)
            loss = criterion(logits, labels)
            total_loss += loss.item() * images.size(0)
            probs = torch.sigmoid(logits).cpu().numpy().ravel()
            y_prob.extend(probs.tolist()); y_true.extend(labels.cpu().numpy().ravel().tolist())

    y_true = np.array(y_true); y_prob = np.array(y_prob)
    y_pred = (y_prob >= 0.5).astype(int)
    tl = total_loss/len(loader.dataset)
    acc  = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec  = recall_score(y_true, y_pred, zero_division=0)
    f1   = f1_score(y_true, y_pred, zero_division=0)
    try: auc = roc_auc_score(y_true, y_prob)
    except: auc = float('nan')

    print(f"\n=== {title} (Test) ===")
    print(f"Loss: {tl:.4f} | Acc: {acc:.4f} | P: {prec:.4f} | R: {rec:.4f} | F1: {f1:.4f} | AUC: {auc:.4f}")
    print("\nClassification Report:\n", classification_report(y_true, y_pred, target_names=class_names, digits=4))

    # ROC Curve
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure(figsize=(5,4))
    plt.plot(fpr, tpr, lw=2, label=f'ROC (AUC={auc:.3f})')
    plt.plot([0,1],[0,1],'--', lw=1)
    plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate'); plt.title(f'ROC - {title}')
    plt.grid(True, ls='--', alpha=0.4); plt.legend(); plt.show()

    # Calibration
    prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=10, strategy='quantile')
    bs = brier_score_loss(y_true, y_prob)
    plt.figure(figsize=(5,4))
    plt.plot(prob_pred, prob_true, marker='o', label='Reliability')
    plt.plot([0,1],[0,1],'--', label='Perfect')
    plt.xlabel('Predicted probability'); plt.ylabel('Observed frequency')
    plt.title(f'Calibration - {title} (Brier={bs:.4f})')
    plt.grid(True, ls='--', alpha=0.4); plt.legend(); plt.show()

    # Correlations
    pr, p_p = pearsonr(y_true, y_prob)
    sr, s_p = spearmanr(y_true, y_prob)
    print(f"Pearson r = {pr:.4f} (p={p_p:.3g}) | Spearman ρ = {sr:.4f} (p={s_p:.3g})")

    return {
        'loss': tl, 'accuracy': acc, 'precision': prec, 'recall': rec,
        'f1': f1, 'auc': auc, 'brier': bs,
        'pearson_r': pr, 'spearman_rho': sr,
        'y_true': y_true, 'y_prob': y_prob, 'y_pred': y_pred
    }

# --- Run evaluations ---
swin_test = evaluate_full(swin, test_loader, title=f"Swin-Tiny ({'FT' if USE_PRETRAINED_SWINT else 'Scratch'})")
if RUN_VIT:
    vit_test = evaluate_full(vit, test_loader, title="ViT-B/16 (Fine-tune)")


🔍 Interpretability (Attention Rollout)
#Exports metrics, training logs, and summary comparisons.
#Generates heatmaps similar to Grad-CAM for both Swin and ViT models.

In [None]:
def attention_rollout_vitlike(model, images, size=IMG_SIZE, discard_ratio=0.9):
    model.eval()
    net = model[0] if isinstance(model, nn.Sequential) else model
    with torch.no_grad():
        x = images.to(DEVICE)
        if isinstance(net, models.VisionTransformer):
            # ViT path
            x_p = net._process_input(x)
            x_p = net.conv_proj(x_p).flatten(2).transpose(1,2)
            cls = net.class_token.expand(x_p.size(0), -1, -1)
            x_p = torch.cat((cls, x_p), dim=1) + net.encoder.pos_embedding
            x_p = net.encoder.dropout(x_p)
            mats=[]
            for blk in net.encoder.layers:
                attn_out, attn_weights = blk.attn(x_p, need_weights=True)
                mats.append(attn_weights.mean(dim=1))
                x_p = blk.ln_1(x_p + attn_out)
                x_p = blk.ln_2(x_p + blk.mlp(x_p))
            b = mats[0].size(0); N = mats[0].size(-1)
            joint = torch.eye(N, device=DEVICE).unsqueeze(0).repeat(b,1,1)
            for a in mats:
                flat = a.view(b,-1)
                k = (flat.size(1)*discard_ratio).round().long()
                _, idx = torch.topk(flat, k.item(), dim=1, largest=False)
                flat.scatter_(1, idx, 0); a = flat.view_as(a)
                a = a/(a.sum(dim=-1, keepdim=True)+1e-6)
                joint = a @ joint
            mask = joint[:,0,1:]
            hw = int(math.sqrt(mask.size(-1))); mask = mask.view(b,1,hw,hw)
            mask = F.interpolate(mask, size=(size,size), mode='bilinear', align_corners=False)
            mask = (mask - mask.min())/(mask.max()-mask.min()+1e-6)
            return mask.cpu()
        else:
            # Swin path (approximate)
            feats = net.features(x)
            heat = feats.norm(dim=1, keepdim=True)
            heat = F.interpolate(heat, size=(size,size), mode='bilinear', align_corners=False)
            heat = (heat - heat.min())/(heat.max()-heat.min()+1e-6)
            return heat.cpu()

def show_gallery(model, ds, y_true, y_pred, title="Interpretability Gallery"):
    idx_tp = np.where((y_true==1) & (y_pred==1))[0]
    idx_fp = np.where((y_true==0) & (y_pred==1))[0]
    idx_tn = np.where((y_true==0) & (y_pred==0))[0]
    idx_fn = np.where((y_true==1) & (y_pred==0))[0]
    groups = [("TP", idx_tp), ("FP", idx_fp), ("TN", idx_tn), ("FN", idx_fn)]
    for name, idxs in groups:
        if len(idxs)==0:
            print(f"No {name} examples.")
            continue
        pick = np.random.choice(idxs, size=min(4, len(idxs)), replace=False)
        plt.figure(figsize=(10,5))
        for j,k in enumerate(pick):
            img, _ = ds[k]
            vis = (img - img.min())/(img.max()-img.min()+1e-6)
            with torch.no_grad():
                mask = attention_rollout_vitlike(model, img.unsqueeze(0))
            overlay = 0.6*vis + 0.4*mask[0].repeat(3,1,1)
            plt.subplot(1,4,j+1); plt.imshow(overlay.permute(1,2,0)); plt.axis('off')
        plt.suptitle(f'{title}: {name} (Transformer attention/activation overlays)')
        plt.show()

show_gallery(swin, test_ds, swin_test['y_true'], swin_test['y_pred'], title='Swin-Tiny')
if RUN_VIT:
    show_gallery(vit, test_ds, vit_test['y_true'], vit_test['y_pred'], title='ViT-B/16')


📁 Save Results & Summary- Exports metrics, training logs, and summary comparisons.

In [None]:
# ============================================================
# 📁 Save Results & Project Summary
# Saves training logs, metrics, checkpoints, and prints summaries.
# ============================================================

def save_history_csv(history, path_csv): pd.DataFrame(history).to_csv(path_csv, index=False)
def save_metrics_csv(metrics_dict, path_csv):
    md = {k:v for k,v in metrics_dict.items() if not isinstance(v, np.ndarray)}
    pd.DataFrame([md]).to_csv(path_csv, index=False)

swin_hist_csv = os.path.join(OUT_DIR, f"swin_tiny_{'ft' if USE_PRETRAINED_SWINT else 'scratch'}_history.csv")
swin_metrics_csv = os.path.join(OUT_DIR, f"swin_tiny_{'ft' if USE_PRETRAINED_SWINT else 'scratch'}_test_metrics.csv")
save_history_csv(swin_hist, swin_hist_csv); save_metrics_csv(swin_test, swin_metrics_csv)
torch.save(swin.state_dict(), swin_ckpt)

if RUN_VIT:
    vit_hist_csv = os.path.join(OUT_DIR, "vit_b16_ft_history.csv")
    vit_metrics_csv = os.path.join(OUT_DIR, "vit_b16_ft_test_metrics.csv")
    save_history_csv(vit_hist, vit_hist_csv); save_metrics_csv(vit_test, vit_metrics_csv)

def print_summary(model_name, history, tm, ds_sizes, img_size, notes):
    train_len, val_len, test_len = ds_sizes
    print(f"\n{model_name.upper()} PNEUMONIA DETECTION - PROJECT SUMMARY")
    print("="*70)
    print("DATASET STATISTICS:")
    print(f"   Training samples: {train_len}")
    print(f"   Validation samples: {val_len}")
    print(f"   Test samples: {test_len}")
    print(f"   Classes: {class_names}")
    print("\nMODEL ARCHITECTURE:")
    total_params = sum(p.numel() for p in tm.parameters())
    trainable_params = sum(p.numel() for p in tm.parameters() if p.requires_grad)
    print(f"   Model: {model_name}")
    print(f"   Input Size: {img_size}x{img_size}x3")
    print(f"   Trainable Parameters: {trainable_params:,}")
    print(f"   Total Parameters: {total_params:,}")
    print(f"   Key Feature: {notes}")
    print("\nTRAINING PERFORMANCE:")
    print(f"   Final Train Accuracy: {history['train_acc'][-1]:.4f}")
    print(f"   Final Val   Accuracy: {history['val_acc'][-1]:.4f}")
    print("\nTEST PERFORMANCE:")
    metrics = swin_test if 'Swin' in model_name else tm
    print(f"   Test Accuracy: {metrics['accuracy']:.4f}")
    print(f"   Test Precision: {metrics['precision']:.4f}")
    print(f"   Test Recall: {metrics['recall']:.4f}")
    print(f"   Test F1-Score: {metrics['f1']:.4f}")
    print(f"   Test AUC: {metrics['auc']:.4f}")
    print(f"   Brier Score: {metrics['brier']:.4f}")
    print(f"   Pearson r: {metrics['pearson_r']:.4f} | Spearman ρ: {metrics['spearman_rho']:.4f}")

print_summary(
    f"Swin-Tiny ({'Fine-tune' if USE_PRETRAINED_SWINT else 'From-scratch'})",
    swin_hist, swin, (len(train_ds), len(val_ds), len(test_ds)), IMG_SIZE,
    "Windowed self-attention (data-efficient); transformer overlays for interpretability"
)

if RUN_VIT:
    print_summary(
        "ViT-B/16 (Fine-tune)",
        vit_hist, vit, (len(train_ds), len(val_ds), len(test_ds)), IMG_SIZE,
        "Global self-attention; transformer overlays for interpretability"
    )

if RUN_VIT:
    comp = pd.DataFrame([
        {"Model":"Swin-Tiny"+(" FT" if USE_PRETRAINED_SWINT else " Scratch"),
         "Accuracy":swin_test['accuracy'], "Precision":swin_test['precision'],
         "Recall":swin_test['recall'], "F1":swin_test['f1'], "AUC":swin_test['auc'],
         "Brier":swin_test['brier']},
        {"Model":"ViT-B/16 FT",
         "Accuracy":vit_test['accuracy'], "Precision":vit_test['precision'],
         "Recall":vit_test['recall'], "F1":vit_test['f1'], "AUC":vit_test['auc'],
         "Brier":vit_test['brier']}
    ])
    print("\n📊 Comparison Table:\n", comp)
