# GANs &ndash; Part 1: Generation

Author: Sebastian Bujwid <bujwid@kth.se>

Refer to [./README.md](./README.md) for general comments.

---

## Scope

- Implementing a GAN model
- Training a GAN on two toy 2-D datasets:
    - Guassian blobs
    - Swiss roll
    
## Objectives

The point of this part of the practical is to get yourself familiar with a GAN model.
You are expected to learn technical implementation details of a GAN but also get a better feeling of the challenges of GAN training.
As this practical uses simple 2-D data, it will be easy to understand what the model does and what it has learned.

## GAN

A GAN model introduced by 
[Goodfellow _et al._ (2014)](https://papers.nips.cc/paper/2014/hash/5ca3e9b122f61f8f06494c97b1afccf3-Abstract.html)
consists of two networks trained in adversary, a generator and a discriminator.
The generator's goal is to produce realistic samples while the discriminator tries to tell apart the real data samples from the fake samples produced by the generator.
The better the generator gets at generating samples, the more difficult discriminator's job is.

Note that as GAN contain two separate networks, each trained to be better than the other, the training dynamics can be quite different from standard supervised training. It's not uncommon to observe the losses go up and down during training, as one of the networks get better than the other.

## Technical details

This Notebook has some `assert`s or checks which you can use to test your implementation of some of the functions. However, keep in mind that passing those checks does not necessarily mean that everything is correct!

Some of the `assert`s verify whether shapes of tensors are as expected. Those shapes, however, do not necessarily have to be as expected to have a correct implementation. You are, of course, free to do things differently but make sure you know what you're doing!

---

In [None]:
import numpy as np

import sklearn.datasets
import matplotlib.pyplot as plt
import seaborn as sns
import itertools

In [None]:
seed = 42

# Toy datasets

In [None]:
def dataset_swiss_roll(n_samples, seed):
    x, _ = sklearn.datasets.make_swiss_roll(n_samples, noise=0.5, random_state=seed)
    x = x[:, [0, 2]]
    return x / np.max(np.abs(x)) + np.array([0., 5.])

In [None]:
def dataset_s_curve(n_samples, seed):
    x, _ = sklearn.datasets.make_s_curve(n_samples, noise=0.15, random_state=seed)
    x = x[:, [0, 2]]
    return x

In [None]:
def dataset_blobs(n_samples, seed):
    x, _ = sklearn.datasets.make_blobs(
        n_samples=n_samples, n_features=2,
        centers=8, cluster_std=0.8,
        random_state=seed)
    return x

In [None]:
x = dataset_blobs(n_samples=10000, seed=seed)
sns.scatterplot(x=x[:, 0], y=x[:, 1])
plt.show()

# Loss functions

In [None]:
import jax
import jax.numpy as jnp
from jax.experimental import optimizers, stax
from jax import jit

If you have a GPU, make sure you're using it!

In [None]:
jax.devices()

### (Task) Implement `sigmoid_cross_entropy` function

First, we'll implement a `sigmoid_cross_entropy` function that transforms the input logits with the sigmoid function and computes the cross-entropy loss.

The function should compute the following:

$$loss(t, l) = - \left[t \cdot \log S(l) + (1 - t) \cdot \log(1 - S(l)) \right]$$

where `t` represents the input target value; `l` the logit; and `S(x)` the sigmoid function:
$$S(x) = \frac{1}{1 + e^{-x}}$$

Note, that the sigmoid & cross-entropy can be thought of as separate function but here we expect it to be a single function.

**Question:**
- Is there any (objective) potential benefit of implementing the combination of sigmoid & cross-entropy as a single function?

In [None]:
def sigmoid_cross_entropy(*, targets, logits):
    assert targets.shape == logits.shape
    
    raise NotImplementedError('Task: implement!')
    loss = NotImplemented

    return loss

#### (Task) Test `sigmoid_cross_entropy`

In [None]:
sigmoid_cross_entropy(targets=jnp.array(1.), logits=jnp.array(0.))

In [None]:
sigmoid_cross_entropy(targets=jnp.array([[1.], [1.]]), logits=jnp.array([[0.], [0.]]))

In [None]:
targets = np.array([0, 1, 0, 1, 1])
logits = np.array([0.55, -0.22, 0.8, 0.1, -0.42])
returned = sigmoid_cross_entropy(targets=targets, logits=logits)
expected_mean_ce_loss = np.array(0.9110423)
assert returned.shape == logits.shape, \
    'cross-entropy loss is not expected here to be averaged over samples'
np.testing.assert_almost_equal(expected_mean_ce_loss, returned.mean())

### (Optional task) Numerically stable `sigmoid_cross_entropy` loss

Feel free to skip this task (or maybe try it at the end after completing everything else?).
You should (hopefully) be able to complete the rest of the practical even if your `sigmoid_cross_entropy` function is not able to handle very large or small values, producing NaNs.
Note that you typically do not want your model to internally (hidden features, logits, etc.) operate on very large numbers. Such a behavior might indicate some "issues with training".
However, it does not have to imply any major issues and sometimes might be even expected.

Numerical stability of implementation can be important and it is good to be at least familiar with what are the main roots of numerical stability issues and how to address them, what are the common "tricks".


**(Optional) task:**
- Can your function compute the value for $t=0$ and $l=10000.0$ or do you get NaNs?
- Do you understand what is the source of the numerical instability (NaNs)?
    - Hint: Think of computing $\log(e^x)$ for $x = 10^9$, for example. The value of the expressions is impossible to calculate by `np.log(np.exp(x))` (returns inf/NaN) but we know that $\log(e^x)=x$.
    
- Can you compute by hand the (almost?) exact value of the function for the inputs as above? **Try to transform the formula manually and see if you can implement a more numerically stable `sigmoid_cross_entropy` function that works with very large or small numbers.**

- Note that expressions in math are usefully written in the form that is the most readable or intuitive. However, that form might not necessarily be the best to be implemented. Not just because of numerical stability but also in terms of speed or memory usage.

- Hopefully, in the future you'd be able to foresee numerical stability issues and address them! The operations like `log` and `exp` you have to be especially careful about!


In [None]:
sigmoid_cross_entropy(targets=jnp.array(0.),
                      logits=jnp.array(10000.))

### GAN: objectives

In this practical, we're going to use the GAN training objectives as introduced by
[Goodfellow _et al._ (2014)](https://papers.nips.cc/paper/2014/hash/5ca3e9b122f61f8f06494c97b1afccf3-Abstract.html)
(note that $D(x) \in (0, 1)$):

<br>
<div align="center">(discriminator)</div>
$$
-\mathbb{E}_{x \sim p_{\text{data}}(x)}
\left[
\log D(x)
\right]
%
-\mathbb{E}_{z \sim p_z(z)}
\left[
\log (1 - D(G(z)) )
\right]
$$

<br>
<div align="center">(generator: non-saturating objective)</div>
$$
-\mathbb{E}_{z \sim p_z(z)}
\left[
\log D(G(z))
\right].
$$

Note we use the **non-saturating** loss for the generator. The alternative is, of course, to have the generator optimize just the reverse of the discriminator's objective.
In the non-saturating loss, instead of training the generator to minimize the probability of the discriminator being correct, we maximize the probability of the discriminator incorrectly classifying generated samples as true.
Therefore, thanks to the formulation of non-saturating loss, the generator can learn (as it gets non-zero gradients) when the discriminator is very strong and correctly classifies generated samples as fake (that is when $log(1-D(G(z))$ would be close to 0).


Alternatively, there is a number of modified objectives for training GANs, such as for example:
- [WGAN](https://arxiv.org/abs/1701.07875)
- [WGAN-GP](https://papers.nips.cc/paper/2017/hash/892c3b1c6dccd52936e27cbd0ff683d6-Abstract.html)
- [LSGAN](https://openaccess.thecvf.com/content_ICCV_2017/papers/Mao_Least_Squares_Generative_ICCV_2017_paper.pdf)

### (Task) Implement discriminator's loss

Hint: use `sigmoid_cross_entropy` function

In [None]:
def discriminator_loss(*, real_logits, fake_logits):
    # NOTE: the inputs are expected to be logits, not probabilities!
    
    raise NotImplementedError('Task: implement!')
    loss = NotImplemented
    
    return loss

What result would you expect if both real & fake input logits were equal $0.0$?

Try out a couple of values! Does the function return what you would expect it to return?
What if you set the real logits high and fake logits low? What if it's vice-versa?

#### (Task) Test the discriminator loss

In [None]:
discriminator_loss(real_logits=jnp.array([[0.], [0.]]), fake_logits=jnp.array([[0.], [0.]]))

In [None]:
discriminator_loss(real_logits=jnp.array([0., 0.]), fake_logits=jnp.array([0., 0.]))

In [None]:
returned = discriminator_loss(
    real_logits=jnp.array([0.13, -0.11, -1.12, -0.02]),
    fake_logits=jnp.array([-0.88, 0.42, 0.42, -0.81])
)
assert returned.shape == (), \
        'discriminator loss is expected here to be averaged over samples'
expected = 1.5126383
np.testing.assert_almost_equal(returned, expected)

### (Task) Implement generator loss

Hint: Use `sigmoid_cross_entropy` function

In [None]:
def generator_loss(discriminator_fake_logits):
    # NOTE: the inputs are expected to be logits, not probabilities!
    
    raise NotImplementedError('Task: implement!')
    loss = NotImplemented
    
    return loss

#### (Task) Testing the generator loss

What result would you expect for the output of the generator $G(z)$ (which is the input to $D(.)$) equal $0.0$?
Does the result match your expectation?

In [None]:
generator_loss(discriminator_fake_logits=jnp.array([[0.], [0.,]]))

In [None]:
generator_loss(discriminator_fake_logits=jnp.array([0., 0.]))

In [None]:
returned = generator_loss(discriminator_fake_logits=jnp.array([1.12, -0.12, 0., 0.42, -0.82,]))
assert returned.shape == (), \
        'generator loss is expected here to be averaged over samples'
expected = 0.68409014
np.testing.assert_almost_equal(expected, returned)

# Training

In [None]:
from utils import plot_losses, generate_and_plot_samples, plot_logit_space

In [None]:
def iterable_dataset(key, data, batch_size):
    n_batches = len(data) // batch_size
    x_shuffled = jax.random.permutation(key, data)
    dataset = [x_shuffled[i * batch_size: (i + 1) * batch_size] for i in range(n_batches)]
    return dataset

In [None]:
from omegaconf import OmegaConf

In [None]:
def create_mlp(n_hid_layers, hid_dim, output_dim):
    layers = list(itertools.chain.from_iterable(
        [(stax.Dense(hid_dim), stax.Relu) for _ in range(n_hid_layers)]
    )) + [stax.Dense(output_dim)]
    mpl_init, mpl_apply = stax.serial(*layers)
    return mpl_init, mpl_apply

## (Task) Implement the `train_step` function

Note that the generator and discriminator's parameters should be updated individually.
Make sure that the generator's gradients and discriminator's gradients depend on the right "things".

In [None]:
def create_and_train_gan(hparams, data, seed=seed):
    data_dim = data[0].shape[-1]
    # Create models
    
    if hparams.model_type == 'mlp':
        generator_init, generator_apply = create_mlp(
            n_hid_layers=hparams.generator.n_layers,
            hid_dim=hparams.generator.hidden_dim,
            output_dim=data_dim,
        )
        discriminator_init, discriminator_apply = create_mlp(
            n_hid_layers=hparams.discriminator.n_layers,
            hid_dim=hparams.discriminator.hidden_dim,
            output_dim=1,
        )
        
    else:
        raise NotImplementedError(f'model_type: {hparams.model_type}')

    # Initialize
    key = jax.random.PRNGKey(seed)
    key, key_gen, key_dis = jax.random.split(key, 3)
    
    input_noise_shape = (-1, hparams.input_noise_dim)
    data_shape = (-1, data_dim)
    _, gen_params = generator_init(key_gen, input_noise_shape)
    _, dis_params = discriminator_init(key_dis, data_shape)

    ## Initialize: optimizers
    gen_opt_init, gen_opt_update, gen_get_params = optimizers.adam(
        step_size=hparams.gen_lr, b1=hparams.beta1)
    dis_opt_init, dis_opt_update, dis_get_params = optimizers.adam(
        step_size=hparams.dis_lr, b1=hparams.beta1)
    gen_opt_state = gen_opt_init(gen_params)
    dis_opt_state = dis_opt_init(dis_params)
    
    # NOTE: can comment out @jax.jit for debugging but
    #       make sure to use it while training, as it will be much faster
    # HINT: you are expected to define functions that contain forward passes
    @jax.jit
    def train_step(step, gen_opt_state, dis_opt_state, noise, real_samples):
        gen_params = gen_get_params(gen_opt_state)
        dis_params = dis_get_params(dis_opt_state)
        
        raise NotImplementedError('Task: implement!')
        gen_grads = NotImplemented
        dis_grads = NotImplemented
        gen_loss = NotImplemented
        dis_loss = NotImplemented

        gen_opt_state = gen_opt_update(step, gen_grads, gen_opt_state)
        dis_opt_state = dis_opt_update(step, dis_grads, dis_opt_state)
        return (gen_loss, dis_loss), (gen_opt_state, dis_opt_state)
    
    # Training
    losses = {'gen': [], 'dis': []}

    total_step = 0
    for epoch in range(hparams.epochs):
        key, key_data = jax.random.split(key)
        
        for batch_x in iterable_dataset(key=key_data, data=data, batch_size=hparams.batch_size):
            
            key, subkey = jax.random.split(key)
            noise = jax.random.normal(subkey, (hparams.batch_size, hparams.input_noise_dim))
            
            (gen_loss, dis_loss), (gen_opt_state, dis_opt_state) = train_step(
                step=total_step,
                gen_opt_state=gen_opt_state,
                dis_opt_state=dis_opt_state,
                noise=noise,
                real_samples=batch_x,
            )
            losses['gen'].append(gen_loss)
            losses['dis'].append(dis_loss)
            
            total_step += 1
            
        if epoch == 0 or (epoch < 100 and epoch % 10 == 9) or epoch % 100 == 99:
            print('-' * 30, 'epoch', epoch, '-' * 30)
            fig, axes = plt.subplots(
                1,3,
                figsize=(17,5))
            plot_losses(losses,axes[0])
            key, subkey = jax.random.split(key)
            samples = generate_and_plot_samples(
                subkey, 
                data,
                generator_apply, 
                gen_get_params(gen_opt_state),
                discriminator_apply,
                gen_get_params(dis_opt_state),
                n_samples=256, 
                input_noise_dim=hparams.input_noise_dim,
                ax=axes[1])
            plot_logit_space(dis_get_params(dis_opt_state),
                             discriminator_apply,
                             samples,
                             ax=axes[2],
                             ref_ax=axes[1],
                             apply_axis=[0,1])
            plt.show()

## (Task) Train the GAN model: simple blobs data

Below you can the output of my training. Can you get similar results?
Note that the output you get might not be exactly the same as the one below but given that the code skeleton already covers most of the details that could matter, it should be fairly similar.

write a few sentences about each of the following discussion points. 
**Discussion points:**
* Is the generator able to cover all of the blobs?
* Are the individual blobs well covered by the generated samples?
* Can you observe mode collapse?

In [None]:
hparams = OmegaConf.create({
    'epochs': 1000,
    'batch_size': 256,
    'gen_lr': 0.0001,
    'dis_lr': 0.0001,
    'beta1': 0.5,  # NOTE: important!
    'model_type': 'mlp',
    'generator': {
        'n_layers': 5,
        'hidden_dim': 512,
    },
    'discriminator': {
        'n_layers': 4,
        'hidden_dim': 512,
    },
    'input_noise_dim': 2,
})
create_and_train_gan(hparams, dataset_blobs(n_samples=10000, seed=seed), seed=seed)

In [None]:
hparams = OmegaConf.create({
    'epochs': 1000,
    'batch_size': 256,
    'gen_lr': 0.0001,
    'dis_lr': 0.0001,
    'beta1': 0.5,  # NOTE: important!
    'model_type': 'mlp',
    'generator': {
        'n_layers': 6,
        'hidden_dim': 512,
    },
    'discriminator': {
        'n_layers': 5,
        'hidden_dim': 512,
    },
    'input_noise_dim': 2,
})
create_and_train_gan(hparams, dataset_blobs(n_samples=10000, seed=seed), seed=seed)

If you get a similar result feel free to play with the hyperparameter values!
Which of them seem important? Which of them the model is not so sensitive?

You can also try to change the random seed and see if the results change significantly. Do the results seem more consistent across different seeds for different hyperparamer settings?

## (Optional Task) Train the GAN model: Swiss roll data

Now, let's try a slightly difficult problem!

In [None]:
hparams = OmegaConf.create({
    'epochs': 1000,
    'batch_size': 256,
    'gen_lr': 0.0001,
    'dis_lr': 0.0001,
    'beta1': 0.5,  # NOTE: important!
    'model_type': 'mlp',
    'generator': {
        'n_layers': 6,
        'hidden_dim': 512,
    },
    'discriminator': {
        'n_layers': 5,
        'hidden_dim': 512,
    },
    'input_noise_dim': 2,
})

create_and_train_gan(hparams, dataset_swiss_roll(n_samples=10000, seed=seed), seed=seed)

As you see the model (with the original hyperparameter values) is not really able to model the data well :/

Are you able to improve the model to get a results like the following figure?

![Swiss roll](./imgs/step1_swiss_roll.png)

**Here are some sugestions what to try:**
- Different hparams, architecture
- Updating the discriminator more often than the generator (e.g. 5 updates)
- Different initialization of weights (e.g. Normal with std = 0.02)
- Different GAN objectives ([WGAN](https://arxiv.org/abs/1701.07875),
[WGAN-GP](https://papers.nips.cc/paper/2017/hash/892c3b1c6dccd52936e27cbd0ff683d6-Abstract.html),
[LSGAN](https://openaccess.thecvf.com/content_ICCV_2017/papers/Mao_Least_Squares_Generative_ICCV_2017_paper.pdf))

Note that these are just suggestions! The point of them is to make you try out different things yourself. Some of them should help significantly, while others might not help much or even make things worse! You need to try them yourself to find out!