## 1D GAN distribution sampling

This notebook contains a PyTorch implementation of the GAN example as presented at   
http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/

It's a relatively simple problem in which the generator tries to mimic a Gaussian distribution.
The generator is fed by sampling a noise distribution while the discriminator tries to distinguish between samples drawn from the Gaussian distribution and samples as generated by the generator from the noise samples.

In [1]:
import numpy as np

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, ConcatDataset, DataLoader

from bokeh.models import ColumnDataSource
from bokeh.plotting import output_notebook, figure, show
from bokeh.io import push_notebook

In [2]:
#embed figures in the notebook
output_notebook()

### Define the sampling

In [3]:
np.random.seed(42)

In [4]:
def sample_real(n, sigma=1.0, mean=-1.0):
    return np.sort(np.random.normal(mean, sigma, n))

In [5]:
def sample_noise(n, bound=5.0):
    #use stratified sampling to ensure the mapping maintains ordering 
    # (see blog for details)
    return np.linspace(-bound, bound, n) + \
            np.random.random(n) * 0.01

In [6]:
class Sampler(object):
    def __init__(self, real_sigma, real_mean, 
                 noise_range, batch_size):
        self.batch_size = batch_size
        #initialize data providers
        self.real_data = RealData(self.batch_size, 
                                  real_sigma, real_mean)
        self.fake_data = FakeData(self.batch_size, noise_range)
        self.real_inputs = DataLoader(
            self.real_data, batch_size=self.batch_size, 
            shuffle=False, num_workers=1)
        self.fake_inputs = DataLoader(
            self.fake_data, batch_size=self.batch_size, 
            shuffle=False, num_workers=1)

### Wrap sampling in PyTorch Datasets

to use PyTorch sampling functionalities when training the networks.

In [7]:
class RealData(Dataset):
    def __init__(self, n, sigma, mean):
        self.data_size = n
        self.sigma = sigma
        self.mean = mean
    
    def sample(self):
        self.x = torch.FloatTensor(sample_real(
            self.data_size, self.sigma, self.mean)).unsqueeze(-1)
        
    def __len__(self):
        return self.data_size
    
    def __getitem__(self, idx):
        return self.x[idx]

In [8]:
class FakeData(Dataset):
    def __init__(self, n, bound):
        self.data_size = n
        self.bound = bound
        
    def sample(self):
        self.x = torch.FloatTensor(sample_noise(
            self.data_size, self.bound)).unsqueeze(-1)
        
    def __len__(self):
        return self.data_size
    
    def __getitem__(self, idx):
        return self.x[idx]

### Define the networks

In [9]:
class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Generator, self).__init__()
        # first linear fully connected layer
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        # apply nonlinear activation using softplus
        self.softplus = nn.Softplus()
        # reduce to single output using fully connected layer
        self.fc2 = nn.Linear(hidden_dim, 1)
        
        nn.init.normal(self.fc1.weight.data)
        nn.init.constant(self.fc1.bias.data, 0.0)
        nn.init.normal(self.fc2.weight.data)
        nn.init.constant(self.fc2.bias.data, 0.0)
    
    def forward(self, x):
        x = self.fc2(self.softplus(self.fc1(x)))
        return x

In [10]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Discriminator, self).__init__()
        
        self.fc_layers = nn.ModuleList(
            [nn.Linear(input_dim, hidden_dim)] + 
            [nn.Linear(hidden_dim, hidden_dim) for i in range(2)] +
            [nn.Linear(hidden_dim, 1)])
        
        self.activations = nn.ModuleList(
            [nn.ReLU() for i in range(3)] + [nn.Sigmoid()])
        
        for layer in self.fc_layers:
            nn.init.normal(layer.weight.data)
            nn.init.constant(layer.bias.data, 0.0)
        
    def forward(self, x):
        for fc, activ in zip(self.fc_layers, self.activations):
            x = activ(fc(x))
        return x

### Define the GAN

This combines sampler, generator and discriminator, and defines the loss function and optimizers.

In [11]:
class GAN(nn.Module):
    def __init__(self, sampler, hidden_dim):
        super(GAN, self).__init__()
        
        self.sampler = sampler
        
        # initialize labels used for training
        self.real_labels = torch.FloatTensor(
            np.ones((self.sampler.batch_size,1)))
        self.fake_labels = torch.FloatTensor(
            np.zeros((self.sampler.batch_size,1)))
        
        # construct models
        self.generator = Generator(1, hidden_dim)
        # to make sure the discriminator does not get overwhelmed,
        # we give it two times more hidden neurons 
        self.discriminator = Discriminator(1, 2 * hidden_dim)      
        self.criterion = nn.BCELoss() # binary cross entropy loss
        self.optim_gen = optim.Adam(
            self.generator.parameters(), lr=0.001)
        self.optim_discrim = optim.Adam(
            self.discriminator.parameters(), lr=0.001)
        
        # for logging
        self.discrim_loss = {'real': None, 'fake': None}
        self.gen_loss = None
    
    def learn(self):
        # GAN learning without resampling for the generator step
        
        self.sampler.real_data.sample()
        self.sampler.fake_data.sample()
        
        # iterate over minibatches (in this case only one)
        for real_x, fake_x in \
            zip(self.sampler.real_inputs, self.sampler.fake_inputs):
            
            self.discrim_loss['real'], self.discrim_loss['fake'] = \
                self._step_discriminator(real_x, fake_x)
            self.gen_loss = self._step_generator(fake_x)
                             
    def learn_discriminator(self):
        
        self.sampler.real_data.sample()
        self.sampler.fake_data.sample()
        
        for real_x, fake_x in \
            zip(self.sampler.real_inputs, self.sampler.fake_inputs):
        
            self.discrim_loss['real'], self.discrim_loss['fake'] = \
                    self._step_discriminator(real_x, fake_x)
            
    def learn_generator(self):
        
        self.sampler.fake_data.sample()
        
        for x in self.sampler.fake_inputs:
            self.gen_loss = self._step_generator(x)
            
    def _step_discriminator(self, real_x, fake_x):
        
        # Learn discriminator (keep generator fixed).
        # We want to compute the loss = 
        # -log(discrim(real_x)) - log(1 - discrim(gen(fake_x)).
        # To do so, both terms can be computed separately,
        # because the gradients are added together
        # when calling backward()
        
        self.discriminator.zero_grad()
        output = self.discriminator(Variable(real_x))
        loss_real = self.criterion(output, Variable(self.real_labels))
        loss_real.backward()

        forged = self.generator(Variable(fake_x))
        # call detach because we don't need to update
        # the gradients for the generator (parameters are 
        # not trained in this step)
        output = self.discriminator(forged.detach())
        loss_fake = self.criterion(output, Variable(self.fake_labels))
        loss_fake.backward()
        
        self.optim_discrim.step()
        
        return loss_real.data[0], loss_fake.data[0]
    
    def _step_generator(self, x):
    
        # Learn generator (keep discriminator fixed).
        
        self.generator.zero_grad()
        output = self.discriminator(self.generator(Variable(x)))
        # TODO: Not sure here if gradients are computed for the discriminator,
        # and if so, how to prevent that. 
        loss = self.criterion(output, Variable(self.real_labels))
        loss.backward()
        self.optim_gen.step()
        
        return loss.data[0]

### Logging and plotting functions

Logging of losses.

In [12]:
def log_discriminator(loss_real, loss_fake):
    print('discriminator | '
          '{:.4f} | {:.4f} | {:.4f}'.format(
              loss_real, loss_fake, loss_real + loss_fake)) 
    
def log_generator(loss):
    print('generator | {:.4f}'.format(loss))

We will plot the decision boundary, the real Gaussian distribution, and the forged distribution as generated by the generator. The following functions construct data structures to store the data needed to update the figure in bokeh.

In [13]:
def fig_data_real(numb_points, numb_bins, sampler):
    
    mean = sampler.real_data.mean
    sigma = sampler.real_data.sigma
    sample_range = gan.sampler.fake_data.bound
    
    dat = {}
    
    samples_real = sample_real(numb_points, sigma, mean)
    
    hist_bins = np.linspace(-sample_range, sample_range, numb_bins)
    dat['hist_real'], _ = np.histogram(samples_real, 
                                       bins=hist_bins, density=True) 
    dat['x'] = np.linspace(-sample_range, sample_range, numb_bins - 1)
    
    return dat

In [14]:
def fig_data_gan(numb_points, numb_bins, gan):
    
    sample_range = gan.sampler.fake_data.bound
    
    dat = {}
    
    # fake distribution
    gen_x = torch.FloatTensor(np.linspace(
        -sample_range, sample_range, numb_points)).unsqueeze(-1)
    samples_fake = gan.generator(Variable(gen_x)).data.numpy().squeeze(-1)
    hist_bins = np.linspace(-sample_range, sample_range, numb_bins)
    dat['hist_fake'], _ = np.histogram(samples_fake, 
                                       bins=hist_bins, density=True)
    
    # decision boundary (numb_bins - 1 to get same length as 
    # fake distribution data, otherwise bokeh is complaining
    xs = np.linspace(-sample_range, sample_range, numb_bins - 1)
    discr_x = torch.FloatTensor(xs).unsqueeze(-1)
    dat['labels'] = gan.discriminator(
        Variable(discr_x)).data.numpy().squeeze(-1) 
    
    dat['x'] = xs
    
    return dat

This function displays the figure.

In [15]:
def show_data(data_real, data_gan, bound):
    fig = figure(x_range=(-bound, bound))

    lw = 2
    fig.line('x', 'hist_real', source=data_real,
             line_width=lw, line_color='blue', legend='real distribution')
    fig.line('x', 'labels', source=data_gan,
             line_width=lw, line_color='green', legend='decision')
    fig.line('x', 'hist_fake', source=data_gan, 
             line_width=lw, line_color='red', legend='fake distribution')
    
    fig.legend.location = "top_left"
    #fig.legend.click_policy="hide"
    
    show(fig, notebook_handle=True)

### Train the GAN

Note that this problem is not stable and can generate various solution, depending on the random seed or on how many times you run the cells. As explained in the blog, the generator will typically converge to a distribution that has similar range but has a more narrow shape. To improve convergence, a pre-training of the decision boundary can be applied as explained in https://blog.evjang.com/2016/06/generative-adversarial-nets-in.html 

In [16]:
real_sigma = 0.5
real_mean = 4.0
noise_range = 8.0
hidden_dim = 4
batch_size = 8
sampler = Sampler(real_sigma, real_mean, noise_range, batch_size)
gan = GAN(sampler, hidden_dim)

In [17]:
nb_points = 4000
nb_bins = 100

log_step = 10
number_epochs = 5000

# For logging losses, set log_switch=True.
# For plotting, set log_switch=False.
log_switch = False

# cannot initialize to empty dict 
# without bokeh generating warning
data_real = ColumnDataSource(
    data=fig_data_real(nb_points, nb_bins, sampler))
data_gan = ColumnDataSource(
    data={'x': [], 'labels': [], 'hist_fake': []})

show_data(data_real, data_gan, 
          gan.sampler.fake_data.bound)

for epoch in range(number_epochs):
    #gan.learn()
    gan.learn_discriminator()
    gan.learn_generator()
    
    if not (epoch % log_step):
        if log_switch:
            print('epoch {:d}'.format(epoch))
            log_discriminator(gan.discrim_loss['real'],
                              gan.discrim_loss['fake'])
            log_generator(gan.gen_loss)
        else:
            data_gan.data = fig_data_gan(4000, 100, gan)
            push_notebook()   
    
# in case of logging losses, plot the final state    
if log_switch:
    data_gan.data = fig_data_gan(4000, 100, gan)
    push_notebook()   

### Optional

#### Test generator

This performs a test to validate the behavior of the generator. It's tested against an implementation of the same model using numpy. It illustrates how weights and bias values in PyTorch layers can be set to evaluate a model.

In [None]:
def generator_numpy():
    in_vals = np.matrix([[1. , 2., 3., 4.]])
    w1 = np.matrix([[1.], [2.], [3.], [4.]])
    w2 = np.matrix([[1., 2., 3., 4.]])
    res1 = w1 * in_vals
    res2 = np.log(1. + np.exp(res1))
    res3 = w2 * res2
    return res3

In [None]:
hidden_dims_test = 4
test_weights = np.array([[1.], [2.], [3.], [4.]])
weights_gen_fc1 = test_weights
bias_gen_fc1 = np.zeros((hidden_dims_test,1))
weights_gen_fc2 = test_weights.transpose()
bias_gen_fc2 = np.zeros((1,1))

gen = Generator(1,hidden_dims_test)
gen.fc1.weight.data = torch.FloatTensor(weights_gen_fc1)
gen.fc1.bias.data = torch.FloatTensor(bias_gen_fc1)
gen.fc2.weight.data = torch.FloatTensor(weights_gen_fc2)
gen.fc2.bias.data = torch.FloatTensor(bias_gen_fc2)

res = gen(Variable(torch.FloatTensor([[1.], [2.], [3.], [4.]])))

np.testing.assert_almost_equal(
    res.data.numpy(),                            
    generator_numpy().transpose(), decimal=5)

#### Test discriminator

This performs a test of the discriminator model, using baseline values obtained by running equivalent model in TensorFlow.

In [None]:
hidden_dims_test = 4
test_weights = np.array([[0.5, -1.5, 0.33, -0.15]])
weights = [test_weights.transpose(), 
           np.tile(test_weights, (4,1)),
           np.tile(test_weights, (4,1)).transpose(), 
           np.array(np.fliplr(test_weights))] 
# pytorch does not (yet) support negative stride

discrim = Discriminator(1, hidden_dims_test)    
for layer, w in zip(discrim.fc_layers, weights):
    layer.weight.data = torch.FloatTensor(w)

res = discrim(Variable(torch.FloatTensor(
    [[1.0],[0.5],[3.0],[0.0]])))

np.testing.assert_almost_equal(
    res.data.numpy(),
    [[0.3061263], [0.3991169], [0.0790827], [0.5]],
    decimal=7)