In [1]:
#Install medmnist package for first run
#pip install medmnist

### Load the libraries

In [2]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms

import numpy as np
from enum import unique
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.metrics import auc, roc_curve
import matplotlib.pyplot as plt

import medmnist
from medmnist import Evaluator
from medmnist import INFO, BreastMNIST, ChestMNIST, PneumoniaMNIST
from medmnist.evaluator import Evaluator


Performs EDA

In [3]:

#list of datasets to be used for training and evaluation
dataset_names = ['breastmnist', 'pneumoniamnist', 'chestmnist']

for dataset_name in dataset_names:
    #Performs basic EDA on MedMNIST datasets
    print(f'----- EDA for {dataset_name.capitalize()} -----')
    info = INFO[dataset_name]
    task = info['task']
    sample = info['n_samples']
    n_channels = info['n_channels']
    dataset_train = sample['train']
    dataset_val = sample['val']
    dataset_test = sample['test']
    DataClass = getattr(medmnist, info['python_class'])

    temp_dataset = DataClass(split = 'train', transform = transforms.ToTensor(), download = True)
    train_labels = temp_dataset.labels


    # print(f'{info}')
    print(f'Task Type: {task}')
    print(f'Number of Channels: {n_channels}')
    print(f'Training Samples: {dataset_train}')
    print(f'Validation Samples: {dataset_val}')
    print(f'Testing Samples: {dataset_test}')
    print(f'Label Info: {info["label"]}\n\n')

    if len(info["label"]) > 2:
        labels_counts = np.sum(train_labels, axis = 0)

        #Use color map
        cmap = plt.cm.get_cmap('tab20')
        colors = cmap(np.arange(len(info["label"])))

        labels = [info['label'][str(i)] for i in range(len(info["label"]))]

        #Plot bar plot for multi-class
        plt.bar(labels, labels_counts, color = colors)
        plt.title(f"{dataset_name.capitalize()} - Label Distribution")
        plt.xlabel("Labels")
        plt.ylabel("Number of Positive Samples")
        #Vertical label
        plt.xticks(rotation = 45, ha = 'right')
    else:
        unique_labels, counts = np.unique(train_labels, return_counts = True)
        labels = [info['label'][str(i)] for i in range(len(unique_labels))]

        #plot bar plot
        colors = ['orange', 'blue']
        plt.bar(labels, counts, color = colors)
        plt.title(f"{dataset_name.capitalize()} - Label Distribution")
        plt.xlabel("Class Labels")
        plt.ylabel("Number of Samples")
    plt.grid(True)
    plt.savefig(f'figures/{dataset_name} bar.png')
    plt.close()

----- EDA for Breastmnist -----
Task Type: binary-class
Number of Channels: 1
Training Samples: 546
Validation Samples: 78
Testing Samples: 156
Label Info: {'0': 'malignant', '1': 'normal, benign'}


----- EDA for Pneumoniamnist -----
Task Type: binary-class
Number of Channels: 1
Training Samples: 4708
Validation Samples: 524
Testing Samples: 624
Label Info: {'0': 'normal', '1': 'pneumonia'}


----- EDA for Chestmnist -----
Task Type: multi-label, binary-class
Number of Channels: 1
Training Samples: 78468
Validation Samples: 11219
Testing Samples: 22433
Label Info: {'0': 'atelectasis', '1': 'cardiomegaly', '2': 'effusion', '3': 'infiltration', '4': 'mass', '5': 'nodule', '6': 'pneumonia', '7': 'pneumothorax', '8': 'consolidation', '9': 'edema', '10': 'emphysema', '11': 'fibrosis', '12': 'pleural', '13': 'hernia'}




  cmap = plt.cm.get_cmap('tab20')


### Valitaion function

In [4]:
### Valiation
def calculate_validation_metrics(model, val_load, device, dataset_name):
    model.eval()
    y_true = torch.tensor([])
    y_score = torch.tensor([])

    task = INFO[dataset_name]['task']

    with torch.no_grad():
        for inputs, targets in val_load:
            outputs = model(inputs.to(device))

            if task == 'binary-class':
                scores = F.softmax(outputs, dim = 1)
            else:
                scores = torch.sigmoid(outputs)

            y_true = torch.cat((y_true, targets.cpu()), 0)
            y_score = torch.cat((y_score, scores.cpu()), 0)

    evaluator = Evaluator(dataset_name, split = 'val')
    scores_numpy = y_score.numpy()

    if task == 'binary-class':
        scores_numpy = scores_numpy[:, 1]
    val_auc, _ = evaluator.evaluate(scores_numpy)
    return val_auc

Visualization Function

In [5]:
def plot_metrics_for_binary_task(y_true, y_score, dataset_name):
    plt.figure(figsize = (12, 8))
    # ----- ROC Curve -----
    fpr, tpr, threshold = roc_curve(y_true, y_score)
    calculated_roc_auc = auc(fpr, tpr)

    plt.plot(fpr, tpr, label = f'ROC curve (area = {calculated_roc_auc:.4f})')
    plt.plot([0, 1], [0,1], linestyle = '--', label = 'Random Classifier')
    plt.title(f'{dataset_name.capitalize()} ROC Curve')
    plt.xlabel('False Positive Rate (FPR)')
    plt.ylabel('True Positive Rate (TPR)')
    plt.legend()
    plt.savefig(f'figures/{dataset_name} ROC.png')
    plt.close()
    # ----- Score Distribution -----
    score_class_0 = y_score[y_true == 0]
    score_class_1 = y_score[y_true == 1]

    plt.hist(score_class_0, bins = 20, alpha = 0.5, density = True, label = 'True Negative Scores (Class 0)')
    plt.hist(score_class_1, bins = 20, alpha = 0.5, density = True, label = 'True Positive Scores (Class 1)')
    plt.xlabel('Predicted Probability (Score)')
    plt.ylabel('Frequency')
    plt.title(f'{dataset_name.capitalize()} Score Distribution')
    plt.legend()
    plt.savefig(f'figures/{dataset_name} score.png')
    plt.close()

def plot_score_distribution(y_true, y_score, dataset_name):
    plt.figure(figsize = (12, 8))

    scores_class_0 = y_true[y_true == 0]
    scores_class_1 = y_score[y_true == 1]

    plt.hist(scores_class_0, bins = 20, alpha = 0.6, label = 'True Negative Scores (Class 0)')
    plt.hist(scores_class_1, bins = 20, alpha = 0.6, label = 'True Positive Scores (Class 1)')

    plt.xlabel('Predicted Probability (Score)')
    plt.ylabel('Normalize Frequency (Density)')
    plt.title(f'{dataset_name.capitalize()} Score Distribution')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'figures/{dataset_name} score.png')
    plt.close()

### Training and Evaluating model

In [6]:
#list of datasets to be used for training and evaluation
dataset_names = ['breastmnist', 'pneumoniamnist', 'chestmnist']

for dataset_name in dataset_names:
    print(f"----- Training and Evaluation for {dataset_name.capitalize()} -----")
    ### Define and train model
    info = INFO[dataset_name]
    task = info['task'] #identifies
    n_channels = info['n_channels'] #Number of input channels
    n_classes = len(info['label']) #Number of output channels

    #dynamic loss function
    if task == 'multi-label, binary-class':
        crit = nn.BCEWithLogitsLoss()
    else:
        crit = nn.CrossEntropyLoss()

    #Calculate Dataset Statistic for Normalization
    DataClass = getattr(medmnist, info['python_class'])
    #Load dataset with ToTensor() to calculate mean and std deviation
    temp_dataset = DataClass(split = 'train', transform = transforms.ToTensor(), download = True)
    temp_loader = data.DataLoader(temp_dataset, batch_size = len(temp_dataset))
    images, _ = next(iter(temp_loader))

    dataset_mean = images.mean(dim = [0, 2, 3]).tolist()
    dataset_std = images.std(dim = [0, 2, 3]).tolist()

    print(f"Stats for {dataset_name}: Mean = {dataset_mean[0]:.4f}, Standard Deviation = {dataset_std[0]:.4f}")

    ### pre-processing
    transformed = transforms.Compose([
        transforms.ToTensor(),
        #Normalize data using the calculated mean and std for better model convergence
        transforms.Normalize(mean =dataset_mean, std = dataset_std)
    ])
    #Load actual data
    train_dataset = DataClass(split = 'train', transform = transformed, download = True)
    val_dataset = DataClass(split = 'val', transform = transformed, download = True)
    test_dataset = DataClass(split = 'test', transform = transformed, download = True)

    sizes = []
    for img, label in train_dataset:
        sizes.append(img.numpy().flatten())

    all_sizes = np.concatenate(sizes, axis = 0)
    bmean = all_sizes.mean()
    bstd = all_sizes.std()

    #Define dataloader obj
    BATCH_SIZE = 128
    train_load = data.DataLoader(dataset = train_dataset, batch_size = BATCH_SIZE, shuffle = True)
    val_load = data.DataLoader(dataset = val_dataset, batch_size = BATCH_SIZE, shuffle = False)
    test_load = data.DataLoader(dataset = test_dataset, batch_size = BATCH_SIZE, shuffle = False)

    #Model initialization (ResNet-18)
    model = models.resnet18(weights = None)
    #ajdust the input layer if it's a grayscale img
    if n_channels == 1:
        model.conv1 = nn.Conv2d(1, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)

    #adjust the final fully connected larer for the number of classes
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, n_classes)

    #move to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    #define loss, optim, training loop
    optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
    NUM_EPOCHS = 30

    ### Training data
    best_val_auc = float('-inf')
    # Initialize lists to store per-epoch data
    training_loss = [] # This will now store average loss per epoch
    val_auc_list = []  # This already stores AUC per epoch

    for epoch in range(NUM_EPOCHS):
        model.train()
        current_epoch_total_loss = 0.0
        num_samples_in_epoch = 0
        for inputs, targets in train_load:
            optimizer.zero_grad()
            inputs = inputs.to(device)

            #prepare targets based on task
            if task == 'multi-label, binary-class':
                targets = targets.to(device).float() #BCEWithLogitsLoss require float target
            else:
                targets = targets.to(device).long().squeeze() #CrossEntropyLoss require long target

            outputs = model(inputs)
            loss = crit(outputs, targets)
            loss.backward()
            optimizer.step()

            current_epoch_total_loss += loss.item() * inputs.size(0)
            num_samples_in_epoch += inputs.size(0)

        #calculate average loss for the current epoch
        if num_samples_in_epoch > 0:
            avg_loss_this_epoch = current_epoch_total_loss / num_samples_in_epoch
            training_loss.append(avg_loss_this_epoch) # Store average loss per epoch
        else:
            training_loss.append(0.0) # Handle empty train_load

        #validation check
        val_auc = calculate_validation_metrics(model, val_load, device, dataset_name)
        val_auc_list.append(val_auc) # Store validation AUC per epoch

        #save best model based on validation AUC
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            #--------------- Save checkpoint ---------------
            save_ckpt = f'checkpoints/{dataset_name}_ckpt.pth'
            torch.save(model.state_dict(), save_ckpt)
            print(f"Model Best Val AUC: {val_auc:.4f} saved to: {save_ckpt}")
        model.train()

    #evaluation
    model.eval()
    y_true = torch.tensor([])
    y_score = torch.tensor([])

    with torch.no_grad():
        for inputs, targets in tqdm(test_load):
            inputs = inputs.to(device)
            targets = targets.long().squeeze().to(device)

            #get output
            outputs = model(inputs)

            #multi class
            # scores = F.softmax(outputs, dim = 1)
            scores = torch.sigmoid(outputs)

            #collect true label
            y_true = torch.cat((y_true, targets.cpu()),0)
            y_score = torch.cat((y_score, scores.cpu()),0)

    evaluator = Evaluator(dataset_name, split = 'test')
    auc_score, acc_score = evaluator.evaluate(y_score.numpy(), save_folder = None, run = False)
    print(f"\nTest AUC: {auc_score:.4f}, Test ACC: {acc_score:.4f}")

    ### Load the saved checkpoint for another test
    #initialize the model architecture (ResNet18)
    model_test = models.resnet18(weights = None)

    #Adjust first layer for single-channel input
    model_test.conv1 = nn.Conv2d(n_channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)

    #Adjust the final fully connected layer for 2 classes
    num_ftrs_test = model_test.fc.in_features
    model_test.fc = nn.Linear(num_ftrs, n_classes)

    #Move to the appropriate device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_test.to(device)
    print("model_test architecture created successfully and ready to load weights.")

    #load weights from the best-performing epoch found during validation
    model_test.load_state_dict(torch.load(save_ckpt, map_location = device))

    y_true_list = []
    y_score_list = []
    model_test.eval()

    #evaluate the best model on test set
    with torch.no_grad():
        for inputs, targets in tqdm(test_load, desc = "Testing Loaded Checkpoint"):
            inputs = inputs.to(device)
            targets = targets.long().squeeze().to(device)

            #use the loaded model_test instance
            outputs = model_test(inputs)

            #get probabilities
            scores = F.softmax(outputs, dim = 1)

            if dataset_name in ['breastmnist', 'pneumoniamnist']:
              y_score_list.append(scores.cpu()[:, 1])
            elif dataset_name == 'chestmnist':
              scores = torch.sigmoid(outputs)
              y_score_list.append(scores.cpu())

            #collect true label
            y_true_list.append(targets.cpu())

    y_true = torch.cat(y_true_list, 0)
    y_score = torch.cat(y_score_list, 0)

    y_true_np = y_true.numpy()
    y_score_np = y_score.numpy()

    evaluator = Evaluator(dataset_name, split = 'test')

    #metrics
    auc_score, acc_score = evaluator.evaluate(y_score.numpy())

    print(f"Final test AUC: {auc_score:.4f}")
    print(f"Final test ACC: {acc_score:.4f}\n\n")

    #plot final test result (ROC)
    if task == 'binary-class':
        plot_metrics_for_binary_task(y_true_np, y_score_np, dataset_name)
    else:
        plot_score_distribution(y_true_np, y_score_np, dataset_name)

----- Training and Evaluation for Breastmnist -----
Stats for breastmnist: Mean = 0.3276, Standard Deviation = 0.2057
Model Best Val AUC: 0.6483 saved to: checkpoints/breastmnist_ckpt.pth
Model Best Val AUC: 0.6775 saved to: checkpoints/breastmnist_ckpt.pth
Model Best Val AUC: 0.7151 saved to: checkpoints/breastmnist_ckpt.pth
Model Best Val AUC: 0.7678 saved to: checkpoints/breastmnist_ckpt.pth
Model Best Val AUC: 0.8329 saved to: checkpoints/breastmnist_ckpt.pth
Model Best Val AUC: 0.8346 saved to: checkpoints/breastmnist_ckpt.pth
Model Best Val AUC: 0.8630 saved to: checkpoints/breastmnist_ckpt.pth
Model Best Val AUC: 0.8906 saved to: checkpoints/breastmnist_ckpt.pth
Model Best Val AUC: 0.8914 saved to: checkpoints/breastmnist_ckpt.pth
Model Best Val AUC: 0.9114 saved to: checkpoints/breastmnist_ckpt.pth
Model Best Val AUC: 0.9373 saved to: checkpoints/breastmnist_ckpt.pth


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 26.84it/s]



Test AUC: 0.8340, Test ACC: 0.8269
model_test architecture created successfully and ready to load weights.


Testing Loaded Checkpoint: 100%|█████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 24.29it/s]


Final test AUC: 0.8552
Final test ACC: 0.8077


----- Training and Evaluation for Pneumoniamnist -----
Stats for pneumoniamnist: Mean = 0.5719, Standard Deviation = 0.1684
Model Best Val AUC: 0.9791 saved to: checkpoints/pneumoniamnist_ckpt.pth
Model Best Val AUC: 0.9883 saved to: checkpoints/pneumoniamnist_ckpt.pth
Model Best Val AUC: 0.9893 saved to: checkpoints/pneumoniamnist_ckpt.pth
Model Best Val AUC: 0.9920 saved to: checkpoints/pneumoniamnist_ckpt.pth
Model Best Val AUC: 0.9926 saved to: checkpoints/pneumoniamnist_ckpt.pth
Model Best Val AUC: 0.9948 saved to: checkpoints/pneumoniamnist_ckpt.pth
Model Best Val AUC: 0.9951 saved to: checkpoints/pneumoniamnist_ckpt.pth
Model Best Val AUC: 0.9961 saved to: checkpoints/pneumoniamnist_ckpt.pth
Model Best Val AUC: 0.9964 saved to: checkpoints/pneumoniamnist_ckpt.pth
Model Best Val AUC: 0.9966 saved to: checkpoints/pneumoniamnist_ckpt.pth
Model Best Val AUC: 0.9969 saved to: checkpoints/pneumoniamnist_ckpt.pth


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 17.01it/s]



Test AUC: 0.9575, Test ACC: 0.8526
model_test architecture created successfully and ready to load weights.


Testing Loaded Checkpoint: 100%|█████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 14.77it/s]


Final test AUC: 0.9469
Final test ACC: 0.8830


----- Training and Evaluation for Chestmnist -----
Stats for chestmnist: Mean = 0.4936, Standard Deviation = 0.2380
Model Best Val AUC: 0.6733 saved to: checkpoints/chestmnist_ckpt.pth
Model Best Val AUC: 0.6952 saved to: checkpoints/chestmnist_ckpt.pth
Model Best Val AUC: 0.7029 saved to: checkpoints/chestmnist_ckpt.pth
Model Best Val AUC: 0.7083 saved to: checkpoints/chestmnist_ckpt.pth
Model Best Val AUC: 0.7107 saved to: checkpoints/chestmnist_ckpt.pth
Model Best Val AUC: 0.7172 saved to: checkpoints/chestmnist_ckpt.pth


100%|████████████████████████████████████████████████████████████████████████████████| 176/176 [00:11<00:00, 14.84it/s]



Test AUC: 0.6498, Test ACC: 0.9339
model_test architecture created successfully and ready to load weights.


Testing Loaded Checkpoint: 100%|█████████████████████████████████████████████████████| 176/176 [00:10<00:00, 16.03it/s]


Final test AUC: 0.7054
Final test ACC: 0.9470


