Import Statements

In [5]:
# PyTorch
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import os

# Baseline Model and Dataset

os.system("pwd")
from dotPyfiles.LeNet import *
from dotPyfiles.class_imbalance_dataset import *
import matplotlib.pyplot as plt
from tqdm import tqdm
os.system("pwd")


import matplotlib
matplotlib.rcParams.update({'errorbar.capsize': 5})


Getting Data

In [None]:
H_PARAMS = {
    'lr' : 1e-3,
    'batch_size' : 100,
    'num_iterations' : 2000,
}
data_loader = get_mnist_loader(H_PARAMS['batch_size'], classes=[8, 0], proportion=0.8, mode="train")
test_loader = get_mnist_loader(H_PARAMS['batch_size'], classes=[8, 0], proportion=0.5, mode="test")

In [30]:
def gpu_avail(x, requires_grad=True):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, requires_grad=requires_grad)


loaded_imgs = gpu_avail(data_loader.dataset.data_val, requires_grad=False)
loaded_lbls = gpu_avail(data_loader.dataset.labels_val, requires_grad=False)

In [20]:
def get_Lemodel():
    model = LeNet(n_out=1)
    # model = VGGmodel(n_out=1)
    # model = remodel(n_out=1)
    if torch.cuda.is_available():
        model.cuda()
        # torch.backends.cudnn.benchmark=True

    optimizer = torch.optim.SGD(model.params(),lr=H_PARAMS["lr"])
    return model, optimizer

Regular Training Function - Baseline

In [None]:
import time
total_loss = []
step_size_for_fig = 100

model, optimizer = get_Lemodel()
model_l = 0
a = 0.9
accuracy_log = []
preds = []

iters = H_PARAMS['num_iterations']
start_time = time.time()
for i in tqdm(range(iters)):
    model.train()
    img, labels = next(iter(data_loader))

    img = gpu_avail(img, requires_grad=False)
    labels = gpu_avail(labels, requires_grad=False)
    # labels = labels.unsqueeze(1)

    model_output = model(img)
    cel = F.binary_cross_entropy_with_logits(model_output, labels)
    
    optimizer.zero_grad()
    cel.backward()
    optimizer.step()
    
    model_l = a * model_l + (1 - a) * cel.item()
    
    eval = model_l/(1 - a**(i+1))
    total_loss.append(eval)
    
    if i % step_size_for_fig == 0:
        pred = []
        model.eval()
        
        for itr,(test_img, test_label) in enumerate(test_loader):
            test_img = gpu_avail(test_img, requires_grad=False)
            test_label = gpu_avail(test_label, requires_grad=False)
            
            output = model(test_img)
            predicted = (F.sigmoid(output) > 0.5)
            float_prediction = (predicted.int() == test_label.int()).float()
            preds.append(float_prediction)
        
        accuracy = torch.cat(preds,dim=0).mean()
        accuracy_log.append(np.array([i,accuracy])[None])
        
    
        fig, axes = plt.subplots(1, 2, figsize=(10,4))
        ax1, ax2 = axes.ravel()

        ax1.plot(total_loss, label='total_loss')
        ax1.set_ylabel("Losses")
        ax1.set_xlabel("Iteration")
        ax1.legend()
        
        history_of_accuracies = np.concatenate(accuracy_log, axis=0)
        ax2.plot(history_of_accuracies[:,0],history_of_accuracies[:,1])
        # ax2.title("Regular Training - 80% Imbalance")
        ax2.set_ylabel('Accuracy')
        ax2.set_xlabel('Iteration')
        # plt.plot(total_loss, label="Net Loss for 2k iterations")
        # plt.show()
        # plt.plot(acc_log[:,0], acc_log[:,1])
        # # plt.label("Accuracy showing differentiation between '8' and '0'")
        plt.show()
print(f"Diff = {time.time() - start_time}")

In [None]:
plt.plot(history_of_accuracies[:,0],history_of_accuracies[:,1])
plt.title("Regular Training - 80% Imbalance")
plt.ylabel('Accuracy')
plt.xlabel('Iteration')

In [None]:

x = torch.linspace(-10, 10, 1000, requires_grad=True)
w = torch.tensor([0.1], requires_grad=True)

swish_output = x * torch.sigmoid(w * x)
swish_output.sum().backward()
swish_grad = x.grad.clone()
x.grad.zero_()

relu_output = torch.relu(w * x)
relu_output.sum().backward()
relu_grad = x.grad.clone()
x.grad.zero_()

plt.plot(x.detach().numpy(), swish_grad.detach().numpy(), label='Swish')
plt.plot(x.detach().numpy(), relu_grad.detach().numpy(), label='ReLU')
plt.xlabel('Input')
plt.ylabel('Gradient')
plt.title("Vanishing Gradient Problem")
plt.legend()
plt.show()


Original Paper Algo

In [23]:
def reAlgo():

    model, optimizer = get_Lemodel()
    meta_losses_clean, total_loss, history_of_accuracies = [],[],[]
    step_size_for_fig = 100
    iters = H_PARAMS['num_iterations']
    a = 0.9
    meta_l,model_l, MIN = 0,0,0
    
    
    for i in tqdm(range(iters)):
        model.train()
        img, labels = next(iter(data_loader))
        meta_model = LeNet(n_out=1)
        meta_model.load_state_dict(model.state_dict())

        if torch.cuda.is_available():meta_model.cuda()

        img = gpu_avail(img, requires_grad=False)
        labels = gpu_avail(labels, requires_grad=False)

        meta_model_out  = meta_model(img)
        cel = F.binary_cross_entropy_with_logits(meta_model_out,labels, reduce=False)

        epsilon = gpu_avail(torch.zeros(cel.size()))
        sum_MM_out = torch.sum(cel* epsilon)

        meta_model.zero_grad()
        
        grads = torch.autograd.grad(sum_MM_out, (meta_model.params()), create_graph=True)
        meta_model.update_params(H_PARAMS['lr'], source_params=grads)
        
        meta_model_prime = meta_model(loaded_imgs)

        cel_prime = F.binary_cross_entropy_with_logits(meta_model_prime,loaded_lbls)
        epsilon_prime = torch.autograd.grad(cel_prime, epsilon, only_inputs=True)[0]
        
        w = torch.clamp(-epsilon_prime,min=MIN)
        LNorm = torch.sum(w)

        meta_model_out = model(img)
        cost = F.binary_cross_entropy_with_logits(meta_model_out, labels, reduce=False)


        if LNorm != 0:
            w = w / LNorm
        sum_MM_out = torch.sum(cost * w)

        optimizer.zero_grad()
        sum_MM_out.backward()
        optimizer.step()

        meta_l = a * meta_l + (1 - a) * cel_prime.item()
        meta_losses_clean.append(meta_l / (1 - a**(i+1)))

        model_l = a *model_l + (1 - a)* sum_MM_out.item()
        total_loss.append(model_l/(1 - a**(i+1)))

        if i % step_size_for_fig == 0:
            model.eval()

            preds = []
            for itr,(test_img, test_label) in enumerate(test_loader):
                test_img = gpu_avail(test_img, requires_grad=False)
                test_label = gpu_avail(test_label, requires_grad=False)

                output = model(test_img)
                predicted = (F.sigmoid(output) > 0.5)
                float_prediction = (predicted.int() == test_label.int()).float()
                preds.append(float_prediction)

            accuracy = torch.cat(preds,dim=0).mean()
            history_of_accuracies.append(np.array([i,accuracy])[None])


            fig, axes = plt.subplots(1, 2, figsize=(10,4))
            ax1, ax2 = axes.ravel()

            ax1.plot(meta_losses_clean, label='meta_losses_clean')
            ax1.plot(total_loss, label='total_loss')
            ax1.set_ylabel("Losses")
            ax1.set_xlabel("Iteration")
            ax1.legend()

            logs = np.concatenate(accuracy_log, axis=0)
            ax2.plot(logs[:,0],logs[:,1])
            ax2.set_title('Re-weighted Training - 80% Imbalance')
            ax2.set_ylabel('Accuracy')
            ax2.set_xlabel('Iteration')
            # plt.show()

    mean = np.mean(logs[-6:-1, 1])
    return mean

In [None]:
import time
num_repeats = 1
# proportions = [0.9,0.95, 0.98, 0.99, 0.995]
# proportions = [0.8]
# proportions = [0.9,0.95,0.99]
proportions = [(0.90 + 0.95)/2, (0.95+0.99)/2]
history_of_accuracies = {}
start_time = time.time()
batches = H_PARAMS['batch_size']
for prop in proportions:
    data_loader = get_mnist_loader(batches, classes=[8, 0], proportion=prop, mode="train")
    imgs = gpu_avail(data_loader.dataset.data_val, requires_grad=False)
    lbls = gpu_avail(data_loader.dataset.labels_val, requires_grad=False)
    
    for k in range(num_repeats):
        accuracy = reAlgo()
        if prop in history_of_accuracies:
            history_of_accuracies[prop].append(accuracy)
        else:
            history_of_accuracies[prop] = [accuracy]

plt.figure(figsize=(10,4))
for prop in proportions:
    accuracies = history_of_accuracies[prop]
    plt.scatter([prop] * len(accuracies), accuracies)

# plot the trend line with error bars that correspond to standard deviation
accuracies_mean = np.array([np.mean(v) for _,v in sorted(history_of_accuracies.items())])
accuracies_std = np.array([np.std(v) for _,v in sorted(history_of_accuracies.items())])
print(f"Diff = {time.time() - start_time}")
plt.errorbar(proportions, accuracies_mean, yerr=accuracies_std)
plt.title('Accuracies with imbalance proportions')
plt.xlabel('Imbalance Proportions')
plt.ylabel('Accuracies')
plt.show()

In [None]:
plt.errorbar(proportions, accuracies_mean, yerr=accuracies_std)
plt.title('Imbalance Proportion Accuracy')
plt.xlabel('Imbalance')
plt.ylabel('Accuracy')
plt.show()

In [None]:
acc = [0.74, 0.52,0.92,0.76,0.61]
x = [0.9,0.925,0.95,0.97,0.99]
plt.scatter(x,acc)
plt.title('Accuracies with imbalance proportions')
plt.xlabel('Imbalance Proportions')
plt.ylabel('Accuracies')