# Install Dependencies

In [None]:
!pip install medmnist torch torchvision scikit-learn seaborn matplotlib

# Imports Packages + Setup

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import (
    roc_auc_score, accuracy_score,
    precision_score, recall_score,
    f1_score, confusion_matrix,
    roc_curve
)

from sklearn.calibration import calibration_curve

from torchvision import models, transforms
from torch.utils.data import DataLoader
from medmnist import PneumoniaMNIST

from torch.cuda.amp import autocast, GradScaler

os.makedirs("models", exist_ok=True)
os.makedirs("reports", exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


# Dataset Preparation 

In [None]:
def get_pneumonia_data(batch_size=64):

    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ])

    train_dataset = PneumoniaMNIST(split='train', transform=train_transform, download=True)
    val_dataset = PneumoniaMNIST(split='val', transform=test_transform, download=True)
    test_dataset = PneumoniaMNIST(split='test', transform=test_transform, download=True)

    return (
        DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
        DataLoader(val_dataset, batch_size=batch_size),
        DataLoader(test_dataset, batch_size=batch_size)
    )


# Model Design

In [None]:
def build_model():
    model = models.resnet18(pretrained=False)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, 1)
    return model


# Training with Early Stopping + Mixed Precision

In [None]:
BATCH_SIZE = 64
EPOCHS = 25
LR = 1e-3
PATIENCE = 5

train_loader, val_loader, test_loader = get_pneumonia_data(BATCH_SIZE)

model = build_model().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

scaler = GradScaler(enabled=(device=="cuda"))

best_auc = 0
early_stop_counter = 0
train_losses = []

for epoch in range(EPOCHS):

    model.train()
    running_loss = 0

    for x, y in train_loader:
        x, y = x.to(device), y.float().to(device)

        optimizer.zero_grad()

        with autocast(enabled=(device=="cuda")):
            logits = model(x).squeeze()
            loss = criterion(logits, y.squeeze())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)

    # Validation
    model.eval()
    all_labels, all_preds = [], []

    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device)
            logits = model(x).squeeze()
            probs = torch.sigmoid(logits)
            all_labels.extend(y.numpy())
            all_preds.extend(probs.cpu().numpy())

    val_auc = roc_auc_score(all_labels, all_preds)

    print(f"Epoch {epoch+1} | Loss: {epoch_loss:.4f} | Val AUC: {val_auc:.4f}")

    scheduler.step()

    # Early stopping
    if val_auc > best_auc:
        best_auc = val_auc
        torch.save(model.state_dict(), "models/best_model.pth")
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= PATIENCE:
            print("Early stopping triggered.")
            break


# Test Evaluation

In [None]:
model.load_state_dict(torch.load("models/best_model.pth"))
model.eval()

all_preds, all_labels = [], []

with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        logits = model(x).squeeze()
        probs = torch.sigmoid(logits)
        all_preds.extend(probs.cpu().numpy())
        all_labels.extend(y.numpy())

pred_labels = (np.array(all_preds) >= 0.5).astype(int)

auc = roc_auc_score(all_labels, all_preds)
acc = accuracy_score(all_labels, pred_labels)
prec = precision_score(all_labels, pred_labels)
rec = recall_score(all_labels, pred_labels)
f1 = f1_score(all_labels, pred_labels)

print("\nFINAL TEST METRICS")
print(f"AUC: {auc:.4f}")
print(f"Accuracy: {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall (Sensitivity): {rec:.4f}")
print(f"F1 Score: {f1:.4f}")


# Confusion Matrix

In [None]:
cm = confusion_matrix(all_labels, pred_labels)

plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt="d",
            xticklabels=["Normal","Pneumonia"],
            yticklabels=["Normal","Pneumonia"])
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()


# ROC Curve

In [None]:
fpr, tpr, _ = roc_curve(all_labels, all_preds)

plt.plot(fpr, tpr, label=f"AUC = {auc:.2f}")
plt.plot([0,1],[0,1],'--')
plt.title("ROC Curve")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.show()
