In [1]:
from IPython import display

import torch
import torch.nn as nn
from torch.nn import init
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision.datasets as dset

import numpy as np

import time
from visdom import Visdom

from lib.VisdomWrapper import *
from lib.DataManager import *
from lib.GANs import *
from lib.DataCreationWrapper import *

In [2]:
#Options
bal_raw = False
bal_syn = False
bal_aug = False
zer_raw = False
zer_syn = False
zer_aug = False
nin_raw = True
nin_syn = True
nin_aug = True

In [3]:
batch_size = 512
num_epochs = 30
img_width = 28 #hardcoded
n_features = img_width**2
n_noise_features = 100
n_classes = 10

loss = nn.CrossEntropyLoss()
vis = VisdomController()

mnist = dset.MNIST('input', train=True, download=True, transform=T.ToTensor())

mnist_test=dset.MNIST('input', train=False, download=True, transform=T.ToTensor())
balanced_test = DataLoader(mnist_test)

Setting up a new session...


In [4]:
def calc_balance(labels):
    cnts = torch.bincount(labels)
    print(cnts)
    max = torch.max(cnts)
    print(max)
    ret = (max * torch.ones(len(cnts)) - cnts).cpu().numpy().astype('int')
    print(ret)
    return ret

In [5]:
def plot_frequency(data_loader, title):
    x = np.arange(10)
    cnts = torch.zeros(10)
    for _, y in data_loader:
        cnts += torch.bincount(y.cpu())
    vis.CreateStaticBarPlot(cnts, x, title, "Class", "Count", "FreqPlot")

In [6]:
#Load Data
bal_nn_file = "models\gen_nn_1_03-29-06"
zer_nn_file = "models\gen_nn_low_zero_BEST"
nin_nn_file = "models\gen_nn_high_nine_OLD"


with torch.no_grad():
    if bal_raw or bal_aug:
        bal_train = DataLoader(mnist, batch_size =1000)
        X, Y = data_loader_to_tensor(bal_train)
        bal_train = DataLoader(TensorDataset(X, Y), batch_size=batch_size, shuffle=True)
        X, Y = None, None
        if bal_raw:
            plot_frequency(bal_train, "Balanced Raw Class Frequencies")
        
        
    if zer_raw or zer_aug:
        zer_train = get_unbalanced_mnist([.01, .1, .1, .1, .1, .1, .1, .1, .1, .1], batch_size=1000)
        X, Y = data_loader_to_tensor(zer_train)
        zer_train = DataLoader(TensorDataset(X, Y), batch_size=batch_size, shuffle=True)
        X, Y = None, None
        if zer_raw:
            plot_frequency(zer_train, "Low Zero Raw Class Frequencies")
        
    if nin_raw or nin_aug:
        nin_train = get_unbalanced_mnist([.1, .1, .1, .1, .1, .1, .1, .1, .1, 1], batch_size=1000)
        X, Y = data_loader_to_tensor(nin_train)
        nin_train = DataLoader(TensorDataset(X, Y), batch_size=batch_size, shuffle=True)
        X, Y = None, None
        if nin_raw:
            plot_frequency(nin_train, "High Nine Raw Class Frequencies")
    
    if bal_syn or bal_aug:
        bal_gen = Conv_GeneratorNetwork(n_noise_features, n_features, n_classes)
        bal_gen.load_state_dict(torch.load(bal_nn_file))
        bal_gen = bal_gen.cuda()
        bal_gen.eval()
        if bal_syn:
            synth_data, synth_labels = synthesize_data_of_each_label(bal_gen, gaussian_noise, 6000 * np.ones(10).astype('int'))
            syn_bal_data_loader = DataLoader(TensorDataset(synth_data.view(len(synth_labels), 1, 28, 28), synth_labels), batch_size=batch_size, shuffle=True)
            synth_data, synth_labels = None, None
            plot_frequency(syn_bal_data_loader, "Balanced Synthetic Class Frequencies")

        if bal_aug:
            print("Augmenting bal")
            synth_data, synth_labels = synthesize_data_of_each_label(bal_gen, gaussian_noise, calc_balance(bal_train.dataset.tensors[1]))
            synth_data = torch.cat((synth_data.view(len(synth_labels), 1, 28, 28).cuda(), bal_train.dataset.tensors[0].cuda()), dim=0)
            synth_labels = torch.cat((synth_labels.cuda(), bal_train.dataset.tensors[1].cuda()))
            aug_bal_data_loader = DataLoader(TensorDataset(synth_data.view(len(synth_labels), 1, 28, 28), synth_labels), batch_size=batch_size, shuffle=True)
            synth_data, synth_labels = None, None
            plot_frequency(aug_bal_data_loader, "Balanced Augmented Class Frequencies")

    if zer_syn or zer_aug:
        zer_gen = OLD_GeneratorNetwork(n_noise_features, n_features, n_classes)
        zer_gen.load_state_dict(torch.load(zer_nn_file))
        zer_gen.cuda()
        zer_gen.eval()
        if zer_syn:
            synth_data, synth_labels = synthesize_data_of_each_label(zer_gen, gaussian_noise, 6000 * np.ones(10).astype('int'))
            syn_low_zero_data_loader = DataLoader(TensorDataset(synth_data.view(len(synth_labels), 1, 28, 28), synth_labels), batch_size=batch_size, shuffle=True)
            synth_data, synth_labels = None, None
            plot_frequency(syn_low_zero_data_loader, "Low Zero Synthetic Class Frequencies")
            
        if zer_aug:
            print("Augmenting zer")
            synth_data, synth_labels = synthesize_data_of_each_label(zer_gen, gaussian_noise, calc_balance(zer_train.dataset.tensors[1]))
            synth_data = torch.cat((synth_data.view(len(synth_labels), 1, 28, 28).cuda(), zer_train.dataset.tensors[0].cuda()), dim=0)
            synth_labels = torch.cat((synth_labels.cuda(), zer_train.dataset.tensors[1].cuda()))
            aug_zer_data_loader = DataLoader(TensorDataset(synth_data.view(len(synth_labels), 1, 28, 28), synth_labels), batch_size=batch_size, shuffle=True)
            synth_data, synth_labels = None, None
            plot_frequency(aug_zer_data_loader, "Low Zero Augmented Class Frequencies")
            

    if nin_syn or nin_aug:
        nin_gen = OLD_GeneratorNetwork(n_noise_features, n_features, n_classes)
        nin_gen.load_state_dict(torch.load(nin_nn_file))
        nin_gen.cuda()
        nin_gen.eval()
        if nin_syn:
            synth_data, synth_labels = synthesize_data_of_each_label(nin_gen, gaussian_noise, 6000 * np.ones(10).astype('int'))
            syn_nin_data_loader = DataLoader(TensorDataset(synth_data.view(len(synth_labels), 1, 28, 28), synth_labels), batch_size=batch_size, shuffle=True)
            synth_data, synth_labels = None, None
            plot_frequency(syn_nin_data_loader, "High Nine Synthetic Class Frequencies")
            
        if nin_aug:
            print("Augmenting nin")
            synth_data, synth_labels = synthesize_data_of_each_label(nin_gen, gaussian_noise, calc_balance(nin_train.dataset.tensors[1]))
            synth_data = torch.cat((synth_data.view(len(synth_labels), 1, 28, 28).cuda(), nin_train.dataset.tensors[0].cuda()), dim=0)
            synth_labels = torch.cat((synth_labels.cuda(), nin_train.dataset.tensors[1].cuda()))
            aug_nin_data_loader = DataLoader(TensorDataset(synth_data.view(len(synth_labels), 1, 28, 28), synth_labels), batch_size=batch_size, shuffle=True)
            synth_data, synth_labels = None, None
            plot_frequency(aug_nin_data_loader, "High Nine Augmented Class Frequencies")
            

RuntimeError: Error(s) in loading state_dict for OLD_GeneratorNetwork:
	Missing key(s) in state_dict: "label_embedding.weight", "hidden0.1.weight", "hidden0.1.bias", "hidden0.1.running_mean", "hidden0.1.running_var", "hidden1.1.weight", "hidden1.1.bias", "hidden1.1.running_mean", "hidden1.1.running_var", "hidden2.0.weight", "hidden2.0.bias", "hidden2.1.weight", "hidden2.1.bias", "hidden2.1.running_mean", "hidden2.1.running_var", "out.0.weight", "out.0.bias". 
	Unexpected key(s) in state_dict: "process_noise.0.weight", "process_noise.0.bias", "label_embedding.0.weight", "label_embedding.1.weight", "label_embedding.1.bias", "hidden0.2.weight", "hidden0.2.bias", "hidden0.2.running_mean", "hidden0.2.running_var", "hidden0.2.num_batches_tracked". 
	size mismatch for hidden0.0.weight: copying a param with shape torch.Size([129, 64, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 110]).
	size mismatch for hidden0.0.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for hidden1.0.weight: copying a param with shape torch.Size([64, 1, 4, 4]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for hidden1.0.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([512]).

In [None]:
def build_classifier():
    return nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=5, stride=1),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 64, kernel_size=5, stride=1),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(4*4*64, 4*4*64, bias=True),
        nn.LeakyReLU(0.01),
        nn.Linear(4*4*64, 10, bias=True)
    )

In [None]:
def get_optimizer(model):
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    return optimizer

In [None]:
def train_classifier(classifier, optimizer, data, key="Loss"):
    for epoch in range(num_epochs):
        for n_batch, (x, y) in enumerate(data):
            if len(x) != batch_size:
                continue
            optimizer.zero_grad()
            x = x.cuda()
            y = y.cuda()
            scores = classifier(x)
            out = loss(scores, y)
            out.backward()
            optimizer.step()
        display.clear_output(True)
        print("Epoch {}, {} / {}".format(epoch, n_batch, len(data)))
        print("Loss: ", out.item())
        vis.loss_axis = epoch
        vis.PlotLoss(key, out.item())
            

In [None]:
def test(model, device, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()


    print('Accuracy: {}/{} ({:.0f}%)\n'.format(correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
def plot_confusion_matrix(model, key):
    nb_classes = 10

    confusion_matrix = torch.zeros(nb_classes, nb_classes)
    with torch.no_grad():
        for i, (inputs, classes) in enumerate(balanced_test):
            inputs = inputs.to(device)
            classes = classes.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            for t, p in zip(classes.view(-1), preds.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1
    confusion_matrix = confusion_matrix.numpy()
    np.fill_diagonal(confusion_matrix, 0)
    vis.PlotHeatMap(confusion_matrix, key, False)
#y is correct, x is predicted

In [None]:
if bal_raw:
    balanced_net = build_classifier().cuda()
    balanced_net.train()
    train_classifier(balanced_net, get_optimizer(balanced_net), bal_train, key="Balanced Raw")

In [None]:
if bal_aug:
    balanced_aug_net = build_classifier().cuda()
    balanced_aug_net.train()
    train_classifier(balanced_aug_net, get_optimizer(balanced_aug_net), aug_bal_data_loader, key="Balanced Aug")

In [None]:
if bal_syn:
    balanced_syn_net = build_classifier().cuda()
    balanced_syn_net.train()
    train_classifier(balanced_syn_net, get_optimizer(balanced_syn_net), syn_bal_data_loader, key="Balanced Synth")

In [None]:
if bal_raw:
    device = torch.device("cuda")
    balanced_net.eval()
    test(balanced_net, device, balanced_test)
    plot_confusion_matrix(balanced_net, "Balanced Raw")    

In [None]:
if bal_aug:
    device = torch.device("cuda")
    balanced_aug_net.eval()
    test(balanced_aug_net, device, balanced_test)
    plot_confusion_matrix(balanced_aug_net, "Balanced Augmented")    

In [None]:
if bal_syn:
    device = torch.device("cuda")
    balanced_syn_net.eval()
    test(balanced_syn_net, device, balanced_test)
    plot_confusion_matrix(balanced_syn_net, "Balanced Synthetic")        

In [None]:
if zer_raw:
    low_zero_net = build_classifier().cuda()
    low_zero_net.train()
    train_classifier(low_zero_net, get_optimizer(low_zero_net), zer_train, key="Low Zero Raw")

In [None]:
if zer_syn:
    low_zero_syn_net = build_classifier().cuda()
    low_zero_syn_net.train()
    train_classifier(low_zero_syn_net, get_optimizer(low_zero_syn_net), syn_low_zero_data_loader, key="Low Zero Synth")

In [None]:
if zer_aug:
    low_zero_aug_net = build_classifier().cuda()
    low_zero_aug_net.train()
    train_classifier(low_zero_aug_net, get_optimizer(low_zero_aug_net), aug_zer_data_loader, key="Low Zero Aug")

In [None]:
if zer_raw:
    device = torch.device("cuda")
    low_zero_net.eval()
    test(low_zero_net, device, balanced_test)
    plot_confusion_matrix(low_zero_net, "Low Zero Raw")                

In [None]:
if zer_syn:
    device = torch.device("cuda")
    low_zero_syn_net.eval()
    test(low_zero_syn_net, device, balanced_test)
    plot_confusion_matrix(low_zero_syn_net, "Low Zero Synthetic")    

In [None]:
if zer_aug:
    device = torch.device("cuda")
    low_zero_aug_net.eval()
    test(low_zero_aug_net, device, balanced_test)
    plot_confusion_matrix(low_zero_aug_net, "Low Zero Augmented")        

In [None]:
if nin_raw:
    high_nine_net = build_classifier().cuda()
    high_nine_net.train()
    train_classifier(high_nine_net, get_optimizer(high_nine_net), nin_train, key="High Nine Raw")

In [None]:
if nin_syn:
    nin_syn_net = build_classifier().cuda()
    nin_syn_net.train()
    train_classifier(nin_syn_net, get_optimizer(nin_syn_net), syn_nin_data_loader, key="High Nine Synth")

In [None]:
if nin_aug:
    nin_aug_net = build_classifier().cuda()
    nin_aug_net.train()
    train_classifier(nin_aug_net, get_optimizer(nin_aug_net), aug_nin_data_loader, key="High Nine Aug")

In [None]:
if nin_raw:
    device = torch.device("cuda")
    high_nine_net.eval()
    test(high_nine_net, device, balanced_test)
    plot_confusion_matrix(high_nine_net, "High Nine Raw")                    

In [None]:
if nin_aug:
    device = torch.device("cuda")
    nin_aug_net.eval()
    test(nin_aug_net, device, balanced_test)
    plot_confusion_matrix(nin_aug_net, "High Nine Augmented")                        

In [None]:
if nin_syn:
    device = torch.device("cuda")
    nin_syn_net.eval()
    test(nin_syn_net, device, balanced_test)
    plot_confusion_matrix(nin_syn_net, "High Nine Synthetic")                            

In [None]:
# vis.ShowImages(format_to_image(synthesize_data_from_each_label(zer_gen, gaussian_noise, n_classes).cpu().detach(), n_classes, img_width), "Fortnite")