# Targeted Selection Demo For Biomedical Datasets With Rare Classes

### Imports 

In [1]:
# !git clone https://github.com/decile-team/trust.git
# !pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ submodlib
# !git clone https://github.com/decile-team/distil.git
# !pip install medmnist

In [2]:
import time
import random
import datetime
import copy
import numpy as np
from tabulate import tabulate
import os
import csv
import json
import subprocess
import sys
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
from matplotlib import pyplot as plt
from trust.trust.utils.models.resnet import ResNet18
from trust.trust.utils.models.resnet import ResNet50
from trust.trust.utils.custom_dataset_medmnist import load_biodataset_custom
from torch.utils.data import Subset
from torch.autograd import Variable
import tqdm
from math import floor
from sklearn.metrics.pairwise import cosine_similarity, pairwise_distances
from trust.trust.strategies.smi import SMI
from trust.trust.strategies.random_sampling import RandomSampling
from distil.distil.active_learning_strategies.entropy_sampling import EntropySampling
from distil.distil.active_learning_strategies.badge import BADGE

seed=42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
from trust.trust.utils.utils import *
from trust.trust.utils.viz import tsne_smi

### Helper functions

In [3]:
def model_eval_loss(data_loader, model, criterion):
    total_loss = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            inputs, targets = inputs.to(device), targets.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    return total_loss

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
    elif isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()
                
def create_model(name, num_cls, device, embedding_type):
    if name == 'ResNet18':
        if embedding_type == "gradients":
            model = ResNet18(num_cls)
        else:
            model = models.resnet18()
    elif name == 'ResNet50':
        if embedding_type == "gradients":
            model = ResNet50(num_cls)
        else:
            model = models.resnet50()
    elif name == 'MnistNet':
        model = MnistNet()
    elif name == 'ResNet164':
        model = ResNet164(num_cls)
    model.apply(init_weights)
    model = model.to(device)
    return model

def loss_function():
    criterion = nn.CrossEntropyLoss()
    criterion_nored = nn.CrossEntropyLoss(reduction='none')
    return criterion, criterion_nored

def optimizer_with_scheduler(model, num_epochs, learning_rate, m=0.9, wd=5e-4):
    optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                          momentum=m, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    return optimizer, scheduler

def optimizer_without_scheduler(model, learning_rate, m=0.9, wd=5e-4):
#     optimizer = optim.Adam(model.parameters(),weight_decay=wd)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                          momentum=m, weight_decay=wd)
    return optimizer

def generate_cumulative_timing(mod_timing):
    tmp = 0
    mod_cum_timing = np.zeros(len(mod_timing))
    for i in range(len(mod_timing)):
        tmp += mod_timing[i]
        mod_cum_timing[i] = tmp
    return mod_cum_timing/3600

def displayTable(val_err_log, tst_err_log):
    col1 = [str(i) for i in range(10)]
    val_acc = [str(100-i) for i in val_err_log]
    tst_acc = [str(100-i) for i in tst_err_log]
    table = [col1, val_acc, tst_acc]
    table = map(list, zip(*table))
    print(tabulate(table, headers=['Class', 'Val Accuracy', 'Test Accuracy'], tablefmt='orgtbl'))
    print("Testing accuracy is as follows - ")
    for i in tst_acc:
        print(i)

def find_err_per_class(test_set, val_set, final_val_classifications, final_val_predictions, final_tst_classifications, 
                       final_tst_predictions, saveDir, prefix):
    val_err_idx = list(np.where(np.array(final_val_classifications) == False)[0])
    tst_err_idx = list(np.where(np.array(final_tst_classifications) == False)[0])
    val_class_err_idxs = []
    tst_err_log = []
    val_err_log = []
    for i in range(num_cls):
        tst_class_idxs = list(torch.where(torch.Tensor(test_set.targets) == i)[0].cpu().numpy())
        val_class_idxs = list(torch.where(torch.Tensor(val_set.targets.float()) == i)[0].cpu().numpy())
        #err classifications per class
        val_err_class_idx = set(val_err_idx).intersection(set(val_class_idxs))
        tst_err_class_idx = set(tst_err_idx).intersection(set(tst_class_idxs))
        if(len(val_class_idxs)>0):
            val_error_perc = round((len(val_err_class_idx)/len(val_class_idxs))*100,2)
        else:
            val_error_perc = 0
        tst_error_perc = round((len(tst_err_class_idx)/len(tst_class_idxs))*100,2)
#         print("val, test error% for class ", i, " : ", val_error_perc, tst_error_perc)
        val_class_err_idxs.append(val_err_class_idx)
        tst_err_log.append(tst_error_perc)
        val_err_log.append(val_error_perc)
    displayTable(val_err_log, tst_err_log)
    tst_err_log.append(sum(tst_err_log)/len(tst_err_log))
    val_err_log.append(sum(val_err_log)/len(val_err_log))
    return tst_err_log, val_err_log, val_class_err_idxs


def aug_train_subset(train_set, lake_set, true_lake_set, subset, lake_subset_idxs, budget, augrandom=False):
    all_lake_idx = list(range(len(lake_set)))
    if(not(len(subset)==budget) and augrandom):
        print("Budget not filled, adding ", str(int(budget) - len(subset)), " randomly.")
        remain_budget = int(budget) - len(subset)
        remain_lake_idx = list(set(all_lake_idx) - set(subset))
        random_subset_idx = list(np.random.choice(np.array(remain_lake_idx), size=int(remain_budget), replace=False))
        subset += random_subset_idx
    if str(type(true_lake_set.targets)) == "<class 'numpy.ndarray'>":
        lake_ss = SubsetWithTargets(true_lake_set, subset, torch.Tensor(true_lake_set.targets.astype(np.float))[subset])
    else:
        lake_ss = SubsetWithTargets(true_lake_set, subset, torch.Tensor(true_lake_set.targets.float())[subset])
    remain_lake_idx = list(set(all_lake_idx) - set(lake_subset_idxs))
    if str(type(true_lake_set.targets)) == "<class 'numpy.ndarray'>":
        remain_lake_set = SubsetWithTargets(lake_set, remain_lake_idx, torch.Tensor(lake_set.targets.astype(np.float))[remain_lake_idx])
    else:
        remain_lake_set = SubsetWithTargets(lake_set, remain_lake_idx, torch.Tensor(lake_set.targets.float())[remain_lake_idx])
    if str(type(true_lake_set.targets)) == "<class 'numpy.ndarray'>":
        remain_true_lake_set = SubsetWithTargets(true_lake_set, remain_lake_idx, torch.Tensor(true_lake_set.targets.astype(np.float))[remain_lake_idx])
    else:
        remain_true_lake_set = SubsetWithTargets(true_lake_set, remain_lake_idx, torch.Tensor(true_lake_set.targets.float())[remain_lake_idx])
#     print(len(lake_ss),len(remain_lake_set),len(lake_set))
    aug_train_set = torch.utils.data.ConcatDataset([train_set, lake_ss])
    aug_trainloader = torch.utils.data.DataLoader(train_set, batch_size=10, shuffle=True, pin_memory=True)
    return aug_train_set, remain_lake_set, remain_true_lake_set, lake_ss
                        
def getQuerySet(val_set, val_class_err_idxs, imb_cls_idx, miscls):
    miscls_idx = []
    if(miscls):
        for i in range(len(val_class_err_idxs)):
            if i in imb_cls_idx:
                miscls_idx += val_class_err_idxs[i]
        print("Total misclassified examples from imbalanced classes (Size of query set): ", len(miscls_idx))
    else:
        for i in imb_cls_idx:
            imb_cls_samples = list(torch.where(torch.Tensor(val_set.targets.float()) == i)[0].cpu().numpy())
            miscls_idx += imb_cls_samples
        print("Total samples from imbalanced classes as targets (Size of query set): ", len(miscls_idx))
    return Subset(val_set, miscls_idx), val_set.targets[miscls_idx]

def getPerClassSel(lake_set, subset, num_cls):
    perClsSel = []
    if str(type(lake_set.targets)) == "<class 'numpy.ndarray'>":
        subset_cls = torch.Tensor(lake_set.targets.astype(np.float))[subset]
    else:
        subset_cls = torch.Tensor(lake_set.targets.float())[subset]
    for i in range(num_cls):
        cls_subset_idx = list(torch.where(subset_cls == i)[0].cpu().numpy())
        perClsSel.append(len(cls_subset_idx))
    return perClsSel

def print_final_results(res_dict, sel_cls_idx):
    print("Gain in overall test accuracy: ", res_dict['test_acc'][1]-res_dict['test_acc'][0])
    bf_sel_cls_acc = np.array(res_dict['all_class_acc'][0])[sel_cls_idx]
    af_sel_cls_acc = np.array(res_dict['all_class_acc'][1])[sel_cls_idx]
    print("Gain in targeted test accuracy: ", np.mean(af_sel_cls_acc-bf_sel_cls_acc))

# Data, Model & Experimental Settings
The CIFAR-10 dataset contains 60,000 32x32 color images in 10 different classes.The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 6,000 images of each class. The training set contains 50,000 images and test set contains 10,000 images. We will use custom_dataset() function in Trust to simulated a class imbalance scenario using the split_cfg dictionary given below. We then use a ResNet18 model as our task DNN and train it on the simulated imbalanced version of the CIFAR-10 dataset. Next we perform targeted selection using various SMI functions and compare their gain in overall accuracy as well as on the imbalanced classes.

In [4]:
feature = "classimb"
device_id = 0
run="exp_4_2"
datadir = 'data/'
data_name = 'bloodmnist'
model_name = 'ResNet18'
learning_rate = 0.0003
computeClassErrorLog = True
device = "cuda:"+str(device_id) if torch.cuda.is_available() else "cpu"
miscls = True #Set to True if only the misclassified examples from the imbalanced classes is to be used
embedding_type = "gradients" #Type of the representation to use (gradients/features)
num_cls = 8
budget = 20
visualize_tsne = False
tns = [] #train_num_samples
imbf = 60 #imbalance factor
import math
for i in range(1,num_cls+1):
    tns.append(math.ceil(5*(1.4**i)))
# split_cfg = {"num_cls_imbalance":1,
#              "sel_cls_idx":[0,1,2,3,4,5,6,7,8],
#              "per_imbclass_train":{0:tns[0],1:tns[1],2:tns[2],3:tns[3],4:tns[4],5:tns[5],6:tns[6],7:tns[7],8:tns[8]},
#              "per_imbclass_val":{0:10,1:10,2:10,3:10,4:10,5:10,6:10,7:10,8:10},
#              "per_imbclass_lake":{0:tns[0]*imbf,1:tns[1]*imbf,2:tns[2]*imbf,3:tns[3]*imbf,4:tns[4]*imbf,5:tns[5]*imbf,6:tns[6]*imbf,7:tns[7]*imbf,8:tns[8]*imbf},
#             } #cifar10
split_cfg = {"num_cls_imbalance":4,
             "sel_cls_idx":[0,1,2,3,4,5,6,7],
             "per_imbclass_train":{0:7,1:50,2:7,3:50,4:7,5:7,6:50,7:50}, 
             "per_imbclass_val":{0:20,1:0,2:20,3:0,4:20,5:20,6:0,7:0},
             "per_imbclass_lake":{0:56,1:400,2:56,3:400,4:56,5:56,6:400,7:400},
             "per_imbclass_test":{0:243,1:243,2:243,3:243,4:243,5:243,6:243,7:243}}
print("split_cfg:",split_cfg)
initModelPath = "./"+data_name + "_" + model_name + "_" + str(learning_rate) + "_" + str(split_cfg["num_cls_imbalance"])

split_cfg: {'num_cls_imbalance': 4, 'sel_cls_idx': [0, 1, 2, 3, 4, 5, 6, 7], 'per_imbclass_train': {0: 7, 1: 50, 2: 7, 3: 50, 4: 7, 5: 7, 6: 50, 7: 50}, 'per_imbclass_val': {0: 20, 1: 0, 2: 20, 3: 0, 4: 20, 5: 20, 6: 0, 7: 0}, 'per_imbclass_lake': {0: 56, 1: 400, 2: 56, 3: 400, 4: 56, 5: 56, 6: 400, 7: 400}, 'per_imbclass_test': {0: 243, 1: 243, 2: 243, 3: 243, 4: 243, 5: 243, 6: 243, 7: 243}}


# Targeted Selection Algorithm
1. Given: Initial Labeled set of Examples: 𝐸, large unlabeled dataset: 𝑈, A target subset/slice where we want to improve accuracy: 𝑇, Loss function 𝐿 for learning
2. Train model with loss $\mathcal L$ on labeled set $E$ and obtain parameters $\theta_E$
3. Compute the gradients $\{\nabla_{\theta_E} \mathcal L(x_i, y_i), i \in U\}$ (using hypothesized labels) and $\{\nabla_{\theta_E} \mathcal L(x_i, y_i), i \in T\}$. 
(This notebook uses gradients for representation. However, any other representation can be used. Trust also supports using features via the API.)
4. Compute the similarity kernels $S$ (this includes kernel of the elements within $U$, within $T$ and between $U$ and $T$) and define a submodular function $f$ and diversity function $g$
5. Compute subset $\hat{A}$ by mazximizing the SMI function: $\hat{A} \gets \max_{A \subseteq U, |A|\leq k} I_f(A;T) + \gamma g(A)$
6. Obtain the labels of the elements in $A^*$: $L(\hat{A})$
7. Train a model on the combined labeled set $E \cup L(\hat{A})$

In [5]:
def run_targeted_selection(dataset_name, datadir, feature, model_name, budget, split_cfg, learning_rate, run,
                device, computeErrorLog, strategy="SIM", sf=""):

    #load the dataset in the class imbalance setting
    train_set, val_set, test_set, lake_set, sel_cls_idx, num_cls = load_biodataset_custom(datadir, dataset_name, feature, split_cfg, False, False)
    print("Indices of randomly selected classes for imbalance: ", sel_cls_idx)
    
    #Set batch size for train, validation and test datasets
    N = len(train_set)
    trn_batch_size = 20
    val_batch_size = 10
    tst_batch_size = 100

    #Create dataloaders
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=trn_batch_size,
                                              shuffle=True, pin_memory=True)

    valloader = torch.utils.data.DataLoader(val_set, batch_size=val_batch_size, 
                                            shuffle=False, pin_memory=True)

    tstloader = torch.utils.data.DataLoader(test_set, batch_size=tst_batch_size,
                                             shuffle=False, pin_memory=True)
    
    lakeloader = torch.utils.data.DataLoader(lake_set, batch_size=tst_batch_size,
                                         shuffle=False, pin_memory=True)
    true_lake_set = copy.deepcopy(lake_set)
    # Budget for subset selection
    bud = budget
   
    # Variables to store accuracies
    num_rounds=10 #The first round is for training the initial model and the second round is to train the final model
    fulltrn_losses = np.zeros(num_rounds)
    val_losses = np.zeros(num_rounds)
    tst_losses = np.zeros(num_rounds)
    timing = np.zeros(num_rounds)
    val_acc = np.zeros(num_rounds)
    full_trn_acc = np.zeros(num_rounds)
    tst_acc = np.zeros(num_rounds)
    final_tst_predictions = []
    final_tst_classifications = []
    best_val_acc = -1
    csvlog = []
    val_csvlog = []
    # Results logging file
    all_logs_dir = './results/' + dataset_name  + '/' + feature + '/'+  sf + '/' + str(bud) + '/' + str(run)
    print("Saving results to: ", all_logs_dir)
#     subprocess.run(["mkdir", "-p", all_logs_dir]) #Uncomment for saving results
#     exp_name = dataset_name + "_" + feature +  "_" + strategy + "_" + str(len(sel_cls_idx))  +"_" + sf +  '_budget:' + str(bud) + '_rounds:' + str(num_rounds) + '_runs' + str(run)

    #Create a dictionary for storing results and the experimental setting
    res_dict = {"dataset":data_name, 
                "feature":feature, 
                "sel_func":sf,
                "sel_budget":budget, 
                "num_selections":num_rounds-1, 
                "model":model_name, 
                "learning_rate":learning_rate, 
                "setting":split_cfg, 
                "all_class_acc":None, 
                "test_acc":[],
                "sel_per_cls":[], 
                "sel_cls_idx":sel_cls_idx}
    
    # Model Creation
    model = create_model(model_name, num_cls, device, embedding_type)
    model1 = create_model(model_name, num_cls, device, embedding_type)
    strategy_args = {'batch_size': 20, 'device':device, 'embedding_type':'gradients', 'keep_embedding':True}
    unlabeled_lake_set = LabeledToUnlabeledDataset(lake_set)
    
    if(strategy == "AL"):
        if(sf=="badge"):
            strategy_sel = BADGE(train_set, unlabeled_lake_set, model, num_cls, strategy_args)
        elif(sf=="us"):
            strategy_sel = EntropySampling(train_set, unlabeled_lake_set, model, num_cls, strategy_args)
        elif(sf=="glister" or sf=="glister-tss"):
            strategy_sel = GLISTER(train_set, unlabeled_lake_set, model, num_cls, strategy_args, val_set, typeOf='rand', lam=0.1)
        elif(sf=="gradmatch-tss"):
            strategy_sel = GradMatchActive(train_set, unlabeled_lake_set, model, num_cls, strategy_args, val_set)
        elif(sf=="coreset"):
            strategy_sel = CoreSet(train_set, unlabeled_lake_set, model, num_cls, strategy_args)
        elif(sf=="leastconf"):
            strategy_sel = LeastConfidence(train_set, unlabeled_lake_set, model, num_cls, strategy_args)
        elif(sf=="margin"):
            strategy_sel = MarginSampling(train_set, unlabeled_lake_set, model, num_cls, strategy_args)
    if(strategy == "SIM"):
        strategy_args['smi_function'] = sf
        strategy_sel = SMI(train_set, unlabeled_lake_set, val_set, model, num_cls, strategy_args)
    if(strategy == "random"):
        strategy_sel = RandomSampling(train_set, unlabeled_lake_set, model, num_cls, strategy_args)
        
    # Loss Functions
    criterion, criterion_nored = loss_function()

    # Getting the optimizer and scheduler
    optimizer = optimizer_without_scheduler(model, learning_rate)

    for i in range(num_rounds):
        tst_loss = 0
        tst_correct = 0
        tst_total = 0
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        if(i==0):
            print("Initial training epoch")
            if(os.path.exists(initModelPath)): #Read the initial trained model if it exists
                model.load_state_dict(torch.load(initModelPath, map_location=device))
                print("Init model loaded from disk, skipping init training: ", initModelPath)
                model.eval()
                with torch.no_grad():
                    final_val_predictions = []
                    final_val_classifications = []
                    for batch_idx, (inputs, targets) in enumerate(valloader):
                        inputs, targets = inputs.to(device), targets.to(device, non_blocking=True)
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)
                        val_loss += loss.item()
                        _, predicted = outputs.max(1)
                        val_total += targets.size(0)
                        val_correct += predicted.eq(targets).sum().item()
                        final_val_predictions += list(predicted.cpu().numpy())
                        final_val_classifications += list(predicted.eq(targets).cpu().numpy())
  
                    final_tst_predictions = []
                    final_tst_classifications = []
                    for batch_idx, (inputs, targets) in enumerate(tstloader):
                        inputs, targets = inputs.to(device), targets.to(device, non_blocking=True)
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)
                        tst_loss += loss.item()
                        _, predicted = outputs.max(1)
                        tst_total += targets.size(0)
                        tst_correct += predicted.eq(targets).sum().item()
                        final_tst_predictions += list(predicted.cpu().numpy())
                        final_tst_classifications += list(predicted.eq(targets).cpu().numpy())                
                    best_val_acc = (val_correct/val_total)
                    val_acc[i] = val_correct / val_total
                    tst_acc[i] = tst_correct / tst_total
                    val_losses[i] = val_loss
                    tst_losses[i] = tst_loss
                    res_dict["test_acc"].append(tst_acc[i]*100)
                continue
        else:
            #Remove true labels from the unlabeled dataset, the hypothesized labels are computed when select is called
            unlabeled_lake_set = LabeledToUnlabeledDataset(lake_set)
            strategy_sel.update_data(train_set, unlabeled_lake_set)
            #compute the error log before every selection
            if(computeErrorLog):
                tst_err_log, val_err_log, val_class_err_idxs = find_err_per_class(test_set, val_set, final_val_classifications, final_val_predictions, final_tst_classifications, final_tst_predictions, all_logs_dir, sf+"_"+str(bud))
                csvlog.append([100-x for x in tst_err_log])
                val_csvlog.append([100-x for x in val_err_log])
            ####SIM####
            if(strategy=="SIM" or strategy=="SF"):
                if(sf.endswith("mi")):
                    if(feature=="classimb"):
                        #make a dataloader for the misclassifications - only for experiments with targets
                        miscls_set, miscls_set_targets = getQuerySet(val_set, val_class_err_idxs, sel_cls_idx, miscls)
                        strategy_sel.update_queries(miscls_set)
            elif(strategy=="AL"):
                if(sf=="glister-tss" or sf=="gradmatch-tss"):
                    miscls_set = getQuerySet(val_set, val_class_err_idxs, sel_cls_idx, miscls)
                    strategy_sel.update_queries(miscls_set)
                    print("reinit AL with targeted miscls samples")
            
            strategy_sel.update_model(model)
            subset = strategy_sel.select(budget)
            print("#### Selection Complete, Now re-training with augmented subset ####")
            if(visualize_tsne):
                tsne_plt = tsne_smi(strategy_sel.unlabeled_data_embedding.cpu(),
                                    lake_set.targets,
                                    strategy_sel.query_embedding.cpu(),
                                    miscls_set_targets,
                                    subset)
                print("Computed TSNE plot of the selection")
            lake_subset_idxs = subset #indices wrt to lake that need to be removed from the lake
            perClsSel = getPerClassSel(true_lake_set, lake_subset_idxs, num_cls)
            res_dict['sel_per_cls'].append(perClsSel)
            
            #augment the train_set with selected indices from the lake
            train_set, lake_set, true_lake_set, add_val_set = aug_train_subset(train_set, lake_set, true_lake_set, subset, lake_subset_idxs, budget, True) #aug train with random if budget is not filled
            print("After augmentation, size of train_set: ", len(train_set), " unlabeled set: ", len(lake_set), " val set: ", len(val_set))
    
#           Reinit train and lake loaders with new splits and reinit the model
            trainloader = torch.utils.data.DataLoader(train_set, batch_size=trn_batch_size, shuffle=True, pin_memory=True)
            lakeloader = torch.utils.data.DataLoader(lake_set, batch_size=tst_batch_size, shuffle=False, pin_memory=True)
            model = create_model(model_name, num_cls, device, strategy_args['embedding_type'])
            optimizer = optimizer_without_scheduler(model, learning_rate)
                
        #Start training
        start_time = time.time()
        num_ep=1
#         while(num_ep<150):
        while(full_trn_acc[i]<0.99 and num_ep<100):
            model.train()
            for batch_idx, (inputs, targets) in enumerate(trainloader):
                inputs, targets = inputs.to(device), targets.to(device, non_blocking=True)
                # Variables in Pytorch are differentiable.
                inputs, target = Variable(inputs), Variable(inputs)
                # This will zero out the gradients for this batch.
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
#             scheduler.step()
          
            full_trn_loss = 0
            full_trn_correct = 0
            full_trn_total = 0
            model.eval()
            with torch.no_grad():
                for batch_idx, (inputs, targets) in enumerate(trainloader): #Compute Train accuracy
                    inputs, targets = inputs.to(device), targets.to(device, non_blocking=True)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    full_trn_loss += loss.item()
                    _, predicted = outputs.max(1)
                    full_trn_total += targets.size(0)
                    full_trn_correct += predicted.eq(targets).sum().item()
                full_trn_acc[i] = full_trn_correct / full_trn_total
                print("Selection Epoch ", i, " Training epoch [" , num_ep, "]" , " Training Acc: ", full_trn_acc[i], end="\r")
                num_ep+=1
            timing[i] = time.time() - start_time
        with torch.no_grad():
            final_val_predictions = []
            final_val_classifications = []
            for batch_idx, (inputs, targets) in enumerate(valloader): #Compute Val accuracy
                inputs, targets = inputs.to(device), targets.to(device, non_blocking=True)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
                final_val_predictions += list(predicted.cpu().numpy())
                final_val_classifications += list(predicted.eq(targets).cpu().numpy())

            final_tst_predictions = []
            final_tst_classifications = []
            for batch_idx, (inputs, targets) in enumerate(tstloader): #Compute test accuracy
                inputs, targets = inputs.to(device), targets.to(device, non_blocking=True)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                tst_loss += loss.item()
                _, predicted = outputs.max(1)
                tst_total += targets.size(0)
                tst_correct += predicted.eq(targets).sum().item()
                final_tst_predictions += list(predicted.cpu().numpy())
                final_tst_classifications += list(predicted.eq(targets).cpu().numpy())                
            val_acc[i] = val_correct / val_total
            tst_acc[i] = tst_correct / tst_total
            val_losses[i] = val_loss
            fulltrn_losses[i] = full_trn_loss
            tst_losses[i] = tst_loss
            full_val_acc = list(np.array(val_acc))
            full_timing = list(np.array(timing))
            res_dict["test_acc"].append(tst_acc[i]*100)
            print('Epoch:', i + 1, 'FullTrn,TrainAcc,ValLoss,ValAcc,TstLoss,TstAcc,Time:', full_trn_loss, full_trn_acc[i], val_loss, val_acc[i], tst_loss, tst_acc[i], timing[i])
            print("Gain in accuracy: ",res_dict['test_acc'][i]-res_dict['test_acc'][i-1])
        if(i==0): 
            print("Saving initial model") 
            torch.save(model.state_dict(), initModelPath) #save initial train model if not present
            
    #Compute the statistics of the final model
    if(computeErrorLog):
        print("**** Final Metrics after Targeted Learning ****")
        tst_err_log, val_err_log, val_class_err_idxs = find_err_per_class(test_set, val_set, final_val_classifications, final_val_predictions, final_tst_classifications, final_tst_predictions, all_logs_dir, sf+"_"+str(bud))
        csvlog.append([100-x for x in tst_err_log])
        val_csvlog.append([100-x for x in val_err_log])
        res_dict["all_class_acc"] = csvlog
        res_dict["all_val_class_acc"] = val_csvlog
        
    #Print overall acc improvement and rare class acc improvement, show that TL selected relevant points in space, is possible show some images
#     print_final_results(res_dict, sel_cls_idx)
    print("Total gain in accuracy: ",res_dict['test_acc'][i]-res_dict['test_acc'][0])
    
    #save results dir with test acc and per class selections
#     with open(os.path.join(all_logs_dir, exp_name+".json"), 'w') as fp:
#         json.dump(res_dict, fp)
    
#     tsne_plt.show()
    

In [6]:
start_time = time.monotonic()

# Submodular Mutual Information (SMI)

We let $V$ denote the ground-set of $n$ data points $V = \{1, 2, 3,...,n \}$ and a set function $f:
 2^{V} \xrightarrow{} \Re$. Given a set of items $A, B \subseteq V$, the submodular mutual information (MI)[1,3] is defined as $I_f(A; B) = f(A) + f(B) - f(A \cup B)$. Intuitively, this measures the similarity between $B$ and $A$ and we refer to $B$ as the query set.

In [2], they extend MI to handle the case when the target can come from an auxiliary set $V^{\prime}$ different from the ground set $V$. For targeted data subset selection, $V$ is the source set of data instances and the target is a subset of data points (validation set or the specific set of examples of interest).
Let $\Omega  = V \cup V^{\prime}$. We define a set function $f: 2^{\Omega} \rightarrow \Re$. Although $f$ is defined on $\Omega$, the discrete optimization problem will only be defined on subsets $A \subseteq V$. To find an optimal subset given a query set $Q \subseteq V^{\prime}$, we can define $g_{Q}(A) = I_f(A; Q)$, $A \subseteq V$ and maximize the same.

# FL1MI

In the first variant of FL, we set the unlabeled dataset to be $V$. The SMI instantiation of FL1MI can be defined as:
\begin{align}
I_f(A;Q)=\sum_{i \in V}\min(\max_{j \in A}s_{ij}, \eta \max_{j \in Q}sq_{ij})
\end{align}

The first term in the min(.) of FL1MI models diversity, and the second term models query relevance. An increase in the value of $\eta$ causes the resulting summary to become more relevant to the query.

In [7]:
run_targeted_selection(data_name, 
               datadir, 
               feature, 
               model_name, 
               budget, 
               split_cfg, 
               learning_rate, 
               run, 
               device, 
               computeClassErrorLog,
               "SIM",'fl1mi')

bloodmnist Custom dataset stats: Train size:  228 Val size:  80 Lake size:  1824
Indices of randomly selected classes for imbalance:  [0, 1, 2, 3, 4, 5, 6, 7]
Saving results to:  ./results/bloodmnist/classimb/fl1mi/20/exp_4_2
Initial training epoch
Epoch: 1 FullTrn,TrainAcc,ValLoss,ValAcc,TstLoss,TstAcc,Time: 1.2726227566599846 0.9912280701754386 14.229120910167694 0.3 27.315001368522644 0.7293189125986553 28.732062816619873
Gain in accuracy:  0.0
Saving initial model
|   Class |   Val Accuracy |   Test Accuracy |
|---------+----------------+-----------------|
|       0 |             10 |           13.52 |
|       1 |            100 |           93.43 |
|       2 |             35 |           35.37 |
|       3 |            100 |           88.77 |
|       4 |             75 |           52.26 |
|       5 |              0 |            4.93 |
|       6 |            100 |           96.7  |
|       7 |            100 |          100    |
Testing accuracy is as follows - 
13.519999999999996
93.4

# FL2MI

In the V2 variant, we set $D$ to be $V \cup Q$. The SMI instantiation of FL2MI can be defined as:
\begin{align} \label{eq:FL2MI}
I_f(A;Q)=\sum_{i \in Q} \max_{j \in A} sq_{ij} + \eta\sum_{i \in A} \max_{j \in Q} sq_{ij}
\end{align}
FL2MI is very intuitive for query relevance as well. It measures the representation of data points that are the most relevant to the query set and vice versa. It can also be thought of as a bidirectional representation score.

In [8]:
run_targeted_selection(data_name, 
               datadir, 
               feature, 
               model_name, 
               budget, 
               split_cfg, 
               learning_rate, 
               run, 
               device, 
               computeClassErrorLog, 
               "SIM",'fl2mi')

bloodmnist Custom dataset stats: Train size:  228 Val size:  80 Lake size:  1824
Indices of randomly selected classes for imbalance:  [0, 1, 2, 3, 4, 5, 6, 7]
Saving results to:  ./results/bloodmnist/classimb/fl2mi/20/exp_4_2
Initial training epoch
Init model loaded from disk, skipping init training:  ./bloodmnist_ResNet18_0.0003_4
|   Class |   Val Accuracy |   Test Accuracy |
|---------+----------------+-----------------|
|       0 |             10 |           13.52 |
|       1 |            100 |           93.43 |
|       2 |             35 |           35.37 |
|       3 |            100 |           88.77 |
|       4 |             70 |           52.26 |
|       5 |              0 |            4.93 |
|       6 |            100 |           96.7  |
|       7 |            100 |          100    |
Testing accuracy is as follows - 
13.519999999999996
93.43
35.370000000000005
88.77
52.26
4.930000000000007
96.7
100.0
Total misclassified examples from imbalanced classes (Size of query set):  57

# GCMI

The SMI instantiation of graph-cut (GCMI) is defined as:
\begin{align}
I_f(A;Q)=2\sum_{i \in A} \sum_{j \in Q} sq_{ij}
\end{align}
Since maximizing GCMI maximizes the joint pairwise sum with the query set, it will lead to a subset similar to the query set $Q$.

In [9]:
run_targeted_selection(data_name, 
               datadir, 
               feature, 
               model_name, 
               budget, 
               split_cfg, 
               learning_rate, 
               run, 
               device, 
               computeClassErrorLog,
               "SIM",'gcmi')

bloodmnist Custom dataset stats: Train size:  228 Val size:  80 Lake size:  1824
Indices of randomly selected classes for imbalance:  [0, 1, 2, 3, 4, 5, 6, 7]
Saving results to:  ./results/bloodmnist/classimb/gcmi/20/exp_4_2
Initial training epoch
Init model loaded from disk, skipping init training:  ./bloodmnist_ResNet18_0.0003_4
|   Class |   Val Accuracy |   Test Accuracy |
|---------+----------------+-----------------|
|       0 |             15 |           13.52 |
|       1 |            100 |           93.43 |
|       2 |             40 |           35.37 |
|       3 |            100 |           88.77 |
|       4 |             70 |           52.26 |
|       5 |              5 |            4.93 |
|       6 |            100 |           96.7  |
|       7 |            100 |          100    |
Testing accuracy is as follows - 
13.519999999999996
93.43
35.370000000000005
88.77
52.26
4.930000000000007
96.7
100.0
Total misclassified examples from imbalanced classes (Size of query set):  54


# LOGDETMI

The SMI instantiation of LogDetMI can be defined as:
\begin{align}
I_f(A;Q)=\log\det(S_{A}) -\log\det(S_{A} - \eta^2 S_{A,Q}S_{Q}^{-1}S_{A,Q}^T)
\end{align}
$S_{A, B}$ denotes the cross-similarity matrix between the items in sets $A$ and $B$. The similarity matrix in constructed in such a way that the cross-similarity between $A$ and $Q$ is multiplied by $\eta$ to control the trade-off between query-relevance and diversity.

In [10]:
run_targeted_selection(data_name, 
               datadir, 
               feature, 
               model_name, 
               budget, 
               split_cfg, 
               learning_rate, 
               run, 
               device, 
               computeClassErrorLog,
               "SIM",'logdetmi')

bloodmnist Custom dataset stats: Train size:  228 Val size:  80 Lake size:  1824
Indices of randomly selected classes for imbalance:  [0, 1, 2, 3, 4, 5, 6, 7]
Saving results to:  ./results/bloodmnist/classimb/logdetmi/20/exp_4_2
Initial training epoch
Init model loaded from disk, skipping init training:  ./bloodmnist_ResNet18_0.0003_4
|   Class |   Val Accuracy |   Test Accuracy |
|---------+----------------+-----------------|
|       0 |             15 |           13.52 |
|       1 |            100 |           93.43 |
|       2 |             30 |           35.37 |
|       3 |            100 |           88.77 |
|       4 |             75 |           52.26 |
|       5 |              0 |            4.93 |
|       6 |            100 |           96.7  |
|       7 |            100 |          100    |
Testing accuracy is as follows - 
13.519999999999996
93.43
35.370000000000005
88.77
52.26
4.930000000000007
96.7
100.0
Total misclassified examples from imbalanced classes (Size of query set): 

# Random

In [11]:
run_targeted_selection(data_name, 
               datadir, 
               feature, 
               model_name, 
               budget, 
               split_cfg, 
               learning_rate, 
               run, 
               device, 
               computeClassErrorLog,
               "random",'random')

bloodmnist Custom dataset stats: Train size:  228 Val size:  80 Lake size:  1824
Indices of randomly selected classes for imbalance:  [0, 1, 2, 3, 4, 5, 6, 7]
Saving results to:  ./results/bloodmnist/classimb/random/20/exp_4_2
Initial training epoch
Init model loaded from disk, skipping init training:  ./bloodmnist_ResNet18_0.0003_4
|   Class |   Val Accuracy |   Test Accuracy |
|---------+----------------+-----------------|
|       0 |             20 |           13.52 |
|       1 |            100 |           93.43 |
|       2 |             30 |           35.37 |
|       3 |            100 |           88.77 |
|       4 |             70 |           52.26 |
|       5 |              0 |            4.93 |
|       6 |            100 |           96.7  |
|       7 |            100 |          100    |
Testing accuracy is as follows - 
13.519999999999996
93.43
35.370000000000005
88.77
52.26
4.930000000000007
96.7
100.0
#### Selection Complete, Now re-training with augmented subset ####
After au

# US

In [12]:
run_targeted_selection(data_name, 
               datadir, 
               feature, 
               model_name, 
               budget, 
               split_cfg, 
               learning_rate, 
               run, 
               device, 
               computeClassErrorLog,
               "AL",'us')

bloodmnist Custom dataset stats: Train size:  228 Val size:  80 Lake size:  1824
Indices of randomly selected classes for imbalance:  [0, 1, 2, 3, 4, 5, 6, 7]
Saving results to:  ./results/bloodmnist/classimb/us/20/exp_4_2
Initial training epoch
Init model loaded from disk, skipping init training:  ./bloodmnist_ResNet18_0.0003_4
|   Class |   Val Accuracy |   Test Accuracy |
|---------+----------------+-----------------|
|       0 |             15 |           13.52 |
|       1 |            100 |           93.43 |
|       2 |             40 |           35.37 |
|       3 |            100 |           88.77 |
|       4 |             65 |           52.26 |
|       5 |              0 |            4.93 |
|       6 |            100 |           96.7  |
|       7 |            100 |          100    |
Testing accuracy is as follows - 
13.519999999999996
93.43
35.370000000000005
88.77
52.26
4.930000000000007
96.7
100.0
#### Selection Complete, Now re-training with augmented subset ####
After augmen

# BADGE

In [13]:
run_targeted_selection(data_name, 
               datadir, 
               feature, 
               model_name, 
               budget, 
               split_cfg, 
               learning_rate, 
               run, 
               device, 
               computeClassErrorLog,
               "AL",'badge')

bloodmnist Custom dataset stats: Train size:  228 Val size:  80 Lake size:  1824
Indices of randomly selected classes for imbalance:  [0, 1, 2, 3, 4, 5, 6, 7]
Saving results to:  ./results/bloodmnist/classimb/badge/20/exp_4_2
Initial training epoch
Init model loaded from disk, skipping init training:  ./bloodmnist_ResNet18_0.0003_4
|   Class |   Val Accuracy |   Test Accuracy |
|---------+----------------+-----------------|
|       0 |             15 |           13.52 |
|       1 |            100 |           93.43 |
|       2 |             35 |           35.37 |
|       3 |            100 |           88.77 |
|       4 |             70 |           52.26 |
|       5 |              0 |            4.93 |
|       6 |            100 |           96.7  |
|       7 |            100 |          100    |
Testing accuracy is as follows - 
13.519999999999996
93.43
35.370000000000005
88.77
52.26
4.930000000000007
96.7
100.0
#### Selection Complete, Now re-training with augmented subset ####
After aug

# References
[1] Rishabh Iyer, Ninad Khargoankar, Jeff Bilmes, and Himanshu Asnani. Submodular combinatorialinformation measures with applications in machine learning.arXiv preprint arXiv:2006.15412,2020


[2] Kaushal V, Kothawade S, Ramakrishnan G, Bilmes J, Iyer R. PRISM: A Unified Framework of Parameterized Submodular Information Measures for Targeted Data Subset Selection and Summarization. arXiv preprint arXiv:2103.00128. 2021 Feb 27.


[3] Anupam Gupta and Roie Levin. The online submodular cover problem. InACM-SIAM Symposiumon Discrete Algorithms, 2020

In [14]:
end_time = time.monotonic()
print('Time to complete all strategies is -', (end_time - start_time)/60, 'mins')

Time to complete all strategies is - 59.71914701256513 mins
