*   Author: Zhuoning Yuan, Qi Qi
*   Project: https://github.com/yzhuoning/LibAUC



# **Installing LibAUC**

In [1]:
!pip install libauc-1.1.5-py3-none-any.whl


# **Importing LibAUC**

In [2]:
from libauc.losses import APLoss_SH
from libauc.optimizers import SOAP_SGD, SOAP_ADAM
from libauc.models import ResNet18
from libauc.datasets import CIFAR10
from libauc.datasets import ImbalanceGenerator, ImbalanceSampler 

import torchvision.transforms as transforms
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import numpy as np
import torch
from PIL import Image


# **Reproducibility**

In [3]:
def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# **Image Dataset**

In [4]:
class ImageDataset(Dataset):
    def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
       self.images = images.astype(np.uint8)
       self.targets = targets
       self.mode = mode
       self.transform_train = transforms.Compose([                                                
                              transforms.RandomCrop(image_size, padding=4),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(),
                              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                              
                              ])
       self.transform_test = transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                              ])
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        image = Image.fromarray(image.astype('uint8'))
        if self.mode == 'train':
            image = self.transform_train(image)
        else:
            image = self.transform_test(image)
        return idx, image, target


# **Paramaters**

In [13]:
# paramaters
imratio = 0.02
SEED = 123
BATCH_SIZE = 64
lr =  1e-6
weight_decay = 2e-4
margin = 0.6
beta = 0.99 # this refers to gamma for moving average in the paper
posNum = 1

# **Loading datasets**

In [14]:
# dataloader 
(train_data, train_label), (test_data, test_label) = CIFAR10()
(train_images, train_labels) = ImbalanceGenerator(train_data, train_label, imratio=imratio, shuffle=True, random_seed=SEED)
(test_images, test_labels) = ImbalanceGenerator(test_data, test_label, is_balanced=True,  random_seed=SEED)

train_dataset = ImageDataset(train_images, train_labels)
test_dataset = ImageDataset(test_images, test_labels, mode='test')
testloader = torch.utils.data.DataLoader(test_dataset , batch_size=BATCH_SIZE, shuffle=False, num_workers=1,  pin_memory=True)

NUM_SAMPLES: [25510], POS:NEG: [510 : 25000], POS_RATIO: 0.0200
NUM_SAMPLES: [10000], POS:NEG: [5000 : 5000], POS_RATIO: 0.5000


# **Creating models & AUC Optimizer**

In [15]:
set_all_seeds(456)
model = ResNet18(pretrained=False, last_activation=None) 
model = model.cuda()

# load ce pretrained model & remove final fc layers
PATH = 'cifar10_resnet18_002.ckpt' 
state_dict = torch.load(PATH)
state_dict.pop('fc.weights', None)
state_dict.pop('fc.bias', None)
model.load_state_dict(state_dict['model'], strict=True)

# SOAPLoss requires ImbalanceSampler() with pos_num>=1!
Loss = APLoss_SH(margin=margin, beta=beta, data_len=train_labels.shape[0])
optimizer = SOAP_SGD(model.parameters(), lr=lr, weight_decay=weight_decay)

# **Training**

In [16]:
# training 
model.train()
losses = []  
print ('-'*30)
total_iters = 0
for epoch in range(64):
    if epoch == 32:
       optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']/10
    
    train_pred = []
    train_true = []
    model.train() 
       
    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=ImbalanceSampler(train_labels.flatten().astype(int), BATCH_SIZE, pos_num=posNum), num_workers=2, pin_memory=True, drop_last=True) 
    for idx, (index, data, targets) in enumerate(trainloader):
        data, targets  = data.cuda(), targets.cuda()
        y_pred = model(data)
        y_prob = torch.sigmoid(y_pred)
        loss = Loss(y_prob, targets, index_s=index)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_pred.append(y_prob.cpu().detach().numpy())
        train_true.append(targets.cpu().detach().numpy())

    train_true = np.concatenate(train_true)
    train_pred = np.concatenate(train_pred)
    train_auc = roc_auc_score(train_true, train_pred) 
    train_prc = average_precision_score(train_true, train_pred)

    model.eval()
    test_pred = []
    test_true = [] 
    for j, data in enumerate(testloader):
        _, test_data, test_targets = data
        test_data = test_data.cuda()
        y_pred = model(test_data)
        y_prob = torch.sigmoid(y_pred)
        test_pred.append(y_prob.cpu().detach().numpy())
        test_true.append(test_targets.numpy())
    test_true = np.concatenate(test_true)
    test_pred = np.concatenate(test_pred)
     
    val_auc =  roc_auc_score(test_true, test_pred) 
    val_prc = average_precision_score(test_true, test_pred)
    
    model.train()
    print("epoch: {}, train_ap:{:4f}, test_ap:{:4f},  lr:{:4f}".format(epoch, train_prc, val_prc,  optimizer.param_groups[0]['lr'] ))
    

------------------------------
epoch: 0, train_ap:0.038561, test_ap:0.684382,  lr:0.000001
epoch: 1, train_ap:0.037349, test_ap:0.687822,  lr:0.000001
epoch: 2, train_ap:0.038272, test_ap:0.697211,  lr:0.000001
epoch: 3, train_ap:0.052288, test_ap:0.697083,  lr:0.000001
epoch: 4, train_ap:0.048474, test_ap:0.707550,  lr:0.000001
epoch: 5, train_ap:0.055943, test_ap:0.706793,  lr:0.000001
epoch: 6, train_ap:0.047432, test_ap:0.717946,  lr:0.000001
epoch: 7, train_ap:0.046604, test_ap:0.705159,  lr:0.000001
epoch: 8, train_ap:0.052999, test_ap:0.708554,  lr:0.000001
epoch: 9, train_ap:0.078836, test_ap:0.731893,  lr:0.000001
epoch: 10, train_ap:0.071232, test_ap:0.704826,  lr:0.000001
epoch: 11, train_ap:0.062467, test_ap:0.705674,  lr:0.000001
epoch: 12, train_ap:0.058837, test_ap:0.713067,  lr:0.000001
epoch: 13, train_ap:0.078756, test_ap:0.730586,  lr:0.000001
epoch: 14, train_ap:0.082710, test_ap:0.704639,  lr:0.000001
epoch: 15, train_ap:0.093598, test_ap:0.700707,  lr:0.000001
epo