# Practical 2: Generative Adversarial Networks (GAN)
---

**Tutorial overview**
In this tutorial you will implement, train and analyse the results of a Generative Adversarial Network.


**Tutorial outline**
- [Theory recap](#theory-recap)
- [Setup](#setup)
  - Install and Import Packages
  - Dataset
  - Helper Functions
- [Implementing GAN components](#implement-gan)
  - Generator and Discriminator
  - Loss functions
- [Training and Visualisation](#training)
  - Train Utils
  - Main Train Loop
  - Run Training
- [Analysis](#analysis)


## Theory recap <a class="anchor" id="theory-recap"></a>
---


Generative Adversarial Networks (GANs) are a type of machine learning model that can be used to generate realistic data, such as images, text, and audio. GANs work by training two neural networks against each other: a generator and a discriminator.

The generator is responsible for generating new data, while the discriminator is responsible for distinguishing between real and fake data. During training, the generator tries to fool the discriminator by generating data that is indistinguishable from real data. The discriminator, in turn, tries to get better at detecting fake data. This adversarial process forces the generator to produce increasingly realistic data.

How GANs work in detail:

- The generator takes as input a random noise vector and produces a synthetic data sample.
- The discriminator takes as input a data sample (either real or fake) and outputs a probability that the sample is real.
- The generator and discriminator are trained alternately.
  - The generator is trained to maximize the probability that the discriminator classifies its output as real,    
  - The discriminator is trained to maximize the probability that it correctly classifies real and fake data.
- This training process continues until the generator is able to produce data that is indistinguishable from real data.

![gan_architecture](https://miro.medium.com/v2/resize:fit:1400/1*ZKUo2QtHasnr8-RiqeJ_YA.png)
[Image source [Saul Dobilas Medium](https://towardsdatascience.com/gans-generative-adversarial-networks-an-advanced-solution-for-data-generation-2ac9756a8a99)]

## Setup <a class="anchor" id="setup"></a>

**NOTES:**
<br>
- If the following error is appeared 'AttributeError: module 'numpy' has no attribute '_no_nep50_warning', please restart the kernel and re-run the cells.
- Please use the GPU kernel


### Install and Import Packages

In [None]:
#@title Install Packages
# ! pip uninstall numpy
# ! pip install numpy
! pip install chex -q
! pip install optax -q
! pip install distrax -q
! pip install dm_haiku -q
! pip install absl-py -q

In [None]:
! pip install typing_extensions -q

In [None]:
from functools import partial
from time import time

import haiku as hk
import jax
import matplotlib.pyplot as plt
import optax
from haiku.initializers import Constant, RandomNormal
from jax import jit
from jax import numpy as jnp
from jax import random
from jax import value_and_grad as vgrad
import tensorflow as tf


### Dataset

In this tutorial we will use the [MNIST dataset](https://keras.io/api/datasets/mnist/).
<br>
The other datasets are left for homework.

In [None]:

def load_images_mnist(batch_size=128, seed=0):
    def prepare_dataset(X):
        X = tf.cast(X, tf.float32)
        # Normalization, pixels in [-1, 1]
        X = (X / 255.0) * 2.0 - 1.0
        X = tf.expand_dims(X, axis=-1)
        # shape=(batch_size, 28, 28, 1)
        return X

    (X_train, _), (X_test, _) = tf.keras.datasets.mnist.load_data()
    X = tf.concat([X_train, X_test], axis=0)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.cache().shuffle(buffer_size=len(X), seed=seed)
    dataset = dataset.batch(batch_size).prefetch(buffer_size=-1)
    dataset = dataset.map(prepare_dataset)
    return dataset


def load_images_cifar10(batch_size=128, seed=0):
    def prepare_dataset(X):
        X = tf.cast(X, tf.float32)
        # Normalization, pixels in [-1, 1]
        X = (X / 255.0) * 2.0 - 1.0
        # shape=(batch_size, 32, 32, 3)
        return X

    (X_train, _), (X_test, _) = tf.keras.datasets.cifar10.load_data()
    X = tf.concat([X_train, X_test], axis=0)
    dataset = tf.data.Dataset.from_tensor_slices(X)
    dataset = dataset.cache().shuffle(buffer_size=len(X), seed=seed)
    dataset = dataset.batch(batch_size).prefetch(buffer_size=-1)
    dataset = dataset.map(prepare_dataset)
    return dataset


    return iter(tfds.as_numpy(ds))


def load_images_celeba_64(batch_size=128, seed=0, path='data/CelebA/'):
    def generate_data():
        for f_name in os.listdir(path):
            img = cv2.imread(os.path.join(path, f_name))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)[20:-20, :, :]
            img = cv2.resize(img, (64, 64))
            img = tf.constant(img, dtype=tf.float32)
            img = (img / 255.0) * 2.0 - 1.0
            yield img

    dataset = tf.data.Dataset.from_generator(generate_data,
                                             output_types=tf.float32,
                                             output_shapes=(64, 64, 3))
    dataset = dataset.shuffle(buffer_size=202_600, seed=seed)
    dataset = dataset.batch(batch_size).prefetch(buffer_size=-1)
    dataset.__len__ = lambda: tf.constant(202_599 // batch_size + 1,
                                          dtype=tf.int64)
    return dataset


def load_images_celeba_128(batch_size=128, seed=0, path='data/CelebA/'):
    def generate_data():
        for f_name in os.listdir(path):
            img = cv2.imread(os.path.join(path, f_name))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)[20:-20, :, :]
            img = cv2.resize(img, (128, 128))
            img = tf.constant(img, dtype=tf.float32)
            img = (img / 255.0) * 2.0 - 1.0
            yield img

    dataset = tf.data.Dataset.from_generator(generate_data,
                                             output_types=tf.float32,
                                             output_shapes=(128, 128, 3))
    dataset = dataset.shuffle(buffer_size=202_600, seed=seed)
    dataset = dataset.batch(batch_size).prefetch(buffer_size=-1)
    dataset.__len__ = lambda: tf.constant(202_599 // batch_size + 1,
                                          dtype=tf.int64)
    return dataset



### Helper Functions

Some helper functions for:
- A class that keeps track of the moving average of the inputs
- Plotting functions for the images and losses

In [None]:

class Mean(object):
    """ Compute dynamic mean of given inputs. """
    def __init__(self):
        self.val = 0.0
        self.count = 0
        # Keep the history of **given inputs**
        self.history = []

    def reset(self):
        self.val = 0.0
        self.count = 0

    def reset_history(self):
        self.history = []

    def __call__(self, val):
        if isinstance(val, jnp.ndarray):
            val = val.item()
        # Keep the history of **given inputs**
        self.history.append(val)
        self.val = (self.val * self.count + val) / (self.count + 1)
        self.count += 1
        return self.val

    def __str__(self):
        return str(self.val)

    def __repr__(self):
        return repr(self.val)

    def __format__(self, *args, **kwargs):
        return self.val.__format__(*args, **kwargs)


def input_func(key, batch_size, zdim):
    """ Input of generator ( = "noise" in classic GAN). """
    return random.normal(key, (batch_size, zdim))


# Plotting

def plot_tensor_images(images, num_images=(10, 10), cmap='gray'):
    # Normalize to [0, 1]
    if images.min() < 0:
        images = (images + 1.0) / 2.0
        images = jnp.clip(images, 0.0, 1.0)
    h, w = images.shape[1:3]
    nh, nw = num_images
    if len(images) < nh * nw:
        raise ValueError("Not enough images to show (number of images "
                         f"received: {len(images)}, number of image "
                         f"needed : {nh}x{nw}.")
    image_grid = images[:nh * nw].reshape(nh, nw, h, w, -1)
    image_grid = jnp.transpose(image_grid, (0, 2, 1, 3, 4))
    image_grid = image_grid.reshape(nh * h, nw * w, -1)
    plt.grid(False)
    plt.axis('off')
    plt.imshow(image_grid, cmap=cmap)


def plot_curves(history, n_epochs):
    loss_gen, loss_disc = history['loss_gen'], history['loss_disc']
    loss_gen, loss_disc = jnp.array(loss_gen), jnp.array(loss_disc)
    # Downsample the points to reduce the length of the plots
    len_gen, len_disc = min(1000, len(loss_gen)), min(1000, len(loss_disc))
    time_gen = jnp.linspace(0, n_epochs, len_gen)
    time_disc = jnp.linspace(0, n_epochs, len_disc)
    loss_gen = jnp.interp(time_gen, jnp.linspace(0, n_epochs, len(loss_gen)),
                          loss_gen)
    loss_disc = jnp.interp(time_disc, jnp.linspace(0, n_epochs,
                                                   len(loss_disc)), loss_disc)

    plt.plot(time_gen, loss_gen, color='#ff9100', label='generator loss')
    plt.plot(time_disc, loss_disc, color='#00aaff', label='discriminator loss')
    plt.ylim([0, 5.0])
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(loc='best')


## Implementing GAN components <a class="anchor" id="implement-gan"></a>

### Generator

The code below implements the Generator Network using hk.Conv2DTranspose, and the Discriminator Network using hk.Conv2D.


In [None]:

class Generator(hk.Module):
    def __init__(self):
        super().__init__()
        self.channels = (256, 128, 64, 1)
        self.ker_shapes = (3, 4, 3, 4)
        self.strides = (2, 1, 2, 2)
        self.padding = (0, 0, 0, 0)
        self.n_layers = len(self.channels)

        if isinstance(self.ker_shapes, int):
            self.ker_shapes = [self.ker_shapes] * self.n_layers
        if isinstance(self.strides, int):
            self.strides = [self.strides] * self.n_layers
        if isinstance(self.padding, int):
            self.padding = [self.padding] * self.n_layers

        self.layers = [
            hk.Conv2DTranspose(
                self.channels[i],
                kernel_shape=self.ker_shapes[i],
                stride=self.strides[i],
                padding='VALID' if self.padding[i] == 0 else 'SAME',
                with_bias=False,
                w_init=RandomNormal(stddev=0.02, mean=0.0))
            for i in range(self.n_layers)
        ]

        self.batch_norms = [
            hk.BatchNorm(False, False, 0.99) for _ in range(self.n_layers - 1)
        ]

    def __call__(self, z, is_training=jnp.asarray([True])):
        x = jnp.reshape(z, (-1, 1, 1, z.shape[-1]))
        for i in range(self.n_layers - 1):
            x = self.layers[i](x)
            x = self.batch_norms[i](x, is_training)
            x = jax.nn.relu(x)
        x = self.layers[-1](x)
        x = jnp.tanh(x)

        return x


class Discriminator(hk.Module):
    def __init__(self):
        super().__init__()
        self.channels = (16, 32, 1)
        self.ker_shapes = 4
        self.strides = 2
        self.padding = (0, 0, 0, 0)
        self.n_layers = len(self.channels)

        if isinstance(self.ker_shapes, int):
            self.ker_shapes = [self.ker_shapes] * self.n_layers
        if isinstance(self.strides, int):
            self.strides = [self.strides] * self.n_layers
        if isinstance(self.padding, int):
            self.padding = [self.padding] * self.n_layers

        self.layers = [
            hk.Conv2D(self.channels[i],
                      kernel_shape=self.ker_shapes[i],
                      stride=self.strides[i],
                      padding='VALID' if self.padding[i] == 0 else 'SAME',
                      w_init=RandomNormal(stddev=0.02, mean=0.0),
                      b_init=Constant(0.0)) for i in range(self.n_layers)
        ]
        self.batch_norms = [
            hk.BatchNorm(True, True, 0.99) for _ in range(self.n_layers - 1)
        ]

    def __call__(self, x, is_training=jnp.asarray([True])):

        if x.ndim == 3:
            x = jnp.expand_dims(x, axis=-1)
        for i in range(self.n_layers - 1):
            x = self.layers[i](x)
            x = self.batch_norms[i](x, is_training)
            x = jax.nn.leaky_relu(x, 0.2)

        x = self.layers[-1](x)
        x = jnp.squeeze(x)

        return x


In [None]:

@hk.without_apply_rng
@hk.transform_with_state
def gen_fwd(z, is_training):
    """ (transformed) Forward pass of generator. """

    generator = Generator()
    X_fake = generator(z, is_training=is_training)
    return X_fake


@hk.without_apply_rng
@hk.transform_with_state
def disc_fwd(X, is_training):
    """ (transformed) Discriminator pass of generator. """

    discriminator = Discriminator()
    y_pred = discriminator(X, is_training=is_training)
    return y_pred


def init_generator(key, config, z):
    """ Initialize the generator parameters/states
    and its optimizer."""
    params_gen, state_gen = gen_fwd.init(key,
                                         z,
                                         is_training=jnp.asarray([True]))

    opt_gen = optax.adam(learning_rate=config['lr'],
                         b1=config['beta1'],
                         b2=config['beta2'])
    opt_state_gen = opt_gen.init(params_gen)
    return state_gen, opt_gen, opt_state_gen, params_gen


def init_discriminator(key, config, x):
    """ Initialize the discriminator parameters/states
    and its optimizer."""
    params_disc, state_disc = disc_fwd.init(
        key,
        x,
        is_training=jnp.asarray([True]),
    )

    opt_disc = optax.adam(learning_rate=config['lr'],
                          b1=config['beta1'],
                          b2=config['beta2'])
    opt_state_disc = opt_disc.init(params_disc)
    return state_disc, opt_disc, opt_state_disc, params_disc


## Loss functions

In [None]:
@jit
def cross_entropy(logits, labels):
    return optax.sigmoid_binary_cross_entropy(logits, labels)



def fwd_loss_gen(params_gen, params_disc, state_gen, state_disc, z,
                 is_training):
    """ Computes the loss of the generator over one batch. """

    ############################################################
    # Generate fake images X_fake, state_gen = gen_fwd.apply(params_gen, state_gen, ???, is_training=is_training)

    X_fake, state_gen =
    ############################################################


    ############################################################
    # Discriminater fake images X_fake, state_gen = disc_fwd.apply(params_gen, state_gen, ???, is_training=is_training)

    y_pred_fake, state_disc =
    #################################################################


    #################################################################
    # Run cross-entroy

    loss_gen =
    #################################################################
    loss_gen = jnp.mean(loss_gen)
    return loss_gen, (state_gen, state_disc)


def fwd_loss_disc(params_disc, params_gen, state_disc, state_gen, z, X_real,
                  is_training):
    """ Computes the loss of the discriminator over one batch. """
    X_fake, state_gen = gen_fwd.apply(params_gen,
                                      state_gen,
                                      z,
                                      is_training=is_training)

    #################################################################
    # Predict fake data y_pred_fake, state_disc = disc_fwd.apply(params_disc,state_disc, ???, is_training=???)
    y_pred_fake, state_disc =
    # Predict real data y_pred_real, state_disc = disc_fwd.apply(params_disc,state_disc, ???, is_training=???)
    y_pred_real, state_disc =

    #################################################################

    # Smooth label (+/- 0.1)
    fake_loss = cross_entropy(y_pred_fake, jnp.zeros_like(y_pred_fake) + 0.1)
    real_loss = cross_entropy(y_pred_real, jnp.ones_like(y_pred_real) - 0.1)
    loss_disc = ((fake_loss + real_loss) / 2.0)
    loss_disc = jnp.mean(loss_disc)
    return loss_disc, (state_disc, state_gen)



## Training and Visualisations <a class="anchor" id="training"></a>

### Train utils

In [None]:
@partial(jit, static_argnums=(4, 7))
def train_gen(
    params_gen,
    params_disc,
    state_gen,
    state_disc,
    opt_gen,
    opt_state_gen,
    z,
    is_training,
):
    """ (jit) Update the generator parameters/states and
    its optimizer over one batch. """
    (loss_gen, (state_gen, state_disc)), grads = vgrad(fwd_loss_gen,
                                                       has_aux=True)(
                                                           params_gen,
                                                           params_disc,
                                                           state_gen,
                                                           state_disc,
                                                           z,
                                                           is_training,
                                                       )

    updates, opt_state_gen = opt_gen.update(grads, opt_state_gen, params_gen)
    params_gen = optax.apply_updates(params_gen, updates)
    return params_gen, state_gen, state_disc, opt_state_gen, loss_gen


@partial(jit, static_argnums=(4, 8))
def train_disc(params_disc, params_gen, state_disc, state_gen, opt_disc,
               opt_state_disc, z, X_real, is_training):
    """ (jit) Update the discriminator parameters/states and
    its optimizer over one batch. """
    (loss_disc, (state_disc, state_gen)), grads = vgrad(fwd_loss_disc,
                                                        has_aux=True)(
                                                            params_disc,
                                                            params_gen,
                                                            state_disc,
                                                            state_gen,
                                                            z,
                                                            X_real,
                                                            is_training,
                                                        )
    updates, opt_state_disc = opt_disc.update(grads, opt_state_disc,
                                              params_disc)
    params_disc = optax.apply_updates(params_disc, updates)
    return params_disc, state_disc, state_gen, opt_state_disc, loss_disc


def cycle_train(X_real, key, params_gen, params_disc, state_gen, state_disc,
                opt_gen, opt_state_gen, opt_disc, opt_state_disc,
                mean_loss_gen, mean_loss_disc, config):
    """ Train the generator and the discriminator and update
    the means (mean_loss_gen and mean_loss_disc).
    """
    X_real = jnp.array(X_real)
    batch_size = X_real.shape[0]  # (can change at the end of epoch)
    key, *keys = random.split(key, 1 + config['disc_cycle_train'])

    # Train generator
    z = input_func(key, batch_size, config['z_dim'])
    (params_gen, state_gen, state_disc, opt_state_gen, loss_gen) = train_gen(
        params_gen,
        params_disc,
        state_gen,
        state_disc,
        opt_gen,
        opt_state_gen,
        z,
        True,
    )
    mean_loss_gen(loss_gen)

    # Train discriminator (cylce_train_disc times)
    for k in range(config['disc_cycle_train']):
        z = input_func(keys[k], batch_size, config['z_dim'])

        (params_disc, state_disc, state_gen, opt_state_disc,
         loss_disc) = train_disc(
             params_disc,
             params_gen,
             state_disc,
             state_gen,
             opt_disc,
             opt_state_disc,
             z,
             X_real,
             True,
         )
        mean_loss_disc(loss_disc)

    return (params_gen, params_disc, state_gen, state_disc, opt_state_gen,
            opt_state_disc, mean_loss_gen, mean_loss_disc)


### Main training Loop

In [None]:
def train(
    dataset,
    key,
    config,
):

    X_real = jnp.array(dataset.take(1).as_numpy_iterator().next())
    batch_size = X_real.shape[0]
    # Initialize generator and discriminator
    z = input_func(key, batch_size, config['z_dim'])
    state_gen, opt_gen, opt_state_gen, params_gen = init_generator(
        key, config, z)
    state_disc, opt_disc, opt_state_disc, params_disc = init_discriminator(
        key, config, X_real)
    print('Initialization succeeded.')

    mean_loss_gen, mean_loss_disc = Mean(), Mean()
    len_ds = int(dataset.__len__())
    itr = 0
    start = time()
    for ep in range(config['epochs']):
        print('Epoch {}-{}'.format(ep + 1, config['epochs']))

        for i_batch, X_real in enumerate(dataset):
            t = time() - start
            eta = t / (itr + 1) * (config['epochs'] * len_ds - itr - 1)
            t_h, t_m, t_s = t // 3600, (t % 3600) // 60, t % 60
            eta_h, eta_m, eta_s = eta // 3600, (eta % 3600) // 60, eta % 60
            print(
                f'  batch {i_batch + 1}/{len_ds} - '
                f'gen loss:{mean_loss_gen: .5f} - '
                f'disc loss:{mean_loss_disc: .5f} - '
                f'time: {int(t_h)}h {int(t_m)}min {int(t_s)}sec - '
                f'eta: {int(eta_h)}h {int(eta_m)}min {int(eta_s)}sec    ',
                end='\r')
            key, subkey = random.split(key, 2)

            # Training both generator and discriminator
            (params_gen, params_disc, state_gen, state_disc, opt_state_gen,
             opt_state_disc, mean_loss_gen, mean_loss_disc) = cycle_train(
                 X_real, key, params_gen, params_disc, state_gen, state_disc,
                 opt_gen, opt_state_gen, opt_disc, opt_state_disc,
                 mean_loss_gen, mean_loss_disc, config)

            itr += 1
            # Plot images
            if itr % config['display_step'] == 0:
                z = input_func(
                    subkey, config['num_images'][0] * config['num_images'][1],
                    config['z_dim'])
                X_fake, state_gen = gen_fwd.apply(params_gen,
                                                  state_gen,
                                                  z,
                                                  is_training=False)

                plot_tensor_images(X_fake, num_images=config['num_images'])
                plt.title('Epoch {}/{} - iteration {}'.format(
                    ep + 1, config['epochs'], itr),
                          fontsize=15)
                plt.show(block=False)


        mean_loss_gen.reset(), mean_loss_disc.reset()
        print()

    history = {
        'loss_gen': mean_loss_gen.history,
        'loss_disc': mean_loss_disc.history
    }

    z = input_func(key, config['num_images'][0] * config['num_images'][1],
                   config['z_dim'])
    X_fake, state_gen = gen_fwd.apply(
        params_gen,
        state_gen,
        z,
        is_training=False,
    )
    plt.figure(figsize=(40, 20))
    plt.subplot(1, 2, 1)
    plot_tensor_images(X_fake, num_images=config['num_images'])
    plt.title('Final images generation')
    plt.subplot(1, 2, 2)
    plt.title('Loss curves')
    plot_curves(history, config['epochs'])
    plt.show()
    return params_gen, state_gen, params_disc, state_disc, history


### Run training

In [None]:
config = {
    'z_dim': 64,
    'lr': 1e-4,
    'beta1': 0.5,
    'beta2': 0.999,
    'batch_size': 128,
    'epochs': 5,
    'disc_cycle_train': 5,
    'seed': 25,
    'display_step': 500,
    'num_images': (10, 10)
}

In [None]:
key = random.PRNGKey(config['seed'])
dataset = load_images_mnist(batch_size=config['batch_size'],
                            seed=config['seed'])

params_gen, state_gen, params_disc, state_disc, history = train(
    dataset=dataset,
    key=key,
    config=config,
)

NameError: name 'load_images_mnist' is not defined

## Analysis <a class="anchor" id="analysis"></a>

1. Find the correct leaerning rate and batch size?

2. Which hyperparameter affects GAN training the most?

3. Increase the num_epochs, how does this affect generated image quality?