In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler, TensorDataset
from torchvision import transforms, datasets
from torch import nn, optim
from torch.autograd.variable import Variable

from lib.VisdomWrapper import *
from lib.GANs import *
from lib.DataCreationWrapper import *
from lib.DataManager import *

import numpy as np
import pandas as pd
import time
from datetime import datetime
from IPython import display

In [4]:
batch_size = 32
class_weights = np.ones(10)
VRAM = False

data_loader = get_unbalanced_mnist(class_weights, batch_size=batch_size)

if VRAM:
    X, Y = data_loader_to_tensor(data_loader)
    data_loader = DataLoader(TensorDataset(X, Y), batch_size=batch_size, shuffle=True)
    X, Y = None, None

n_samples = len(data_loader)
n_classes = 10
img_width = 28 #hardcoded
n_features = img_width**2
n_noise_features = 50 #don't change

# noise_maker = Noisifier(2)

In [5]:
# Visualizer
vis = VisdomController()

Setting up a new session...


In [6]:
loss_key = 0
add_noise = False

discr_nn = None
gen_nn = None
discr_optimizer = None
gen_optimizer = None
loss_function = None
real_target = None
fake_target = None
if(loss_key == 0):
    # Networks
#     discr_nn = DiscriminatorNetwork(n_features, n_classes, loss_key)
#     gen_nn = GeneratorNetwork(n_noise_features, n_features, n_classes)
    discr_nn = Conv_DiscriminatorNetwork(n_features, n_classes, loss_key)
    gen_nn = Conv_GeneratorNetwork(n_noise_features, n_features, n_classes)
    
    # Optimizers
#     discr_optimizer = optim.Adam(discr_nn.parameters(), lr=1e-3)
#     gen_optimizer = optim.Adam(gen_nn.parameters(), lr=1e-3)
    discr_optimizer = optim.Adam(discr_nn.parameters(), lr=1e-4, weight_decay=1e-5)
    gen_optimizer = optim.Adam(gen_nn.parameters(), lr=1e-4, weight_decay=1e-5)
    
    # Loss Function
    loss_function = nn.BCELoss()
    
    reg_constant = 10
    
    # Targets
    real_target = make_target(batch_size, 0.9)
    fake_target = make_target(batch_size, 0.1)
elif(loss_key == 1):
    # Networks
#     discr_nn = DiscriminatorNetwork(n_features, n_classes, loss_key)
#     gen_nn = GeneratorNetwork(n_noise_features, n_features, n_classes)
    discr_nn = Conv_DiscriminatorNetwork(n_features, n_classes, loss_key)
    gen_nn = Conv_GeneratorNetwork(n_noise_features, n_features, n_classes)
    
    # Optimizers    
#     discr_optimizer = optim.RMSprop(discr_nn.parameters(), lr=1e-5)
#     gen_optimizer = optim.RMSprop(gen_nn.parameters(), lr=1e-5)
    discr_optimizer = optim.Adam(discr_nn.parameters(), lr=1e-4, betas = (.9, .99), weight_decay=1e-5)
    gen_optimizer = optim.Adam(gen_nn.parameters(), lr=1e-4, betas = (.9, .99), weight_decay=1e-5)
    
    # Loss Function
    loss_function = wasserstein_loss
    
    reg_constant = 5
    
    # Targets
    real_target = make_target(batch_size, -0.99)
    fake_target = make_target(batch_size, 0.99)


In [7]:
if torch.cuda.is_available():
    discr_nn.cuda()
    gen_nn.cuda()

In [8]:
# Params
num_epochs = 40
noise_function = gaussian_noise
num_scatter_points = 60
class_to_mimic = 0

# Visdom Initialization
vis.ClearPlots()
avg_class_img = get_avg_img(data_loader, class_to_mimic, img_width)
vis.vis.image(avg_class_img)
formatted_data_sample = get_sample(data_loader, num_scatter_points, class_to_mimic, 100, 500)
vis.PlotRealFeatureDistributionComparison(295, 515, formatted_data_sample, num_scatter_points)
cov_real = pd.DataFrame(formatted_data_sample.numpy()).corr().fillna(1).to_numpy()
vis.PlotHeatMap(cov_real, "real_vs_fake_cov_map", True)

(784, 784)


In [9]:
vis.ClearPlots()
torch.manual_seed(2)
for epoch in range(num_epochs):
    for n_batch, (batch, labels) in enumerate(data_loader):
        # Move to CUDA
        if(add_noise):
#             print(batch.shape)
            batch = noise_maker.add_noise_random(batch.view(batch_size,-1), scale=.1).view(batch_size, 1, img_width, img_width)
#             batch = noise_maker.add_noise_directed(batch.view(batch_size,-1), labels, scale=.1).view(batch_size, 1, img_width, img_width)
        
        real_batch = Variable(batch)
        real_labels = Variable(labels)
        if torch.cuda.is_available():
            real_batch = real_batch.cuda()
            real_labels = real_labels.cuda()
            
        t_start = millis = time.time()
        batch_size = real_batch.size(0)
        
        discr_loss_real, discr_loss_fake, reg_loss = train_discriminator(discr_nn,
                                                               discr_optimizer,
                                                               loss_function,
                                                               gen_nn,
                                                               real_batch,
                                                               noise_function,
                                                               n_classes,
                                                               real_labels,
                                                               real_target,
                                                               fake_target,
                                                               loss_key,
                                                               reg_constant
                                                              )
        gen_loss = train_generator(gen_nn,
                                   gen_optimizer,
                                   loss_function,
                                   discr_nn, real_batch,
                                   noise_function,
                                   n_classes,
                                   real_target
                                  )
        
        if (n_batch % 100 == 0):
            display.clear_output(True)
            # Basic Data
            discr_loss = discr_loss_real + discr_loss_fake + reg_loss
            print("Epoch {}, {} / {}".format(epoch, n_batch, len(data_loader)))
            print("discr_loss : ", discr_loss)
            print("gen_loss : ", gen_loss) 
            print("critic_loss : ", abs(gen_loss - discr_loss))
            if(reg_loss != 0):
                print("reg_loss : ", reg_loss)
            
            # Visualization            
            vis.PlotLoss("Discr Loss Real", discr_loss_real.item())
            vis.PlotLoss("Discr Loss Fake", discr_loss_fake.item())
            vis.PlotLoss("Gen Loss", gen_loss.item())
            vis.loss_axis += 1
            
            t_end = millis = time.time()
            print("Time Elapsed : ", t_end - t_start)
    if epoch % 2 == 0:
        vis.PlotFakeFeatureDistributionComparison(295, 515, gen_nn, num_scatter_points, noise_function, class_to_mimic)

        cov_fake = pd.DataFrame(synthesize_data_from_label(gen_nn, num_scatter_points, noise_function, class_to_mimic).cpu().detach().numpy()).corr().fillna(1).to_numpy()
    
        vis.PlotHeatMap(np.abs(cov_real - cov_fake), "real_vs_fake_cov_map", True)
    
        vis.ShowImages(format_to_image(synthesize_data_from_each_label(gen_nn, noise_function, n_classes).cpu().detach(), n_classes, img_width), "Generated Data, Epoch " + str(epoch))
#         vis.ShowImages(get_visual_embeddings(discr_nn, n_classes).cpu().detach(), "Embeddings, Epoch " + str(epoch))

Epoch 0, 1800 / 1875
discr_loss :  tensor(1.0343, device='cuda:0', grad_fn=<AddBackward0>)
gen_loss :  tensor(1.2444, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
critic_loss :  tensor(0.2101, device='cuda:0', grad_fn=<AbsBackward>)
Time Elapsed :  0.9109973907470703


ValueError: Must pass 2-d input

In [None]:
current_time = datetime.now().strftime("%H-%M-%S")
weights_string = ",".join([str(c_w) for c_w in class_weights])
file_name = "models\gen_nn_" + str(loss_key) + "_" + str(current_time)
with open(file_name + ".txt", "w") as f:
    f.write(weights_string)
    
torch.save(gen_nn.state_dict(), file_name)

In [None]:
def show_img_seed(seed):
    torch.manual_seed(2)
    torch.manual_seed(seed)
    vis.ShowImages(format_to_image(synthesize_data_from_each_label(gen_nn, noise_function, n_classes).cpu().detach(), n_classes, img_width), "Generated Data, Seeded" + str(seed))


In [None]:
show_img_seed(22)