# MedMNIST Week 2 — EDA, Baselines & Calibration Example

This notebook is an example for **Week 2** of the `medmnist-ssl` project.

It assumes you have already chosen one binary MedMNIST2D dataset (`breastmnist` or `pneumoniamnist`) and follows the same style as the Week 1 notebook:
- Quick dataset inspection (EDA)
- Two supervised baselines (`smallcnn`, `resnet18 --finetune head`)
- First calibration snapshot (reliability diagram + ECE)
- A small misclassified example gallery


In [None]:
import torch
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

In [None]:
!pip -q install medmnist torch torchvision scikit-learn matplotlib tqdm

In [None]:
import medmnist
from medmnist import INFO

# Choose ONE binary dataset for the entire project
DATASET_KEY = 'pneumoniamnist'  # or 'breastmnist'

info = INFO[DATASET_KEY]
DataClass = getattr(medmnist, info['python_class'])

train_ds_raw = DataClass(split='train', download=True)
val_ds_raw   = DataClass(split='val',   download=True)
test_ds_raw  = DataClass(split='test',  download=True)

print("Cache root:", train_ds_raw.root)
print("Train shape:", train_ds_raw.imgs.shape)
print("Val shape:",   val_ds_raw.imgs.shape)
print("Test shape:",  test_ds_raw.imgs.shape)
print("Labels:", info['label'])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

def describe_split(name, ds):
    labels = np.array(ds.labels).squeeze()
    counts = Counter(labels.tolist())
    print(f"{name} size:", len(labels))
    for k in sorted(counts):
        print(f"  class {k}: {counts[k]}")
    print()

describe_split("Train", train_ds_raw)
describe_split("Val",   val_ds_raw)
describe_split("Test",  test_ds_raw)

# Visualize a small grid of images per class (from the train split)
n_per_class = 4
labels = np.array(train_ds_raw.labels).squeeze()
classes = sorted(set(labels.tolist()))

rows = len(classes)
cols = n_per_class
fig, axes = plt.subplots(rows, cols, figsize=(cols * 2.0, rows * 2.0))

if rows == 1:
    axes = np.expand_dims(axes, axis=0)  # make it 2D for uniform indexing

for row, cls in enumerate(classes):
    idxs = np.where(labels == cls)[0][:n_per_class]
    for col, idx in enumerate(idxs):
        ax = axes[row, col]
        img = train_ds_raw.imgs[idx]
        if img.ndim == 2:
            ax.imshow(img, cmap='gray')
        else:
            ax.imshow(img)
        ax.axis('off')
        ax.set_title(f"y={cls}")
plt.tight_layout()
plt.show()

In [None]:
import os, random, json
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as T
import torchvision.models as tvm
from sklearn.metrics import roc_auc_score, accuracy_score

from medmnist import INFO

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_medmnist_dataset(key: str, split: str, as_rgb: bool = True, size: int = 64, download: bool = True):
    info = INFO[key]
    DataClass = getattr(medmnist, info['python_class'])
    tf = [T.Resize((size, size)), T.ToTensor()]
    if as_rgb:
        # Repeat grayscale channel to RGB if needed
        tf.append(T.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x))
    transform = T.Compose(tf)
    return DataClass(split=split, transform=transform, download=download)

def get_loaders(key, batch_size: int = 128, num_workers: int = 2, label_frac: float = 1.0, seed: int = 42):
    set_seed(seed)
    ds_train = get_medmnist_dataset(key, 'train')
    ds_val   = get_medmnnist_dataset(key, 'val') if False else get_medmnist_dataset(key, 'val')
    ds_test  = get_medmnist_dataset(key, 'test')

    n_classes = len(INFO[key]['label'])

    # Optional: subsample labels for label-efficiency experiments (not used by default in Week 2)
    if 0 < label_frac < 1.0:
        labels = np.array([int(t[1]) for t in ds_train])
        rng = np.random.RandomState(seed)
        idxs = []
        for c in np.unique(labels):
            c_idxs = np.where(labels == c)[0]
            rng.shuffle(c_idxs)
            n_keep = max(1, int(len(c_idxs) * label_frac))
            idxs.extend(c_idxs[:n_keep])
        idxs = np.sort(np.array(idxs))
        ds_train = Subset(ds_train, idxs)

    train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True,  num_workers=num_workers)
    val_loader   = DataLoader(ds_val,   batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader  = DataLoader(ds_test,  batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, val_loader, test_loader, n_classes

class SmallCNN(nn.Module):
    def __init__(self, in_ch: int = 3, n_classes: int = 2):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_ch, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(128, n_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.flatten(1)
        return self.fc(x)

def make_resnet18(n_classes: int = 2, in_ch: int = 3, pretrained: bool = True):
    m = tvm.resnet18(weights=tvm.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
    if in_ch != 3:
        w = m.conv1.weight
        m.conv1 = nn.Conv2d(in_ch, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if in_ch == 1:
            with torch.no_grad():
                m.conv1.weight.copy_(w.sum(dim=1, keepdim=True))
    in_dim = m.fc.in_features
    m.fc = nn.Linear(in_dim, n_classes)
    return m

def compute_metrics(y_true, y_prob, n_classes: int):
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    y_pred = np.argmax(y_prob, axis=1)
    acc = accuracy_score(y_true, y_pred)
    if n_classes == 2:
        try:
            auroc = roc_auc_score(y_true, y_prob[:, 1])
        except Exception:
            auroc = float('nan')
    else:
        y_true_1hot = np.eye(n_classes)[y_true]
        try:
            auroc = roc_auc_score(y_true_1hot, y_prob, average='macro', multi_class='ovr')
        except Exception:
            auroc = float('nan')
    return {'acc': float(acc), 'auroc': float(auroc)}

print('✅ Week 2 helpers ready.')

In [None]:
import math

def train_one_model(model_name: str,
                    finetune: str = 'head',
                    epochs: int = 5,
                    lr: float = 3e-4,
                    weight_decay: float = 1e-4,
                    batch_size: int = 128,
                    label_frac: float = 1.0,
                    device: str = None):
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\n=== Training {model_name} (finetune={finetune}) on {DATASET_KEY} ===")
    print("Using device:", device)

    train_loader, val_loader, test_loader, n_classes = get_loaders(
        DATASET_KEY, batch_size=batch_size, label_frac=label_frac, seed=42
    )

    if model_name == 'smallcnn':
        model = SmallCNN(in_ch=3, n_classes=n_classes)
        params = model.parameters()
    elif model_name == 'resnet18':
        model = make_resnet18(n_classes=n_classes, in_ch=3, pretrained=True)
        if finetune == 'head':
            for p in model.parameters():
                p.requires_grad = False
            for p in model.fc.parameters():
                p.requires_grad = True
            params = model.fc.parameters()
        else:
            params = model.parameters()
    else:
        raise ValueError(f"Unknown model_name: {model_name}")

    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_state = None
    best_score = -1.0

    for epoch in range(1, epochs + 1):
        model.train()
        loss_sum, n_sum = 0.0, 0
        for x, y in train_loader:
            x = x.to(device)
            y = y.squeeze().long().to(device)
            logits = model(x)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_sum += float(loss.item()) * x.size(0)
            n_sum += x.size(0)

        # validation
        model.eval()
        all_logits, all_y = [], []
        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(device)
                logits = model(x)
                all_logits.append(logits.cpu())
                all_y.append(y)
        all_logits = torch.cat(all_logits, dim=0)
        all_y = torch.cat(all_y, dim=0).squeeze().numpy()
        probs_val = torch.softmax(all_logits, dim=1).numpy()
        metrics_val = compute_metrics(all_y, probs_val, n_classes=n_classes)
        val_loss = loss_sum / max(1, n_sum)
        score = metrics_val['auroc'] if not math.isnan(metrics_val['auroc']) else metrics_val['acc']
        print(f"[{model_name}][epoch {epoch:02d}] loss={val_loss:.4f} "
              f"val_acc={metrics_val['acc']:.4f} val_auroc={metrics_val['auroc']:.4f}")
        scheduler.step()

        if score > best_score:
            best_score = score
            best_state = model.state_dict()

    if best_state is not None:
        model.load_state_dict(best_state)

    # test evaluation
    model.eval()
    all_logits, all_y, all_imgs = [], [], []
    with torch.no_grad():
        for x, y in test_loader:
            all_imgs.append(x.cpu())
            x = x.to(device)
            logits = model(x)
            all_logits.append(logits.cpu())
            all_y.append(y)
    all_logits = torch.cat(all_logits, dim=0)
    all_y = torch.cat(all_y, dim=0).squeeze().numpy()
    test_imgs = torch.cat(all_imgs, dim=0)
    probs_test = torch.softmax(all_logits, dim=1).numpy()
    metrics_test = compute_metrics(all_y, probs_test, n_classes=n_classes)
    print(f"[{model_name}] TEST acc={metrics_test['acc']:.4f} auroc={metrics_test['auroc']:.4f}")

    return model, metrics_test, all_y, probs_test, test_imgs

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 5
BATCH_SIZE = 128
LR = 3e-4
WEIGHT_DECAY = 1e-4

results_dir = 'results_week2_example'
os.makedirs(results_dir, exist_ok=True)

smallcnn_model, metrics_small, y_test_small, p_test_small, imgs_test_small = train_one_model(
    model_name='smallcnn',
    finetune='head',
    epochs=EPOCHS,
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    batch_size=BATCH_SIZE,
    device=device,
)

resnet_model, metrics_resnet, y_test_resnet, p_test_resnet, imgs_test_resnet = train_one_model(
    model_name='resnet18',
    finetune='head',
    epochs=EPOCHS,
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    batch_size=BATCH_SIZE,
    device=device,
)

with open(os.path.join(results_dir, 'metrics_smallcnn.json'), 'w') as f:
    json.dump(metrics_small, f, indent=2)
with open(os.path.join(results_dir, 'metrics_resnet18_head.json'), 'w') as f:
    json.dump(metrics_resnet, f, indent=2)

print('Saved metrics JSON files to:', results_dir)
print('smallcnn metrics:', metrics_small)
print('resnet18-head metrics:', metrics_resnet)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def reliability_diagram_with_ece(y_true, y_prob, n_bins: int = 10, strategy: str = 'equal_width', title_prefix: str = ''):
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    confidences = y_prob.max(axis=1)
    preds = y_prob.argmax(axis=1)
    correct = (preds == y_true).astype(float)

    if strategy == 'equal_freq':
        quantiles = np.linspace(0.0, 1.0, n_bins + 1)
        bins = np.quantile(confidences, quantiles)
        bins[0], bins[-1] = 0.0, 1.0  # ensure full coverage
    else:  # 'equal_width'
        bins = np.linspace(0.0, 1.0, n_bins + 1)

    xs, bin_accs, bin_confs, bin_counts = [], [], [], []
    ece = 0.0
    n = len(confidences)

    for i in range(n_bins):
        lo, hi = bins[i], bins[i + 1]
        if i == 0:
            mask = (confidences >= lo) & (confidences <= hi)
        else:
            mask = (confidences > lo) & (confidences <= hi)
        count = mask.sum()
        center = (lo + hi) / 2.0
        xs.append(center)
        if count == 0:
            bin_accs.append(0.0)
            bin_confs.append(center)
            bin_counts.append(0)
            continue
        acc_i = correct[mask].mean()
        conf_i = confidences[mask].mean()
        frac_i = count / max(1, n)
        ece += abs(acc_i - conf_i) * frac_i
        bin_accs.append(acc_i)
        bin_confs.append(conf_i)
        bin_counts.append(count)

    fig, ax1 = plt.subplots()
    ax1.plot([0, 1], [0, 1], linestyle='--')
    width = 1.0 / n_bins
    ax1.bar(xs, bin_accs, width=width, alpha=0.6, edgecolor='k')
    ax1.plot(xs, bin_confs, marker='o')
    ax1.set_xlabel('Confidence')
    ax1.set_ylabel('Accuracy')

    # Show fraction of samples as a second y-axis (simple histogram)
    ax2 = ax1.twinx()
    ax2.bar(xs, np.array(bin_counts) / max(1, n), width=width, alpha=0.3)
    ax2.set_ylabel('Fraction of samples')

    title = f"{title_prefix} Reliability (ECE={ece:.3f}, {strategy})"
    ax1.set_title(title)
    plt.tight_layout()
    plt.show()
    return float(ece)

# Example: ResNet18-head calibration snapshot
ece_resnet_width = reliability_diagram_with_ece(
    y_test_resnet,
    p_test_resnet,
    n_bins=15,
    strategy='equal_width',
    title_prefix='ResNet18-head'
)
ece_resnet_freq = reliability_diagram_with_ece(
    y_test_resnet,
    p_test_resnet,
    n_bins=15,
    strategy='equal_freq',
    title_prefix='ResNet18-head'
)

print('ECE (equal-width bins):', ece_resnet_width)
print('ECE (equal-frequency bins):', ece_resnet_freq)

In [None]:
# Optional: Temperature Scaling (TS) for ResNet18-head
# This is not required for Week 2, but many PneumoniaMNIST runs
# see a noticeable ECE drop after a single TS step.

import torch.nn as nn

def collect_logits_and_labels(model, loader, device):
    model.eval()
    all_logits, all_y = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            logits = model(x)
            all_logits.append(logits.cpu())
            all_y.append(y)
    all_logits = torch.cat(all_logits, dim=0)
    all_y = torch.cat(all_y, dim=0).squeeze()
    return all_logits, all_y

_, val_loader, test_loader, _ = get_loaders(DATASET_KEY, batch_size=BATCH_SIZE, label_frac=1.0, seed=42)
val_logits, val_y = collect_logits_and_labels(resnet_model, val_loader, device)
test_logits, test_y = collect_logits_and_labels(resnet_model, test_loader, device)

def tune_temperature(logits, labels, lr: float = 0.01, max_epochs: int = 100):
    # Optimize log_T to keep T positive
    log_T = torch.zeros(1, requires_grad=True)
    optimizer = torch.optim.Adam([log_T], lr=lr)
    nll = nn.CrossEntropyLoss()
    for _ in range(max_epochs):
        optimizer.zero_grad()
        loss = nll(logits / torch.exp(log_T), labels)
        loss.backward()
        optimizer.step()
    return torch.exp(log_T).detach()

T_opt = tune_temperature(val_logits, val_y)
print('Optimal temperature T:', T_opt.item())

p_test_TS = torch.softmax(test_logits / T_opt, dim=1).numpy()
ece_resnet_TS = reliability_diagram_with_ece(
    test_y.numpy(),
    p_test_TS,
    n_bins=15,
    strategy='equal_width',
    title_prefix='ResNet18-head + TS'
)
print('ECE after TS (equal-width bins):', ece_resnet_TS)

In [None]:
# Misclassified examples for ResNet18-head
import numpy as np
import matplotlib.pyplot as plt

def show_misclassified_gallery(images, y_true, y_prob, n_examples: int = 5):
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    preds = y_prob.argmax(axis=1)
    confidences = y_prob.max(axis=1)
    wrong = np.where(preds != y_true)[0]
    if len(wrong) == 0:
        print('No misclassified examples found.')
        return
    order = np.argsort(-confidences[wrong])
    sel = wrong[order][:n_examples]

    fig, axes = plt.subplots(1, len(sel), figsize=(3 * len(sel), 3))
    if len(sel) == 1:
        axes = [axes]

    for ax, idx in zip(axes, sel):
        img = images[idx]
        # images is a tensor of shape (N, C, H, W)
        if hasattr(img, 'numpy'):
            img_np = img.numpy()
        else:
            img_np = np.array(img)
        if img_np.ndim == 3 and img_np.shape[0] in (1, 3):
            img_np = np.transpose(img_np, (1, 2, 0))
            if img_np.shape[2] == 1:
                ax.imshow(img_np[:, :, 0], cmap='gray')
            else:
                ax.imshow(img_np)
        else:
            ax.imshow(img_np.squeeze(), cmap='gray')
        ax.axis('off')
        ax.set_title(f'true={int(y_true[idx])}, pred={int(preds[idx])}\nconf={confidences[idx]:.2f}')
    plt.tight_layout()
    plt.show()

show_misclassified_gallery(imgs_test_resnet, y_test_resnet, p_test_resnet, n_examples=5)