In [None]:
! pip install medmnist
! pip install libauc==1.2.0

In [1]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from medmnist import PneumoniaMNIST
from libauc.losses import APLoss, AUCMLoss
from libauc.optimizers import SOAP, PESG
from libauc.models import resnet18
import torch.optim as optim
import matplotlib.pyplot as plt

In [2]:
# Custom transformations
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomCrop(28),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(32)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(32)
])

In [3]:
# Load datasets
train_dataset = PneumoniaMNIST(split='train', transform=train_transform, download=True)
test_dataset = PneumoniaMNIST(split='test', transform=test_transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Downloading https://zenodo.org/records/10519652/files/pneumoniamnist.npz?download=1 to /home/grads/s/skpaul/.medmnist/pneumoniamnist.npz


100%|██████████| 4.17M/4.17M [00:01<00:00, 3.92MB/s]

Using downloaded and verified file: /home/grads/s/skpaul/.medmnist/pneumoniamnist.npz





In [4]:
# Modify ResNet18 to work with grayscale images
def get_resnet18():
    model = resnet18(num_classes=2)
    model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    return model

In [None]:
# Function to train model and record AUPRC and AUROC
def train_model(model, loss_fn, optimizer, num_epochs=15, scheduler=None):
    train_auprc, test_auprc, train_auroc, test_auroc = [], [], [], []

    for epoch in range(num_epochs):
        model.train()
        for data in train_loader:
            inputs, labels = data[0].to(device), data[1].squeeze().to(device)  # Adjust target dimension
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

        # Scheduler step
        if scheduler and epoch == 9:
            for param_group in optimizer.param_groups:
                param_group['lr'] /= 10

        # Calculate metrics
        train_prc, train_roc = evaluate_model(model, train_loader)
        test_prc, test_roc = evaluate_model(model, test_loader)

        train_auprc.append(train_prc)
        test_auprc.append(test_prc)
        train_auroc.append(train_roc)
        test_auroc.append(test_roc)
    
    return train_auprc, test_auprc, train_auroc, test_auroc

# Evaluation function for AUPRC and AUROC
def evaluate_model(model, loader):
    # Assuming the LibAUC library or another package provides AUPRC and AUROC calculations
    # Replace with appropriate metric functions if needed
    # This is a placeholder
    auprc, auroc = 0.85, 0.88  # Placeholder values
    return auprc, auroc

# Function to plot curves
def plot_curves(train_auprc, test_auprc, train_auroc, test_auroc, model_name):
    epochs = range(1, len(train_auprc) + 1)
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_auprc, label='Train AUPRC')
    plt.plot(epochs, test_auprc, label='Test AUPRC')
    plt.plot(epochs, train_auroc, label='Train AUROC')
    plt.plot(epochs, test_auroc, label='Test AUROC')
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.title(f'AUPRC and AUROC for {model_name}')
    plt.legend()
    plt.show()