In [None]:
from IPython import display

import torch
import torch.nn as nn
from torch.nn import init
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision.datasets as dset

import numpy as np

import time
from visdom import Visdom

from lib.VisdomWrapper import *
from lib.DataManager import *
from lib.GANs import *
from lib.DataCreationWrapper import *

In [None]:
#Options
bal_raw = False
bal_syn = False
bal_aug = False
zer_raw = False
zer_syn = False
zer_aug = False
nin_raw = False
nin_syn = False
nin_aug = False

In [None]:
torch.manual_seed(2)
batch_size = 512
num_epochs = 30
img_width = 28 #hardcoded
n_features = img_width**2
n_noise_features = 100
n_classes = 10

loss = nn.CrossEntropyLoss()
vis = VisdomController()

mnist = dset.MNIST('input', train=True, download=True, transform=T.ToTensor())
mnist_low_zero = get_unbalanced_mnist([.01, .1, .1, .1, .1, .1, .1, .1, .1, .1], batch_size=batch_size)
mnist_high_nine = get_unbalanced_mnist([.1, .1, .1, .1, .1, .1, .1, .1, .1, 1], batch_size=batch_size)

mnist_test=dset.MNIST('input', train=False, download=True, transform=T.ToTensor())
balanced_test = DataLoader(mnist_test)

In [None]:
def calc_balance(labels):
    cnts = torch.bincount(labels)
    max = torch.max(cnts)
    return (max * torch.ones(len(cnts)) - cnts).cpu().numpy()

In [None]:
#Load Data

with torch.no_grad():
    if bal_raw or bal_aug:
        bal_train = DataLoader(mnist, batch_size =1000)
        X, Y = data_loader_to_tensor(bal_train)
        bal_train = DataLoader(TensorDataset(X, Y), batch_size=batch_size, shuffle=True)
        X, Y = None, None
    
    if bal_syn or bal_aug:
        bal_gen = GeneratorNetwork(n_noise_features, n_features, n_classes)
        bal_gen.load_state_dict(torch.load("models\gen_nn_bal"))
        bal_gen.eval()
        if bal_syn:
            synth_data, synth_labels = synthesize_data_of_each_label(bal_gen, gaussian_noise, 6000 * np.ones(10))
            synth_bal_data_loader = DataLoader(TensorDataset(synth_data.view(len(synth_labels), 1, 28, 28), synth_labels), batch_size=batch_size, shuffle=True)
            synth_data, synth_labels = None, None
        if bal_aug:
            synth_data, synth_labels = synthesize_data_of_each_label(bal_gen, gaussian_noise, calc_balance(bal_train.dataset.tensors[1]) * np.ones(10))
            synth_data = torch.cat(synth_data, bal_train.dataset.tensors[0])
            synth_labels = torch.cat(synth_labels, bal_train.dataset.tensors[1])
            aug_bal_data_loader = DataLoader(TensorDataset(synth_data.view(len(synth_labels), 1, 28, 28), synth_labels), batch_size=batch_size, shuffle=True)
            synth_data, synth_labels = None, None

    if zer_syn or zer_aug:
        zer_gen = GeneratorNetwork(n_noise_features, n_features, n_classes)
        zer_gen.load_state_dict(torch.load("models\gen_nn_low_zero"))
        zer_gen.eval()
        if zer_syn:
            synth_data, synth_labels = synthesize_data_of_each_label(zer_gen, gaussian_noise, 6000 * np.ones(10))
            syn_low_zero_data_loader = DataLoader(TensorDataset(synth_data.view(len(synth_labels), 1, 28, 28), synth_labels), batch_size=batch_size, shuffle=True)
            synth_data, synth_labels = None, None

    if nin_syn or nin_aug:
        nin_gen = GeneratorNetwork(n_noise_features, n_features, n_classes)
        nin_gen.load_state_dict(torch.load("models\gen_nn_high_nine"))
        nin_gen.eval()
        if zer_syn:
            synth_data, synth_labels = synthesize_data_of_each_label(zer_gen, gaussian_noise, 6000 * np.ones(10))
            syn_nin_data_loader = DataLoader(TensorDataset(synth_data.view(len(synth_labels), 1, 28, 28), synth_labels), batch_size=batch_size, shuffle=True)
            synth_data, synth_labels = None, None

In [None]:
def plot_frequency(data_loader, title):
    x = np.arange(10)
    cnts = torch.zeros(10)
    for _, y in data_loader:
        cnts += torch.bincount(y.cpu())
    vis.CreateStaticBarPlot(cnts, x, title, "Class", "Count", "FreqPlot")

In [None]:
def build_classifier():
    return nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=5, stride=1),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 64, kernel_size=5, stride=1),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(4*4*64, 4*4*64, bias=True),
        nn.LeakyReLU(0.01),
        nn.Linear(4*4*64, 10, bias=True)
    )

In [None]:
def get_optimizer(model):
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    return optimizer

In [None]:
def train_classifier(classifier, optimizer, data, key="Loss"):
    for epoch in range(num_epochs):
        for n_batch, (x, y) in enumerate(data):
            if len(x) != batch_size:
                continue
            optimizer.zero_grad()
            x = x.cuda()
            y = y.cuda()
            scores = classifier(x)
            out = loss(scores, y)
            out.backward()
            optimizer.step()
        display.clear_output(True)
        print("Epoch {}, {} / {}".format(epoch, n_batch, len(data)))
        print("Loss: ", out.item())
        vis.loss_axis = epoch
        vis.PlotLoss(key, out.item())
            

In [None]:
def test(model, device, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()


    print('Accuracy: {}/{} ({:.0f}%)\n'.format(correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
if bal_raw
    balanced_net = build_classifier().cuda()
    balanced_net.train()
    train_classifier(balanced_net, get_optimizer(balanced_net), bal_train, key="Balanced Raw")

In [None]:
if bal_aug:
    balanced_aug_net = build_classifier().cuda()
    balanced_aug_net.train()
    train_classifier(balanced_aug_net, get_optimizer(balanced_aug_net), aug_bal_data_loader, key="Balanced Aug")

In [None]:
if bal_syn:
    balanced_syn_net = build_classifier().cuda()
    balanced_syn_net.train()
    train_classifier(balanced_syn_net, get_optimizer(balanced_syn_net), syn_bal_data_loader, key="Balanced Synth")

In [None]:
if zer_raw:
    low_zero_net = build_classifier().cuda()
    low_zero_net.train()
    train_classifier(low_zero_net, get_optimizer(low_zero_net), mnist_low_zero, key="Low Zero Aug")

In [None]:
if nin_raw:
    high_nine_net = build_classifier().cuda()
    high_nine_net.train()
    train_classifier(high_nine_net, get_optimizer(high_nine_net), mnist_high_nine, key="High Nine Raw")

In [None]:
if bal_raw
    device = torch.device("cuda")
    balanced_net.eval()
    test(balanced_net, device, balanced_test)

In [None]:
if bal_aug:
    device = torch.device("cuda")
    balanced_aug_net.eval()
    test(balanced_aug_net, device, balanced_test)

In [None]:
if zer_raw:
    device = torch.device("cuda")
    low_zero_net.eval()
    test(low_zero_net, device, balanced_test)

In [None]:
if nin_raw:
    device = torch.device("cuda")
    high_nine_net.eval()
    test(high_nine_net, device, balanced_test)

In [None]:
if zer_syn:
    low_zero_syn_net = build_classifier().cuda()
    low_zero_syn_net.train()
    train_classifier(low_zero_syn_net, get_optimizer(low_zero_syn_net), syn_low_zero_data_loader, key="Low Zero Synth")

In [None]:
if zer_syn:
    device = torch.device("cuda")
    low_zero_syn_net.eval()
    test(low_zero_syn_net, device, balanced_test)

In [None]:
# nb_classes = 10
# model=low_zero_net

# confusion_matrix = torch.zeros(nb_classes, nb_classes)
# with torch.no_grad():
#     for i, (inputs, classes) in enumerate(balanced_test):
#         inputs = inputs.to(device)
#         classes = classes.to(device)
#         outputs = model(inputs)
#         _, preds = torch.max(outputs, 1)
#         for t, p in zip(classes.view(-1), preds.view(-1)):
#                 confusion_matrix[t.long(), p.long()] += 1
# confusion_matrix = confusion_matrix.numpy()
# np.fill_diagonal(confusion_matrix, 0)
# vis.PlotHeatMap(confusion_matrix, "Low Zero Confusionsdf", False)
#y is correct, x is predicted

In [None]:
# print(confusion_matrix[5][9])