LibAUC Experiments

In [2]:
from libauc.models import resnet18
from libauc.losses import AUCMLoss

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as torch_data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator


from libauc.losses import AUCMLoss, CrossEntropyLoss
from libauc.optimizers import PESG, Adam
from libauc.models import densenet121 as DenseNet121
from libauc.datasets import CheXpert

import torch 
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score

import warnings
warnings.filterwarnings("ignore") 

In [14]:
# data_flag = 'pathmnist'
data_flag = 'chestmnist'
download = True

NUM_EPOCHS = 3
BATCH_SIZE = 8
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [4]:
task

'multi-label, binary-class'

In [15]:
data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.Grayscale(3),
            transforms.ToTensor(),
            transforms.Normalize(0.5, 0.5),
        ])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download)
val_dataset = DataClass(split='val', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)

pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = torch_data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = torch_data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = torch_data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

Using downloaded and verified file: C:\Users\Hari\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\Hari\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\Hari\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\Hari\.medmnist\chestmnist.npz


In [None]:
lr = 0.01 # using smaller learning rate is better
epoch_decay = 2e-5
weight_decay = 1e-7
margin = 1.0

model = resnet18(num_classes=14)
model = model.cuda()
# criterion = nn.CrossEntropyLoss()  
# optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = AUCMLoss()
optimizer = PESG(model, 
                 loss_fn=criterion, 
                 lr=lr, 
                 margin=margin, 
                 epoch_decay=epoch_decay, 
                 weight_decay=weight_decay)
CE_loss = nn.CrossEntropyLoss()
# training
best_val_auc = 0 
for epoch in range(1001):
    train_losses = []
    for idx, data in enumerate(train_loader):
      print("Iteration started")
      train_data, train_labels = data
      train_data, train_labels  = train_data.cuda(), train_labels.cuda()
      y_pred = model(train_data)
      y_pred = torch.sigmoid(y_pred)
      loss = criterion(y_pred, train_labels.type(torch.LongTensor).cuda())
      train_losses.append(loss.item()/len(data))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      break

    print("Epoch : {:03d}  Train Loss : {:.5f} ".format(epoch, np.mean(train_losses)), end='')
    model.eval()
    with torch.no_grad():    
        test_pred = []
        test_true = [] 
        test_losses = []
        test_CE_losses = []
        for jdx, data in enumerate(test_loader):
            test_data, test_labels = data
            test_data = test_data.cuda()
            y_pred = model(test_data)
            test_pred.append(y_pred.cpu().detach().numpy())
            test_true.append(test_labels.numpy())
            test_losses.append(criterion(y_pred, test_labels.squeeze().type(torch.LongTensor).cuda()).item() / len(data))
            # test_CE_losses.append(CE_loss(y_pred, test_labels.squeeze().float().cuda()).cpu())

        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        val_auc_mean = 0
        for i in range(14):
            val_auc_mean +=  roc_auc_score(test_true[:,i], test_pred[:,i]) 
        val_auc_mean/=14
        print("Val Loss : {:.5f}   ".format(np.mean(test_losses)), end = '')
        model.train()

        if best_val_auc < val_auc_mean:
            best_val_auc = val_auc_mean
            torch.save(model.state_dict(), 'pretrained_model.pth')

        print ('BatchID= {}   Val_AUC={:.4f}   Best_Val_AUC={:.4f}'.format(
            idx, val_auc_mean, best_val_auc ))
          
    print("Epoch : {}".format(epoch))

In [23]:
test_true.shape

(22433, 14)