In [1]:
!pip install libauc medmnist tensorboardX pytorch_lightning --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.6/73.6 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.5/114.5 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m719.0/719.0 kB[0m [31m23.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.3/88.3 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m54.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m149.6/149.6 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m29.8 MB/s[

In [2]:
# Import required libraries
from tqdm import tqdm
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from libauc.losses import AUCMLoss
from libauc.losses import AUCM_MultiLabel
from libauc.optimizers import PESG
from libauc.losses import APLoss
from libauc.losses.auc import pAUCLoss
from libauc.models import resnet18 as ResNet18
from libauc.optimizers import SOAP, SOPA
import torchvision.transforms as transforms
from libauc.sampler import DualSampler
from torch.utils.data import Dataset
import medmnist
from medmnist import INFO, Evaluator
from torch.utils.data import random_split
from libauc.metrics import auc_roc_score
from sklearn.metrics import accuracy_score, roc_auc_score
from libauc.metrics import auc_prc_score
from libauc.metrics import pauc_roc_score
# from libauc.utils import auroc
import torchvision
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
import logging
logging.getLogger('lightning').setLevel(0)
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.WARNING)
logging.disable(logging.CRITICAL)

In [28]:
# Test Data
class ImageDataset(Dataset):
    def __init__(self, images, targets, image_size=32, crop_size=25, mode='train', data_aug_flag=False, data_flag=''):
        self.images = images.astype(np.uint8)
        self.targets = targets
        self.mode = mode
        self.data_aug_flag = data_aug_flag
        self.transform_train = transforms.Compose([                                                
                              transforms.ToTensor(),
                              transforms.Normalize(mean=[.5], std=[.5])
                              ])
        self.transform_test = transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize(mean=[.5], std=[.5])
                              ])
        if data_flag == 'synapsemnist3d':
              self.transform_test = transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Resize((32, 32)),
                             transforms.Normalize(mean=[.5], std=[.5])
                              ])


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        if self.mode == 'train':
           image = self.transform_train(image)
        else:
           image = self.transform_test(image)
        return idx, image, target 

def test(model, data_flag, BATCH_SIZE):
    info = INFO[data_flag]
    task = info['task']
    n_channels = info['n_channels']
    n_classes = len(info['label'])

    DataClass = getattr(medmnist, info['python_class'])
    test_dataset = DataClass(split='test', download=True)
    test_data = test_dataset.imgs 
    if data_flag != 'chestmnist':
        test_labels = test_dataset.labels[:, 0] 
        test_labels[test_labels != 0] = 99
        test_labels[test_labels == 0] = 1
        test_labels[test_labels == 99] = 0
    else:
        test_labels = test_dataset.labels

    # test_data = test_data/255.0 
    testSet = ImageDataset(test_data, test_labels, mode='test', data_flag=data_flag)
    test_loader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    model.eval()
    best_val_auc = 0
    train_pred_list = []
    train_true_list = []
    with torch.no_grad():
        for index, inputs, targets in test_loader:
            inputs  = inputs.cuda()
            outputs = model(inputs)
            train_pred_list.append(outputs.cpu().detach().numpy())
            train_true_list.append(targets.numpy())
        train_true = np.concatenate(train_true_list)
        train_pred = np.concatenate(train_pred_list)
        train_auc = np.mean(auc_roc_score(train_true, train_pred))
        train_pauc = np.mean(auc_prc_score(train_true, train_pred))
        best_auc = train_auc
        if data_flag in ['adrenalmnist3d','vesselmnist3d','nodulemnist3d']:
            best_auc = train_pauc
        if best_val_auc < train_auc:
            best_val_auc = train_auc
        print('AUC for test is: ' + str(best_auc))
    return None

In [4]:
# Breast MNIST
model = ResNet18(pretrained=False)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.load_state_dict(torch.load('breast_mnist_model.pt'))
model = model.cuda()
print('==> Evaluating ...')
print('\n')
test(model, 'breastmnist', 256)

==> Evaluating ...


Downloading https://zenodo.org/record/6496656/files/breastmnist.npz?download=1 to /root/.medmnist/breastmnist.npz


100%|██████████| 559580/559580 [00:01<00:00, 290382.12it/s]


AUC for test is: 0.9312865497076024


In [5]:
# Pneumonia MNIST
model = ResNet18(pretrained=False)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.load_state_dict(torch.load('pneumonia_mnist_model.pt'))
model = model.cuda()
print('==> Evaluating ...')
print('\n')
test(model, 'pneumoniamnist', 256)

==> Evaluating ...


Downloading https://zenodo.org/record/6496656/files/pneumoniamnist.npz?download=1 to /root/.medmnist/pneumoniamnist.npz


100%|██████████| 4170669/4170669 [00:14<00:00, 278106.33it/s]

AUC for test is: 0.9624863028709183





In [29]:
# Chest MNIST
from torchvision.models import resnet18
model =  resnet18(pretrained=False, num_classes=14)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.load_state_dict(torch.load('chest_mnist_model.pt'))
model = model.cuda()
print('==> Evaluating ...')
test(model, 'chestmnist', 256)

==> Evaluating ...
Using downloaded and verified file: /root/.medmnist/chestmnist.npz
AUC for test is: 0.7278451652570814


In [8]:
# Nodule MNIST
model = ResNet18(pretrained=False)
model.conv1 = nn.Conv2d(28, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.load_state_dict(torch.load('nodule_mnist_prc_model.pt'))
model = model.cuda()
print('==> Evaluating ...')
print('\n')
test(model, 'nodulemnist3d', 256)

==> Evaluating ...


Downloading https://zenodo.org/record/6496656/files/nodulemnist3d.npz?download=1 to /root/.medmnist/nodulemnist3d.npz


100%|██████████| 29299364/29299364 [01:47<00:00, 273379.79it/s]


AUC for test is: 0.9482103522389413


In [9]:
# Adrenal MNIST
model = ResNet18(pretrained=False)
model.conv1 = nn.Conv2d(28, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.load_state_dict(torch.load('adrenal_mnist_prc_model.pt'))
model = model.cuda()
print('==> Evaluating ...')
print('\n')
test(model, 'adrenalmnist3d', 256)

==> Evaluating ...


Downloading https://zenodo.org/record/6496656/files/adrenalmnist3d.npz?download=1 to /root/.medmnist/adrenalmnist3d.npz


100%|██████████| 276833/276833 [00:01<00:00, 233491.42it/s]

AUC for test is: 0.9404990524475865





In [10]:
# Vessel MNIST
model = ResNet18(pretrained=False)
model.conv1 = nn.Conv2d(28, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.load_state_dict(torch.load('vessel_mnist_prc_model.pt'))
model = model.cuda()
print('==> Evaluating ...')
print('\n')
test(model, 'vesselmnist3d', 256)

==> Evaluating ...


Downloading https://zenodo.org/record/6496656/files/vesselmnist3d.npz?download=1 to /root/.medmnist/vesselmnist3d.npz


100%|██████████| 398373/398373 [00:01<00:00, 241067.12it/s]

AUC for test is: 0.9751638697864619





In [31]:
# Synapse MNIST
model = ResNet18(pretrained=False)
model.conv1 = nn.Conv2d(28, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.load_state_dict(torch.load('synapse_mnist_model.pth'))
model = model.cuda()
print('==> Evaluating ...')
print('\n')
test(model, 'synapsemnist3d', 256)

==> Evaluating ...


Using downloaded and verified file: /root/.medmnist/synapsemnist3d.npz
AUC for test is: 0.79295515052222


In [None]:
# For reproducing code, from original ipynb file copy and the methods and execute with these params
# Breast MNIST
model, train_log, val_log = load_train_test_data('breastmnist', 100, 0.25, 0, False, False, 0.1, 128, 0.9, 'pesg', 1e-5, 1e-3, 100)
print('==> Evaluating ...')
test(model, 'breastmnist', 256)
# Pnuemonia MNIST
model, train_log, val_log = load_train_test_data('pneumoniamnist', 100, 0.25, 0, False, False, 0.01, 128, 0.9, 'pesg', 1e-2, 0.1, 100)
print('==> Evaluating ...')
test(model, 'pneumoniamnist', 256)
# Chest MNIST
imratio_list = [0.1, 0, 0.1, 0.2, 0,
       0, 0, 0 , 0, 0,
       0, 0, 0, 0]
model, train_log, val_log = load_train_test_data('chestmnist', 5, imratio_list, 0, False, False, 0.001, 128, 0.9, 'pesg', 1e-3, 0.001, 0.001, 10)
print('==> Evaluating ...')
test(model, 'chestmnist', 256)
#Nodule MNIST
model, train_log_prc, val_log_prc = load_train_test_data_prc('nodulemnist3d', 100, 0.25, 0, True, 0.1, 128, 0.9, '', 1e-3, 0.1, 100)
print('==> Evaluating ...')
test(model, 'nodulemnist3d', 256)
# Adrenal MNIST
model, train_log_prc, val_log_prc = load_train_test_data_prc('adrenalmnist3d', 100, 0.2, 0, True, 0.1, 128, 0.9, '', 1e-3, 0.1, 100)
print('==> Evaluating ...')
test(model, 'adrenalmnist3d', 256)
# Vessel MNIST
model, train_log_prc, val_log_prc = load_train_test_data_prc('vesselmnist3d', 100, 0.1, 0, True, 0.1, 128, 0.9, '', 1e-3, 0.1, 100)
print('==> Evaluating ...')
test_prc(model, 'vesselmnist3d', 256)
# Synapse MNIST
model, train_log, val_log = load_train_test_data('synapsemnist3d', 100, 0.27, 0, True, False, 0.1, 128, 0.9, 'pesg', 1e-1, 0, 0)
print('==> Evaluating ...')
test(model, 'synapsemnist3d', 256)