In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms, datasets
from torch import nn, optim
from torch.autograd.variable import Variable

from lib.VisdomWrapper import *
from lib.GANs import *
from lib.DataCreationWrapper import *
from lib.DataManager import *

import numpy as np
import pandas as pd
import time
from IPython import display

In [4]:
batch_size = 16
data_loader = get_unbalanced_mnist(np.ones(10), batch_size=batch_size)

n_samples = len(data_loader)
n_classes = len(data_loader.dataset.classes)
img_width = 28 #hardcoded
n_features = img_width**2
n_noise_features = 100

In [5]:
# Visualizer
vis = VisdomController()

Setting up a new session...


In [6]:
loss_key = 0

discr_nn = None
gen_nn = None
discr_optimizer = None
gen_optimizer = None
loss_function = None
real_target = None
fake_target = None
if(loss_key == 0):
    # Networks
#     discr_nn = DiscriminatorNetwork(n_features, n_classes)
#     gen_nn = GeneratorNetwork(n_noise_features, n_features, n_classes)
    discr_nn = Conv_DiscriminatorNetwork(n_features, n_classes)
    gen_nn = Conv_GeneratorNetwork(n_noise_features, n_features, n_classes)
    
    # Optimizers
#     discr_optimizer = optim.Adam(discr_nn.parameters(), lr=1e-3)
#     gen_optimizer = optim.Adam(gen_nn.parameters(), lr=1e-3)
    discr_optimizer = optim.Adam(discr_nn.parameters(), lr=1e-4)
    gen_optimizer = optim.Adam(gen_nn.parameters(), lr=1e-4)
    
    # Loss Function
    loss_function = nn.BCELoss()
    
    # Targets
    real_target = make_target(batch_size, 0.9)
    fake_target = make_target(batch_size, 0.1)
elif(loss_key == 1):
    # Networks
#     discr_nn = DiscriminatorNetwork(n_features, n_classes)
#     gen_nn = GeneratorNetwork(n_noise_features, n_features, n_classes)
    discr_nn = Conv_DiscriminatorNetwork(n_features, n_classes)
    gen_nn = Conv_GeneratorNetwork(n_noise_features, n_features, n_classes)
    
    # Optimizers    
    discr_optimizer = optim.RMSprop(discr_nn.parameters(), lr=1e-5)
    gen_optimizer = optim.RMSprop(gen_nn.parameters(), lr=1e-5)
    
    # Loss Function
    loss_function = wasserstein_loss
    
    # Targets
    real_target = make_target(batch_size, -0.9)
    fake_target = make_target(batch_size, 0.9)


In [7]:
if torch.cuda.is_available():
    discr_nn.cuda()
    gen_nn.cuda()

In [8]:
# Params
num_epochs = 40
noise_function = gaussian_noise
num_scatter_points = 60
class_to_mimic = 1

# Visdom Initialization
vis.ClearPlots()
avg_class_img = get_avg_img(data_loader, class_to_mimic, img_width)
vis.vis.image(avg_class_img)
formatted_data_sample = get_sample(data_loader, num_scatter_points, class_to_mimic, 100, 500)
vis.PlotRealFeatureDistributionComparison(295, 515, formatted_data_sample, num_scatter_points)
cov_real = pd.DataFrame(formatted_data_sample.numpy()).corr().fillna(1).to_numpy()
# vis.PlotHeatMap(cov_real, "real_vs_fake_cov_map", True)

In [9]:
vis.ClearPlots()
for epoch in range(num_epochs):
    for n_batch, (batch, labels) in enumerate(data_loader):
        
        # Move to CUDA      
        real_batch = Variable(batch)
        real_labels = Variable(labels)
        if torch.cuda.is_available():
            real_batch = real_batch.cuda()
            real_labels = real_labels.cuda()
            
        t_start = millis = time.time()
        batch_size = real_batch.size(0)
        
        discr_loss_real, discr_loss_fake = train_discriminator(discr_nn,
                                                               discr_optimizer,
                                                               loss_function,
                                                               gen_nn,
                                                               real_batch,
                                                               noise_function,
                                                               n_classes,
                                                               real_labels,
                                                               real_target,
                                                               fake_target
                                                              )
        gen_loss = train_generator(gen_nn,
                                   gen_optimizer,
                                   loss_function,
                                   discr_nn, real_batch,
                                   noise_function,
                                   n_classes,
                                   real_target
                                  )
        
        if (n_batch % 100 == 0):
            display.clear_output(True)
            # Basic Data
            discr_loss = discr_loss_real + discr_loss_fake
            print("Epoch {}, {} / {}".format(epoch, n_batch, len(data_loader)))
            print("discr_loss : ", discr_loss)
            print("gen_loss : ", gen_loss) 
            print("critic_loss : ", abs(gen_loss - discr_loss))
            
            # Visualization            
            vis.PlotLoss("Discr Loss Real", discr_loss_real.item())
            vis.PlotLoss("Discr Loss Fake", discr_loss_fake.item())
            vis.PlotLoss("Gen Loss", gen_loss.item())
            vis.loss_axis += 1
            
            t_end = millis = time.time()
            print("Time Elapsed : ", t_end - t_start)
    if epoch % 2 == 0:
#     vis.PlotFakeFeatureDistributionComparison(295, 515, gen_nn, num_scatter_points, noise_function, class_to_mimic)

#     cov_fake = pd.DataFrame(synthesize_data_from_label(gen_nn, num_scatter_points, noise_function, class_to_mimic).cpu().detach().numpy()).corr().fillna(1).to_numpy()
    
#     vis.PlotHeatMap(np.abs(cov_real - cov_fake), "real_vs_fake_cov_map", True)
    
        vis.ShowImages(format_to_image(synthesize_data_from_each_label(gen_nn, noise_function, n_classes).cpu().detach(), n_classes, img_width), "Generated Data, Epoch " + str(epoch))

# torch.save(gen_nn.state_dict(),"models\gen_nn" + str(data_to_mimic))

Epoch 32, 900 / 3750
discr_loss :  tensor(0.7415, device='cuda:0', grad_fn=<AddBackward0>)
gen_loss :  tensor(2.2877, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
critic_loss :  tensor(1.5462, device='cuda:0', grad_fn=<AbsBackward>)


KeyboardInterrupt: 