## 1D GAN distribution sampling

This notebook contains a PyTorch implementation of a classic one-dimensional GAN example. It's a relatively simple problem in which the generator tries to mimic a Gaussian distribution. The generator is fed by sampling a noise distribution while the discriminator tries to distinguish between samples drawn from the Gaussian distribution and samples constructed by the generator using noise samples as input.

Several implementations have been proposed on the internet, and the goal is to illustrate impact of different models and the evolution of Wasserstein GAN (WGAN) training techniques.
More detailed descriptions of the problem can be found in the following blogs:

http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/

https://blog.evjang.com/2016/06/generative-adversarial-nets-in.html

https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f

Relevant publications:

GAN:
https://arxiv.org/abs/1406.2661

Wasserstein GAN:
https://arxiv.org/abs/1701.07875

Wasserstein GAN with improved training:
https://arxiv.org/abs/1704.00028

In [None]:
import numpy as np

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, ConcatDataset, DataLoader

from bokeh.models import ColumnDataSource
from bokeh.plotting import output_notebook, figure, show
from bokeh.layouts import row
from bokeh.io import push_notebook

In [None]:
#embed figures in the notebook
output_notebook()

### Define the sampling

In [None]:
def sample_real(n, sigma=1.0, mean=-1.0):
    return np.sort(np.random.normal(mean, sigma, n))

In [None]:
def sample_noise(n, bound=5.0):
    #use stratified sampling to ensure the mapping maintains ordering 
    # (see blog for details)
    return np.linspace(-bound, bound, n) + \
            np.random.random(n) * 0.01

In [None]:
# used for the Wasserstein-GAN with improved training
def sample_epsilon(n):
    return np.random.uniform(0.0, 1.0, n)

### Wrap sampling in PyTorch Datasets

to use PyTorch sampling functionalities when training the networks.

In [None]:
class RealData(Dataset):
    def __init__(self, n, sigma, mean):
        self.data_size = n
        self.sigma = sigma
        self.mean = mean
    
    def sample(self):
        self.x = torch.FloatTensor(sample_real(
            self.data_size, self.sigma, self.mean)).unsqueeze(-1)
        
    def __len__(self):
        return self.data_size
    
    def __getitem__(self, idx):
        return self.x[idx]

In [None]:
class FakeData(Dataset):
    def __init__(self, n, bound):
        self.data_size = n
        self.bound = bound
        
    def sample(self):
        self.x = torch.FloatTensor(sample_noise(
            self.data_size, self.bound)).unsqueeze(-1)
        
    def __len__(self):
        return self.data_size
    
    def __getitem__(self, idx):
        return self.x[idx]

In [None]:
class Sampler(object):
    def __init__(self, real_sigma, real_mean, 
                 noise_range, batch_size):
        self.batch_size = batch_size
        #initialize data providers
        self.real_data = RealData(self.batch_size, 
                                  real_sigma, real_mean)
        self.fake_data = FakeData(self.batch_size, noise_range)
        self.real_inputs = DataLoader(
            self.real_data, batch_size=self.batch_size, 
            shuffle=False, num_workers=1)
        self.fake_inputs = DataLoader(
            self.fake_data, batch_size=self.batch_size, 
            shuffle=False, num_workers=1)
        
    def sample_epsilon(self):
        return torch.FloatTensor(sample_epsilon(
            self.batch_size)).unsqueeze(-1)

### Define the networks

First we define base classes to collect the common functionalities.

In [None]:
class ModelBase(nn.Module):
    def __init__(self):
        super(ModelBase, self).__init__()
        
    def _init_layers(self):
        for layer in self.fc_layers:
            nn.init.normal(layer.weight.data)
            nn.init.constant(layer.bias.data, 0.0)
    
    def _forward_first(self, x):
        for fc, activ in zip(self.fc_layers, self.activations):
            x = activ(fc(x))
        return x

In [None]:
class Discriminator(ModelBase):
    def __init__(self, wasserstein=False):
        super(Discriminator, self).__init__()
        
        self.wasserstein = wasserstein

        if wasserstein:
            self.forward = self._forward_wasserstein
        else:
            self.last_activation = nn.Sigmoid()
            self.forward = self._forward_vanilla
    
    def is_wasserstein(self):
        return self.wasserstein
    
    def _forward_vanilla(self, x):
        return self.last_activation(
            self.fc_layers[-1](self._forward_first(x)))
    
    def _forward_wasserstein(self, x):
        return self.fc_layers[-1](self._forward_first(x))

Networks as proposed in http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/

In [None]:
class Generator1(ModelBase):
    def __init__(self, input_dim, hidden_dim):
        super(Generator1, self).__init__()
        
        self.fc_layers = nn.ModuleList(
            [nn.Linear(input_dim, hidden_dim)] +
            [nn.Linear(hidden_dim, 1)])
        
        self.activations = nn.ModuleList(
            [nn.Softplus()])
        
        self._init_layers()
    
    def forward(self, x):
        return self.fc_layers[-1](self._forward_first(x))

In [None]:
class Discriminator1(Discriminator):
    def __init__(self, input_dim, hidden_dim, wasserstein=False):
        super(Discriminator1, self).__init__(wasserstein)
        
        self.fc_layers = nn.ModuleList(
            [nn.Linear(input_dim, hidden_dim)] + 
            [nn.Linear(hidden_dim, hidden_dim) for i in range(2)] +
            [nn.Linear(hidden_dim, 1)])
        
        self.activations = nn.ModuleList(
            [nn.ReLU() for i in range(3)])
        
        self._init_layers()

Networks as proposed in 
https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f

In [None]:
class Generator2(ModelBase):
    def __init__(self, input_dim, hidden_dim):
        super(Generator2, self).__init__()
        
        self.fc_layers = nn.ModuleList(
            [nn.Linear(input_dim, hidden_dim), 
             nn.Linear(hidden_dim, hidden_dim),
             nn.Linear(hidden_dim, 1)])
        
        self.activations = nn.ModuleList(
            [nn.ELU(), nn.Sigmoid()])
        
        self._init_layers()
    
    def forward(self, x):
        return self.fc_layers[-1](self._forward_first(x))

In [None]:
class Discriminator2(Discriminator):
    def __init__(self, input_dim, hidden_dim, wasserstein=False):
        super(Discriminator2, self).__init__(wasserstein)
        
        self.fc_layers = nn.ModuleList(
            [nn.Linear(input_dim, hidden_dim), 
             nn.Linear(hidden_dim, hidden_dim),
             nn.Linear(hidden_dim, 1)])
        
        self.activations = nn.ModuleList(
            [nn.ELU() for i in range(2)])
        
        self._init_layers()

Networks as proposed in https://github.com/kremerj/gan

In [None]:
class Generator3(ModelBase):
    def __init__(self, input_dim,  hidden_dim):
        super(Generator3, self).__init__()
        
        self.fc_layers = nn.ModuleList([nn.Linear(1, hidden_dim),
                                        nn.Linear(hidden_dim, 1)])
        self.activation = nn.ReLU()
    
        self._init_layers()
                                       
    def forward(self, x):
        return self.fc_layers[1](self.activation(self.fc_layers[0](x)))

In [None]:
class Discriminator3(Discriminator):
    def __init__(self, input_dim, hidden_dim, wasserstein=False):
        super(Discriminator3, self).__init__(wasserstein)
        
        self.fc_layers = nn.ModuleList([nn.Linear(1, hidden_dim),
                                        nn.Linear(hidden_dim, 1)])
        self.activations = nn.ModuleList([nn.ReLU()])
    
        self._init_layers() 

### Define the GAN

This combines sampler, generator and discriminator, and defines the loss function and optimizers.

In [None]:
class GAN(nn.Module):
    def __init__(self, sampler, generator, discriminator, flavor='vanilla'):
        super(GAN, self).__init__()
        
        # store the sampling
        self.sampler = sampler
        
        # store the models
        self.generator = generator
        self.discriminator = discriminator
        
        self.flavor = flavor
        
        # ToDo: use decorators instead of this heavy conditionals
        
        # define the criterion
        if flavor == 'vanilla':
            self.number_discrim_iterations = 1
            self.real_labels = torch.ones(self.sampler.batch_size,1)
            self.fake_labels = torch.zeros(self.sampler.batch_size,1)
            self.criterion = nn.BCELoss() # binary cross entropy loss
            self._eval_criterion_real = self._eval_criterion_real_vanilla
            self._eval_criterion_fake = self._eval_criterion_fake_vanilla
            self._loss_penalty = self._no_loss_penalty
            self._postprocess_discriminator = \
                self._no_postprocess_discriminator
                
        if 'wasserstein' in flavor:
            self.number_discrim_iterations = 5
            self.criterion = torch.mean
            self._eval_criterion_real = self._eval_criterion_real_wasserstein
            self._eval_criterion_fake = self._eval_criterion_fake_wasserstein
           
        if flavor == 'wasserstein':
            self._loss_penalty = self._no_loss_penalty
            self._postprocess_discriminator = \
                self._postprocess_discriminator_wasserstein
            
        if flavor == 'wasserstein_improved':
            self._loss_penalty = self._loss_penalty_wasserstein_improved
            self.regularization_lambda = 0.1
            self._postprocess_discriminator = \
                self._no_postprocess_discriminator
        
        # create optimizers
        self.optim_gen = optim.Adam(
            self.generator.parameters(), lr=0.001, betas=(0.5, 0.9))
        self.optim_discrim = optim.Adam(
            self.discriminator.parameters(), lr=0.001, betas=(0.5, 0.9))
        
        # for logging
        self.discrim_loss = {'real': None, 'fake': None, 'penalty': None}
        self.gen_loss = None
    
    def _eval_criterion_real_vanilla(self, output):
        return self.criterion(output, Variable(self.real_labels))
    
    def _eval_criterion_fake_vanilla(self, output):
        return self.criterion(output, Variable(self.fake_labels))
    
    def _eval_criterion_real_wasserstein(self, output):
        return -self.criterion(output)
    
    def _eval_criterion_fake_wasserstein(self, output):
        return self.criterion(output)
    
    def _no_postprocess_discriminator(self):
        return
    
    def _postprocess_discriminator_wasserstein(self):
        for p in self.discriminator.parameters():
            p.data.clamp_(-0.01, 0.01)
            
    def _postprocess_discriminator_wasserstein_improved(self):
        return
        
    def _no_loss_penalty(self, real_x, forged_x):
        return None
    
    def _loss_penalty_wasserstein_improved(self, real_x, forged_x):
        
        eps = self.sampler.sample_epsilon()
        penalty_input = Variable(eps * real_x + (1.0 - eps) * forged_x, 
                                 requires_grad=True)
        
        # pytorch can only compute gradients for a scalar output.
        # To enable computing gradients for the discriminator,
        # we need to give the dummy gradient w.r.t to a scalar output.
        output = self.discriminator(penalty_input)
        grads = torch.autograd.grad(output, penalty_input, 
                                    grad_outputs=
                                        output.data.new(output.shape).fill_(1),
                                    create_graph=True,
                                    retain_graph=True,
                                    only_inputs=True)[0]
        return self.regularization_lambda * torch.mean(
                        (torch.norm(grads,2,1) - 1.0) ** 2)
        
    def learn(self):
        # GAN learning without resampling for the generator step
        
        self.sampler.real_data.sample()
        self.sampler.fake_data.sample()
        
        # iterate over minibatches (in this case only one)
        for _ in range(self.number_discrim_iterations):
            for real_x, fake_x in \
                zip(self.sampler.real_inputs, self.sampler.fake_inputs):

                self.discrim_loss['real'], self.discrim_loss['fake'], \
                self.discrim_loss['penalty'] = \
                    self._step_discriminator(real_x, fake_x)

        self.gen_loss = self._step_generator(fake_x)
                             
    def learn_discriminator(self):
        
        self.sampler.real_data.sample()
        self.sampler.fake_data.sample()
        
        for _ in range(self.number_discrim_iterations):
            for real_x, fake_x in \
                zip(self.sampler.real_inputs, self.sampler.fake_inputs):

                self.discrim_loss['real'], self.discrim_loss['fake'], \
                self.discrim_loss['penalty'] = \
                    self._step_discriminator(real_x, fake_x)
            
    def learn_generator(self):
        
        self.sampler.fake_data.sample()
        
        for x in self.sampler.fake_inputs:
            self.gen_loss = self._step_generator(x)
            
    def _step_discriminator(self, real_x, fake_x):
        
        # Learn discriminator (keep generator fixed).
        # We want to compute the loss = 
        # -log(discrim(real_x)) - log(1 - discrim(gen(fake_x)).
        # To do so, both terms can be computed separately,
        # because the gradients are added together
        # when calling backward()
        
        self.discriminator.zero_grad()
        output = self.discriminator(Variable(real_x))
        loss_real = self._eval_criterion_real(output)
        loss_real.backward()

        forged = self.generator(Variable(fake_x))
        # call detach because we don't need to update
        # the gradients for the generator (parameters are 
        # not trained in this step)
        output = self.discriminator(forged.detach())
        loss_fake = self._eval_criterion_fake(output)
        loss_fake.backward()
        
        loss_penalty = self._loss_penalty(real_x, forged.detach().data)
        if loss_penalty is not None:
            loss_penalty.backward()
            loss_penalty_value = loss_penalty.data[0]
        else:
            loss_penalty_value = None
            
        self.optim_discrim.step()
        
        self._postprocess_discriminator()
        
        return loss_real.data[0], loss_fake.data[0], loss_penalty_value
    
    def _step_generator(self, x):
    
        # Learn generator (keep discriminator fixed).
        
        self.generator.zero_grad()
        output = self.discriminator(self.generator(Variable(x)))
        # TODO: Not sure here if gradients are computed for the discriminator,
        # and if so, how to prevent that. 
        loss = self._eval_criterion_real(output)
        loss.backward()
        self.optim_gen.step()
        
        return loss.data[0]

### Logging and plotting functions

Logging of losses.

In [None]:
def log_discriminator(loss_real, loss_fake):
    print('discriminator | '
          '{:.4f} | {:.4f} | {:.4f}'.format(
              loss_real, loss_fake, loss_real + loss_fake)) 
    
def log_generator(loss):
    print('generator | {:.4f}'.format(loss))

We will plot the decision boundary, the real Gaussian distribution, and the forged distribution as generated by the generator in a figure. The evolution of the losses will be plotted in a second figure. The following functions construct data structures to store the data needed to update the figures in bokeh.

In [None]:
def get_data_real_hist(numb_points, numb_bins, sampler):
    
    mean = sampler.real_data.mean
    sigma = sampler.real_data.sigma
    sample_range = sampler.fake_data.bound
    
    dat = {}
    
    samples_real = sample_real(numb_points, sigma, mean)
    
    hist_bins = np.linspace(-sample_range, sample_range, numb_bins)
    dat['hist_real'], _ = np.histogram(samples_real, 
                                       bins=hist_bins, density=True) 
    dat['x'] = np.linspace(-sample_range, sample_range, numb_bins - 1)
    
    return dat

In [None]:
def get_data_gan_hist(numb_points, numb_bins, gan):
    
    sample_range = gan.sampler.fake_data.bound
    
    dat = {}
    
    # fake distribution
    gen_x = torch.FloatTensor(np.linspace(
        -sample_range, sample_range, numb_points)).unsqueeze(-1)
    samples_fake = gan.generator(Variable(gen_x)).data.numpy().squeeze(-1)
    hist_bins = np.linspace(-sample_range, sample_range, numb_bins)
    dat['hist_fake'], _ = np.histogram(samples_fake, 
                                       bins=hist_bins, density=True)
    
    # decision boundary (numb_bins - 1) to get same length as 
    # fake distribution data, otherwise bokeh is complaining
    xs = np.linspace(-sample_range, sample_range, numb_bins - 1)
    discr_x = torch.FloatTensor(xs).unsqueeze(-1)
    decision_bound = gan.discriminator(
        Variable(discr_x)).data.numpy().squeeze(-1) 
    
    if gan.discriminator.is_wasserstein():
        max_label = np.max(decision_bound)
        min_label = np.min(decision_bound)
        range_label = max_label - min_label
        if abs(range_label) > 1e-5: 
            decision_bound = (decision_bound - min_label)/(max_label - min_label)
    
    dat['labels'] = decision_bound
    
    dat['x'] = xs
    
    return dat

In [None]:
def update_data_gan_loss(gan, old_dat, epoch):
    
    dat = {}
    dat['x'] = np.append(old_dat['x'], [epoch]) \
               if epoch else np.array([epoch])
    dloss_real = gan.discrim_loss['real']
    dloss_fake = gan.discrim_loss['fake'] 
    dloss_total = dloss_real + dloss_fake
    dat['loss_discrim_real'] = np.append(old_dat['loss_discrim_real'], 
                                  [dloss_real]) \
                               if epoch else np.array([dloss_real])
    dat['loss_discrim_fake'] = np.append(old_dat['loss_discrim_fake'], 
                                  [dloss_fake]) \
                               if epoch else np.array([dloss_fake])
    if gan.discriminator.is_wasserstein():
        dloss_penalty = gan.discrim_loss['penalty']
        dloss_total += dloss_penalty
        dat['loss_discrim_penalty'] = np.append(old_dat['loss_discrim_penalty'], 
                                  [dloss_penalty]) \
                               if epoch else np.array([dloss_penalty])
    dat['loss_discrim_total'] = np.append(old_dat['loss_discrim_total'], 
                                  [dloss_total]) \
                                if epoch else np.array([dloss_total])
    dat['loss_generator'] = np.append(old_dat['loss_generator'], 
                                      [gan.gen_loss]) \
                            if epoch else np.array([gan.gen_loss])
    
    return dat

This function displays the figures.

In [None]:
def show_data(data_real_hist, data_gan_hist, data_gan_loss, bound):
    
    lw = 2
    
    fig_hist = figure(plot_width=350, plot_height=350,
                      x_range=(-bound, bound))

    fig_hist.line('x', 'hist_real', source=data_real_hist,
             line_width=lw, line_color='blue', legend='real distribution')
    fig_hist.line('x', 'labels', source=data_gan_hist,
             line_width=lw, line_color='green', legend='decision')
    fig_hist.line('x', 'hist_fake', source=data_gan_hist, 
             line_width=lw, line_color='red', legend='fake distribution')
    
    # unfortunately legends cannot be moved (yet) interactively
    fig_hist.legend.location = "top_left"
    fig_hist.legend.click_policy="hide"
    
    fig_loss = figure(plot_width=640, plot_height=350)
    fig_loss.line('x', 'loss_discrim_real', source=data_gan_loss, 
                  line_width=lw, line_color='blue', legend='D-loss real')
    fig_loss.line('x', 'loss_discrim_fake', source=data_gan_loss, 
                  line_width=lw, line_color='green', legend='D-loss fake')
    if 'loss_discrim_penalty' in data_gan_loss.data:
        fig_loss.line('x', 'loss_discrim_penalty', source=data_gan_loss, 
                  line_width=lw, line_color='magenta', legend='D-loss penalt')
    fig_loss.line('x', 'loss_discrim_total', source=data_gan_loss, 
                  line_width=lw, line_color='red', legend='D-loss total')
    fig_loss.line('x', 'loss_generator', source=data_gan_loss, 
                  line_width=lw, line_color='black', legend='G-loss')
    
    show(row(fig_hist, fig_loss), notebook_handle=True)

### Train the GAN

encapsulate training in a function to make it reusable:

In [None]:
def train_gan(gan, nb_points, nb_bins, log_step, number_epochs):
    data_real = ColumnDataSource(
    data=get_data_real_hist(nb_points, nb_bins, sampler))
    data_gan_hist = ColumnDataSource(
    data={'x': [], 'labels': [], 'hist_fake': []})
    # when using empty arrays, y-axis tick labels are not showing
    model_data = {'x': [0.], 
                  'loss_discrim_real': [0.],
                  'loss_discrim_fake': [0.],
                  'loss_discrim_total': [0.],
                  'loss_generator': [0.]}
    if gan.discriminator.is_wasserstein():
        model_data['loss_discrim_penalty'] = [0.]
    data_gan_loss = ColumnDataSource(data=model_data)

    show_data(data_real, data_gan_hist, data_gan_loss, 
              gan.sampler.fake_data.bound)

    for epoch in range(number_epochs):
        #gan.learn()
        gan.learn_discriminator()
        gan.learn_generator()

        if not (epoch % log_step):
            #log_discriminator(gan.discrim_loss['real'], gan.discrim_loss['fake'])
            #log_generator(gan.gen_loss)
            data_gan_hist.data = get_data_gan_hist(4000, 100, gan)
            data_gan_loss.data = update_data_gan_loss(
                gan, data_gan_loss.data, epoch)
            push_notebook()   

### With Generator1 and Discriminator1

Note that this problem is not stable and can generate various solution, depending on the random seed or on how many times you run the cells. As explained in the blog, the generator will typically converge to a distribution that has similar range but has a more narrow shape. To improve convergence, a pre-training of the decision boundary can be applied as explained in https://blog.evjang.com/2016/06/generative-adversarial-nets-in.html 

In [None]:
# define seed here such that each variant is 
# tested with same random sequence
np.random.seed(42)
real_sigma = 0.5
real_mean = 4.0
noise_range = 8.0
hidden_dim = 4
batch_size = 8
sampler = Sampler(real_sigma, real_mean, noise_range, batch_size)

#### Using vanilla GAN

In [None]:
gen1 = Generator1(1, hidden_dim)
discrim1 = Discriminator1(1, 2*hidden_dim)
gan1 = GAN(sampler, gen1, discrim1, flavor='vanilla')

In [None]:
train_gan(gan1, nb_points=4000, nb_bins=100, log_step=10, number_epochs=5000)

#### Using Wasserstein GAN (WGAN)

This doesn't seem to work on this example. The discriminator loss function collapses to zero for both the real and fake terms, and the weights in both the generator and discriminator becomes zero. This makes sense, since both the discriminator and the generator loss function becomes zero, if everything is mapped to zero.

In [None]:
gen1 = Generator1(1, hidden_dim)
discrim1 = Discriminator1(1, 2*hidden_dim, wasserstein=True)
gan1 = GAN(sampler, gen1, discrim1, flavor='wasserstein')

In [None]:
train_gan(gan1, nb_points=4000, nb_bins=100, log_step=10, number_epochs=2000)

#### Using Wasserstein GAN with improved training

In [None]:
gen1 = Generator1(1, hidden_dim)
discrim1 = Discriminator1(1, 2*hidden_dim, wasserstein=True)
gan1 = GAN(sampler, gen1, discrim1, flavor='wasserstein_improved')

In [None]:
train_gan(gan1, nb_points=4000, nb_bins=100, log_step=10, number_epochs=1500)

### With Generator2 and Discriminator2

In [None]:
# define seed here such that each variant is 
# tested with same random sequence
np.random.seed(42)
real_sigma = 0.5
real_mean = 4.0
noise_range = 8.0
hidden_dim = 4
batch_size = 8
sampler = Sampler(real_sigma, real_mean, noise_range, batch_size)

#### Using vanilla GAN

In [None]:
gen2 = Generator2(1, hidden_dim)
discrim2 = Discriminator2(1, hidden_dim)
gan2 = GAN(sampler, gen2, discrim2)

In [None]:
train_gan(gan2, nb_points=4000, nb_bins=100, log_step=10, number_epochs=5000)

#### Using Wasserstein GAN with improved training

In [None]:
gen2 = Generator2(1, hidden_dim)
discrim2 = Discriminator2(1, hidden_dim, wasserstein=True)
gan2 = GAN(sampler, gen2, discrim2, flavor='wasserstein_improved')

In [None]:
train_gan(gan2, nb_points=4000, nb_bins=100, log_step=10, number_epochs=1500)

### With Generator3 and Discriminator3

In [None]:
# define seed here such that each variant is 
# tested with same random sequence
np.random.seed(42)
real_sigma = 0.5
real_mean = 4.0
noise_range = 8.0
hidden_dim = 4
batch_size = 64
sampler = Sampler(real_sigma, real_mean, noise_range, batch_size)

#### Using Vanilla GAN

This doesn't work well. The convergence get stuck, probably because the discriminator is not powerfull enough.

In [None]:
gen3 = Generator3(1, hidden_dim)
discrim3 = Discriminator3(1, hidden_dim)
gan3 = GAN(sampler, gen3, discrim3)

In [None]:
train_gan(gan3, nb_points=4000, nb_bins=100, log_step=10, number_epochs=2000)

In [None]:
gen3 = Generator3(1, hidden_dim)
discrim3 = Discriminator3(1, hidden_dim, wasserstein=True)
gan3 = GAN(sampler, gen3, discrim3, flavor='wasserstein_improved')

In [None]:
train_gan(gan3, nb_points=4000, nb_bins=100, log_step=10, number_epochs=1450)

### Optional

#### Test generator

This performs a test to validate the behavior of the generator. It's tested against an implementation of the same model using numpy. It illustrates how weights and bias values in PyTorch layers can be set to evaluate a model.

In [None]:
def generator_numpy():
    in_vals = np.matrix([[1. , 2., 3., 4.]])
    w1 = np.matrix([[1.], [2.], [3.], [4.]])
    w2 = np.matrix([[1., 2., 3., 4.]])
    res1 = w1 * in_vals
    res2 = np.log(1. + np.exp(res1))
    res3 = w2 * res2
    return res3

In [None]:
hidden_dims_test = 4
test_weights = np.array([[1.], [2.], [3.], [4.]])
weights_gen_fc1 = test_weights
bias_gen_fc1 = np.zeros((hidden_dims_test,1))
weights_gen_fc2 = test_weights.transpose()
bias_gen_fc2 = np.zeros((1,1))

gen = Generator(1,hidden_dims_test)
gen.fc1.weight.data = torch.FloatTensor(weights_gen_fc1)
gen.fc1.bias.data = torch.FloatTensor(bias_gen_fc1)
gen.fc2.weight.data = torch.FloatTensor(weights_gen_fc2)
gen.fc2.bias.data = torch.FloatTensor(bias_gen_fc2)

res = gen(Variable(torch.FloatTensor([[1.], [2.], [3.], [4.]])))

np.testing.assert_almost_equal(
    res.data.numpy(),                            
    generator_numpy().transpose(), decimal=5)

#### Test discriminator

This performs a test of the discriminator model, using baseline values obtained by running equivalent model in TensorFlow.

In [None]:
hidden_dims_test = 4
test_weights = np.array([[0.5, -1.5, 0.33, -0.15]])
weights = [test_weights.transpose(), 
           np.tile(test_weights, (4,1)),
           np.tile(test_weights, (4,1)).transpose(), 
           np.array(np.fliplr(test_weights))] 
# pytorch does not (yet) support negative stride

discrim = Discriminator(1, hidden_dims_test)    
for layer, w in zip(discrim.fc_layers, weights):
    layer.weight.data = torch.FloatTensor(w)

res = discrim(Variable(torch.FloatTensor(
    [[1.0],[0.5],[3.0],[0.0]])))

np.testing.assert_almost_equal(
    res.data.numpy(),
    [[0.3061263], [0.3991169], [0.0790827], [0.5]],
    decimal=7)

In [None]:
class test_grad(nn.Module):
    def __init__(self):
        super(test_grad, self).__init__()
        
        self.fc_layer1 = nn.Linear(2,1)
        #self.fc_layer2 = nn.Linear(2,1)
        #self.fc_layer1.weight.data = torch.FloatTensor()
        self.fc_layer1.weight.data = torch.FloatTensor([[3.0, 7.0]])
        #nn.init.constant(self.fc_layer2.weight.data, 3.0)
        nn.init.constant(self.fc_layer1.bias.data, 0.0)
        #nn.init.constant(self.fc_layer2.bias.data, 0.0)
    
    def forward(self, x):
        return self.fc_layer1(x)

In [None]:
g = test_grad()

inp = Variable(torch.FloatTensor([[1.0, 4.0],[2.0, 5.0]]), requires_grad=True)
outp = g(inp)
gradspred, = torch.autograd.grad(outp, inp, 
                           grad_outputs=outp.data.new(outp.shape).fill_(1),
                           create_graph=True)
print(outp)
print(gradspred)
print((torch.norm(gradspred,2,1) - 1.0) ** 2)

print(6.6158 * 6.6158)
#print(g.fc_layer1.weight.grad)
#print(g.fc_layer2.weight.grad)
#print(g.fc_layer1.weight.data)