In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms, datasets
from torch import nn, optim
from torch.autograd.variable import Variable

from lib.VisdomWrapper import *
from lib.GANs import *
from lib.DataCreationWrapper import *
from lib.DataManager import *

import numpy as np
import pandas as pd
import time
from IPython import display

In [4]:
batch_size = 100
data_loader = get_unbalanced_emnist(np.ones(10), batch_size=batch_size)
n_samples = len(data_loader)
img_width = 28
n_features = img_width**2
n_noise_features = 100

In [5]:
discr_nn = DiscriminatorNetwork(n_features)
gen_nn = GeneratorNetwork(n_noise_features, n_features)
if torch.cuda.is_available():
    discr_nn.cuda()
    gen_nn.cuda()

In [6]:
# Optimizers
discr_optimizer = optim.Adam(discr_nn.parameters(), lr=1e-3, betas=(0.5, 0.999))
gen_optimizer = optim.Adam(gen_nn.parameters(), lr=1e-3, betas=(0.5, 0.999))

In [7]:
# Visualizer
vis = VisdomController()

Setting up a new session...


In [8]:
# Params
num_epochs = 10
noise_function = gaussian_noise
num_scatter_points = 60
num_imgs_show = 5
loss_function = nn.BCELoss() #WGan has it built into trainer, fix later?
class_to_mimic = 1

# Visdom Initialization
vis.ClearPlots()
avg_class_img = get_avg_img(data_loader, class_to_mimic, img_width)
vis.vis.image(avg_class_img)
formatted_data_sample = get_sample(data_loader, num_scatter_points, class_to_mimic, 100, 500)
vis.PlotRealFeatureDistributionComparison(295, 515, formatted_data_sample, num_scatter_points)
cov_real = pd.DataFrame(formatted_data_sample.numpy()).corr().fillna(1)
# vis.PlotHeatMap(cov_real, "real_vs_fake_cov_map", True)

In [12]:
vis.ClearPlots()
for epoch in range(num_epochs):
    for n_batch, (batch, labels) in enumerate(data_loader):
        
        # Move to CUDA      
        real_batch = Variable(batch[labels == class_to_mimic])
        if torch.cuda.is_available():
            real_batch = real_batch.cuda()
            
        t_start = millis = time.time()
        batch_size = real_batch.size(0)
        
        discr_loss_real, discr_loss_fake = train_discriminator_wass(discr_nn, discr_optimizer, loss_function, gen_nn, real_batch, noise_function)
        gen_loss = train_generator_wass(gen_nn, gen_optimizer, loss_function, discr_nn, real_batch, noise_function)
        
        if (n_batch % 100 == 0):
            display.clear_output(True)
            # Basic Data
            discr_loss = discr_loss_real + discr_loss_fake
            print("Epoch {}, {} / {}".format(epoch, n_batch, len(data_loader)))
            print("discr_loss : ", discr_loss)
            print("gen_loss : ", gen_loss) 
            print("critic_loss : ", abs(gen_loss - discr_loss))
            
            # Visualization            
            vis.PlotLoss("Discr Loss Real", discr_loss_real.item())
            vis.PlotLoss("Discr Loss Fake", discr_loss_fake.item())
            vis.PlotLoss("Gen Loss", gen_loss.item())
            vis.loss_axis += 1
            
            t_end = millis = time.time()
            print("Time Elapsed : ", t_end - t_start)     
    vis.PlotFakeFeatureDistributionComparison(295, 515, gen_nn, num_scatter_points, noise_function)
    cov_fake = pd.DataFrame(synthesize_data(gen_nn, num_scatter_points, noise_function).cpu().detach().numpy()).corr().fillna(1)
    vis.PlotHeatMap(np.abs(cov_real - cov_fake), "real_vs_fake_cov_map", True)
    vis.ShowImages(format_to_image(synthesize_data(gen_nn, num_imgs_show, noise_function).cpu().detach(), num_imgs_show, img_width), "Epoch " + str(epoch))

# torch.save(gen_nn.state_dict(),"models\gen_nn" + str(data_to_mimic))

Epoch 9, 500 / 600
discr_loss :  tensor(5.5603e-07, device='cuda:0', grad_fn=<AddBackward0>)
gen_loss :  tensor(-3.9842e-06, device='cuda:0', grad_fn=<NegBackward>)
critic_loss :  tensor(4.5402e-06, device='cuda:0', grad_fn=<AbsBackward>)
Time Elapsed :  0.3390007019042969
