# pneumonia mnist resnet18 (colab ready)

simple end to end notebook for training and explaining a resnet18 on pneumonia mnist.


## setup

* switch runtime to gpu in colab (runtime > change runtime type > gpu)
* mount drive so checkpoints and figures are saved


In [None]:
# mount google drive to keep outputs
from google.colab import drive
drive.mount('/gdrive')


In [None]:
# install needed libraries (torch + medmnist + shap)
# keep versions explicit so results are repeatable
!pip install --quiet torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2
!pip install --quiet medmnist==2.2.2 shap==0.44.1 scikit-learn==1.3.2 matplotlib==3.8.2 seaborn==0.13.1


In [None]:
# imports
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, roc_curve
import medmnist
from medmnist import PneumoniaMNIST
import shap

# make runs repeatable
random.seed(7)
np.random.seed(7)
torch.manual_seed(7)
torch.cuda.manual_seed_all(7)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

# folder to save checkpoints and plots inside drive
output_dir = '/gdrive/MyDrive/pneumonia_mnist_resnet18'
os.makedirs(output_dir, exist_ok=True)


## data

pneumonia mnist is already preprocessed. we just load it with medmnist and add a simple normalization.


In [None]:
# transforms keep channel as single grayscale image
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

train_ds = PneumoniaMNIST(split='train', download=True, transform=img_transform)
val_ds = PneumoniaMNIST(split='val', download=True, transform=img_transform)
test_ds = PneumoniaMNIST(split='test', download=True, transform=img_transform)

batch_size = 128
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)

num_classes = len(set(train_ds.labels.reshape(-1).tolist()))
print('classes:', num_classes)


## model helper

we use torchvision resnet18 with imagenet weights. first conv is adjusted for 1 channel images. transfer learning happens by choosing which blocks to freeze.


In [None]:
def build_resnet18(freeze_parts=None):
    # load imagenet pretrained weights
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    # change first conv to accept 1 channel
    base_conv = model.conv1
    model.conv1 = nn.Conv2d(1, base_conv.out_channels, kernel_size=base_conv.kernel_size,
                            stride=base_conv.stride, padding=base_conv.padding, bias=False)
    # average weights of original channels to init
    with torch.no_grad():
        model.conv1.weight = nn.Parameter(base_conv.weight.mean(dim=1, keepdim=True))

    # swap final layer to binary head
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    # freeze selected parts
    if freeze_parts:
        for name, param in model.named_parameters():
            for block_name in freeze_parts:
                if name.startswith(block_name):
                    param.requires_grad = False
                    break
    return model

# examples of freeze configurations
freeze_configs = [
    {'label': 'freeze_to_layer3', 'freeze_parts': ['conv1', 'bn1', 'layer1', 'layer2', 'layer3']},
    {'label': 'freeze_to_layer2', 'freeze_parts': ['conv1', 'bn1', 'layer1', 'layer2']},
    {'label': 'full_finetune', 'freeze_parts': []},
]


## training utils

simple training + validation loops that track best checkpoint per config.


In [None]:
def run_epoch(model, loader, criterion, optimizer=None):
    model.train() if optimizer else model.eval()
    running_loss = 0.0
    preds, targets = [], []

    for images, labels in loader:
        images = images.to(device)
        labels = labels.squeeze().long().to(device)

        if optimizer:
            optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        running_loss += loss.item() * images.size(0)

        if optimizer:
            loss.backward()
            optimizer.step()

        pred_labels = torch.argmax(outputs, dim=1)
        preds.extend(pred_labels.detach().cpu().numpy())
        targets.extend(labels.detach().cpu().numpy())

    avg_loss = running_loss / len(loader.dataset)
    acc = accuracy_score(targets, preds)
    return avg_loss, acc


def train_one_setting(label, freeze_parts, epochs=8, lr=1e-3):
    print(f'
>>> training config: {label}')
    model = build_resnet18(freeze_parts).to(device)

    # only update trainable params
    optim_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(optim_params, lr=lr, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = 0.0
    best_path = os.path.join(output_dir, f'{label}_best.pt')

    for epoch in range(1, epochs + 1):
        train_loss, train_acc = run_epoch(model, train_loader, criterion, optimizer)
        val_loss, val_acc = run_epoch(model, val_loader, criterion)
        print(f'epoch {epoch}: train_loss {train_loss:.4f} acc {train_acc:.4f} | val_loss {val_loss:.4f} acc {val_acc:.4f}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), best_path)
            print('saved new best model')
    return best_path, best_val_acc


## run transfer learning experiments

train with different freeze depths to see what works best. adjust epochs upward if the gpu session is stable.


In [None]:
results = []
for cfg in freeze_configs:
    ckpt, val_acc = train_one_setting(cfg['label'], cfg['freeze_parts'], epochs=10, lr=1e-3)
    results.append({'label': cfg['label'], 'val_acc': val_acc, 'ckpt': ckpt})

# pick best config
results = sorted(results, key=lambda x: x['val_acc'], reverse=True)
best = results[0]
print('
best config:', best)


## evaluation on test set

load the best checkpoint, evaluate, and draw confusion matrix + roc curve.


In [None]:
# load best model
best_model = build_resnet18()
best_model.load_state_dict(torch.load(best['ckpt'], map_location=device))
best_model.to(device)
best_model.eval()

# collect outputs
all_logits, all_targets = [], []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.squeeze().long().to(device)
        logits = best_model(images)
        all_logits.append(logits.cpu())
        all_targets.append(labels.cpu())

logits = torch.cat(all_logits)
targets = torch.cat(all_targets)
probs = torch.softmax(logits, dim=1)[:, 1]
pred_labels = torch.argmax(logits, dim=1)

acc = accuracy_score(targets, pred_labels)
cm = confusion_matrix(targets, pred_labels)

fpr, tpr, _ = roc_curve(targets, probs)
roc_auc = roc_auc_score(targets, probs)

print('test accuracy:', acc)
print('roc auc:', roc_auc)

# plot confusion matrix
plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('predicted')
plt.ylabel('true')
plt.title('confusion matrix')
cm_path = os.path.join(output_dir, 'confusion_matrix.png')
plt.savefig(cm_path, dpi=150, bbox_inches='tight')
plt.show()

# plot roc curve
plt.figure(figsize=(5,4))
plt.plot(fpr, tpr, label=f'auc={roc_auc:.3f}')
plt.plot([0,1],[0,1],'k--')
plt.xlabel('false positive rate')
plt.ylabel('true positive rate')
plt.title('roc curve')
plt.legend()
roc_path = os.path.join(output_dir, 'roc_curve.png')
plt.savefig(roc_path, dpi=150, bbox_inches='tight')
plt.show()


## shap gradient explainer

use a small background batch for the gradient explainer. we look at one correct and one incorrect prediction (from different classes) to see what pixels drive the decision.


In [None]:
# pick background samples
background_images, _ = next(iter(train_loader))
background_images = background_images[:50].to(device)

# prepare explainer
grad_explainer = shap.GradientExplainer(best_model, background_images)

# find correct and incorrect examples
best_model.eval()
correct_example = None
incorrect_example = None

for images, labels in test_loader:
    images = images.to(device)
    labels = labels.squeeze().long().to(device)
    with torch.no_grad():
        outputs = best_model(images)
        preds = torch.argmax(outputs, dim=1)
    for img, label, pred in zip(images, labels, preds):
        if correct_example is None and pred == label:
            correct_example = (img.unsqueeze(0), label.item(), pred.item())
        if incorrect_example is None and pred != label:
            incorrect_example = (img.unsqueeze(0), label.item(), pred.item())
        if correct_example and incorrect_example:
            break
    if correct_example and incorrect_example:
        break

examples = {'correct_case': correct_example, 'incorrect_case': incorrect_example}

for tag, sample in examples.items():
    if sample is None:
        print(f'skipping {tag} because it was not found')
        continue
    img, true_label, pred_label = sample
    shap_values = grad_explainer.shap_values(img)
    shap.image_plot(shap_values, img.cpu().numpy(), show=False)
    plt.title(f'{tag} true={true_label} pred={pred_label}')
    fig_path = os.path.join(output_dir, f'{tag}_shap.png')
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    plt.show()


## notes for the report

* keep an eye on whether freezing more layers hurts validation accuracy. usually unfreezing layer4 helps.
* confusion matrix shows where false positives/negatives happen.
* shap maps highlight lung regions that push decisions; noisy edges suggest overfitting.
* store your favorite checkpoint and figures in drive for later write-up.
