LibAUC Experiments

In [4]:
# Install libAUC and medMNIST
!pip install libauc==1.2.0
!pip install medmnist
!pip install tensorboardX
!pip install acsconv

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting libauc==1.2.0
  Downloading libauc-1.2.0-py3-none-any.whl (73 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.6/73.6 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: libauc
Successfully installed libauc-1.2.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting medmnist
  Downloading medmnist-2.2.1-py3-none-any.whl (21 kB)
Collecting fire
  Downloading fire-0.5.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.3/88.3 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.5.0-py2.py3-none-any.whl size=116952 sha256=9d70d305aa820b218f609047314807d5

In [5]:
from libauc.models import resnet18
from libauc.losses import AUCMLoss

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as torch_data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator


from libauc.losses import AUCMLoss, CrossEntropyLoss
from libauc.optimizers import PESG, Adam
from libauc.models import densenet121 as DenseNet121
from libauc.datasets import CheXpert

import torch 
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score

import warnings
warnings.filterwarnings("ignore") 

In [6]:
# data_flag = 'pathmnist'
data_flag = 'chestmnist'
download = True

NUM_EPOCHS = 3
BATCH_SIZE = 8
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [7]:
task

'multi-label, binary-class'

In [8]:
data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.Grayscale(3),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=3),
            transforms.ToTensor(),
            transforms.Normalize(0.5, 0.5),
        ])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download)
val_dataset = DataClass(split='val', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)

pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = torch_data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = torch_data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = torch_data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

Downloading https://zenodo.org/record/6496656/files/chestmnist.npz?download=1 to /root/.medmnist/chestmnist.npz


100%|██████████| 82802576/82802576 [00:06<00:00, 12137964.62it/s]


Using downloaded and verified file: /root/.medmnist/chestmnist.npz
Using downloaded and verified file: /root/.medmnist/chestmnist.npz
Using downloaded and verified file: /root/.medmnist/chestmnist.npz


In [11]:
lr = 0.1 # using smaller learning rate is better
epoch_decay = 0.03
weight_decay = 1e-5
margin = 1.0

model = resnet18(num_classes=14)
model = model.cuda()
criterion = AUCMLoss()
optimizer = PESG(model, 
                 loss_fn=criterion, 
                 lr=lr, 
                 margin=margin, 
                 epoch_decay=epoch_decay, 
                 weight_decay=weight_decay)

# training
best_val_auc = 0 
for epoch in range(200):
    train_losses = []
    for idx, data in enumerate(train_loader):
      print("Iteration started")
      train_data, train_labels = data
      train_data, train_labels  = train_data.cuda(), train_labels.cuda()
      y_pred = model(train_data)
      y_pred = torch.sigmoid(y_pred)
      loss = criterion(y_pred, train_labels.type(torch.LongTensor).cuda())
      train_losses.append(loss.item()/len(data))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      break

    print("Epoch : {:03d}  Train Loss : {:.5f} ".format(epoch, np.mean(train_losses)), end='')
    model.eval()
    with torch.no_grad():    
        test_pred = []
        test_true = [] 
        test_losses = []
        test_CE_losses = []
        for jdx, data in enumerate(test_loader):
            test_data, test_labels = data
            test_data = test_data.cuda()
            y_pred = model(test_data)
            test_pred.append(y_pred.cpu().detach().numpy())
            test_true.append(test_labels.numpy())
            test_losses.append(criterion(y_pred, test_labels.squeeze().type(torch.LongTensor).cuda()).item() / len(data))

        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        val_auc_mean = 0
        for i in range(14):
            val_auc_mean +=  roc_auc_score(test_true[:,i], test_pred[:,i]) 
        val_auc_mean/=14
        print("Val Loss : {:.5f}   ".format(np.mean(test_losses)), end = '')
        model.train()

        if best_val_auc < val_auc_mean:
            best_val_auc = val_auc_mean
            torch.save(model.state_dict(), 'pretrained_model.pth')

        print ('BatchID= {}   Val_AUC={:.4f}   Best_Val_AUC={:.4f}'.format(
            idx, val_auc_mean, best_val_auc ))
          
    print("Epoch : {}".format(epoch))

Iteration started
Epoch : 000  Train Loss : 0.12257 Val Loss : 0.02780   BatchID= 0   Val_AUC=0.4964   Best_Val_AUC=0.4964
Epoch : 0
Iteration started
Epoch : 001  Train Loss : 0.12370 Val Loss : 0.03827   BatchID= 0   Val_AUC=0.5037   Best_Val_AUC=0.5037
Epoch : 1
Iteration started
Epoch : 002  Train Loss : 0.11881 Val Loss : 0.03175   BatchID= 0   Val_AUC=0.5068   Best_Val_AUC=0.5068
Epoch : 2
Iteration started
Epoch : 003  Train Loss : 0.15592 Val Loss : 0.00663   BatchID= 0   Val_AUC=0.5043   Best_Val_AUC=0.5068
Epoch : 3
Iteration started
Epoch : 004  Train Loss : 0.09952 Val Loss : -0.02838   BatchID= 0   Val_AUC=0.5044   Best_Val_AUC=0.5068
Epoch : 4
Iteration started
Epoch : 005  Train Loss : 0.15463 Val Loss : -0.06696   BatchID= 0   Val_AUC=0.5085   Best_Val_AUC=0.5085
Epoch : 5
Iteration started
Epoch : 006  Train Loss : 0.16012 Val Loss : -0.10333   BatchID= 0   Val_AUC=0.4982   Best_Val_AUC=0.5085
Epoch : 6
Iteration started
Epoch : 007  Train Loss : 0.14152 Val Loss : -0.