# Homework: Cycle Generative Adversarial Network (CycleGAN)

In [19]:
# Load the necessary imports
%load_ext autoreload
%autoreload 2

from library import datasets, models
from flax import linen as nn
from jax import random, numpy as jnp
import optax
import pandas as pd
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'notebook_connected'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Generative Adversarial Networks (GANs)

Generative models has seen a huge growth in popularity the latest years due to its unique properties. The usage of these models cover applications that directly require generation, creating more training data for machine learning or data augmentation for privacy concerns. A popular choice of generative models are the Generative Adversarial Networks (GANs), which have been used to generate realistic photographs, image style transfer, face ageing in pictures and [many more impressive applications](https://machinelearningmastery.com/impressive-applications-of-generative-adversarial-networks/).

GANs are constructed by training two models, a generator and a discriminator, that are competing against each other. The generator is trying to learn the distribution of a given dataset and is used to generate fake samples following this distribution. The discriminator is trying to determine whether a sample is from the "fake" distribution or the true distribution. Therefore, the generator and discriminator are considered as adversaries, since they are both trying to fool each other. The [original GAN paper](https://arxiv.org/pdf/1406.2661.pdf) provides a comical example of their roles, where the generative model is analogous to a team of counterfeiters, trying to produce fake currency and use it without detection. The discriminator is analogous to the police, which are trying to detect the counterfeit currency. The competition between these two encourages both to improve their methods until the counterfeit currency is indistinguishable from real currency.

### Dataset
The generator and discriminator are both neural networks, whose architecture greatly depends on the task at hand. The most popular domain for GANs is on images, which means that ResNets and convolutions are part of the generator and discriminator architecture. Due to the computational complexity of using images and thereby training deep models, this homework will consider a points dataset, where data points are generated according to higher-order polynomial with Gaussian noise. As a result, the neural networks become shallower, and their architecture will deviate from what is used in the original CycleGAN paper. 

The points dataset in this homework is similar to the cluster dataset used in the [GAN Lab]( https://poloclub.github.io/ganlab/). GAN Lab also provides a great visualization while training a GAN. For now, let’s become more acquainted with the points dataset.


In [20]:
# Load dataset
dataset = datasets.utils.make_blobs(n_samples=1000, min_sigma=0, max_sigma=0.1)
data = dataset.get_tensors()
px.scatter(x=data[:, 0], y=data[:, 1], color=jnp.ones(len(data)), range_color=(0,1))

**Not done**
In essence, we are interested in training a GAN that can output data points within the two clusters seen in the plot above.

For this homework, consider the discriminator and generator as multi-layer perceptrons with ReLU activations between each layer. The number of layers and neurons is a design choice, but it is important to consider the differing dimensions of the final layer of both networks. The generator should generate output of similar dimension as the data that describes the distribution you wish to learn. The discriminator should generate a single logit, indicating whether it believes some given input is from the true distribution or the generated one.

Using the above, define the ouput dimensions of the discriminator and generator below.

In [21]:
# Define ambient dimension, discriminator and generator
AMBIENT_DIM = 2

discriminator = nn.Sequential([
    nn.Dense(8),
    nn.relu,
    nn.Dense(16),
    nn.relu,
    nn.Dense(16),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(1), # TODO: Complete output dimension
])

generator = nn.Sequential([
    nn.Dense(8),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(4),
    nn.relu,
    nn.Dense(AMBIENT_DIM), # TODO: Complete output dimension
])

### GAN training

**Write about GAN and discriminator loss**

In [22]:
MODEL_LATENT_DIM = 4

# Define GAN model
model = models.vanilla_gan.VanillaGAN(
    generator=generator,
    discriminator=discriminator,
    latent_shape=(MODEL_LATENT_DIM,),
    ambient_shape=(AMBIENT_DIM,)
)

# TODO: Initialize the GAN model by passing in the loss function of the discriminator
model.initialize(loss_fn=optax.sigmoid_binary_cross_entropy)

In order to train the model, please complete the to-do's in `vanilla_gan.py`

In [23]:
# Code cell with training loop. Complete blank parts
model.train(
    datasets.base.TensorDataset(data), 
    optax.adam(learning_rate=1e-3), 
    print_every=5, 
    batch_size=1000, 
    num_epochs=500
)

iteration 0; gen_loss:  7.02e-01; dis_loss:  6.92e-01; gen_grads_magnitude:  9.95e-03; dis_grads_magnitude:  2.97e-03: : 1it [00:04,  4.31s/it]
iteration 0; gen_loss:  7.03e-01; dis_loss:  6.91e-01; gen_grads_magnitude:  1.08e-02; dis_grads_magnitude:  3.36e-03: : 1it [00:00,  2.68it/s]
iteration 0; gen_loss:  7.03e-01; dis_loss:  6.91e-01; gen_grads_magnitude:  1.01e-02; dis_grads_magnitude:  3.66e-03: : 1it [00:00,  3.90it/s]
iteration 0; gen_loss:  7.03e-01; dis_loss:  6.91e-01; gen_grads_magnitude:  1.01e-02; dis_grads_magnitude:  3.89e-03: : 1it [00:00,  3.68it/s]
iteration 0; gen_loss:  7.04e-01; dis_loss:  6.91e-01; gen_grads_magnitude:  1.01e-02; dis_grads_magnitude:  4.27e-03: : 1it [00:00,  3.51it/s]
iteration 0; gen_loss:  7.03e-01; dis_loss:  6.91e-01; gen_grads_magnitude:  9.93e-03; dis_grads_magnitude:  4.67e-03: : 1it [00:00,  3.81it/s]
iteration 0; gen_loss:  7.04e-01; dis_loss:  6.91e-01; gen_grads_magnitude:  1.00e-02; dis_grads_magnitude:  4.65e-03: : 1it [00:00,  3.

**Let's visualize the results by showing the output of the trained generator**

In [29]:
gan_distribution = model.create_distribution()
random_noise = random.uniform(random.PRNGKey(0), (len(data), MODEL_LATENT_DIM), minval=-1, maxval=1)
#random_noise = random.normal(random.PRNGKey(0), (len(data), MODEL_LATENT_DIM))
fake_samples = gan_distribution.draw_samples(random_noise)

In [32]:
df = {'x': jnp.concatenate((data[:, 0], fake_samples[:, 0]), axis=0),
      'y': jnp.concatenate((data[:, 1], fake_samples[:, 1]), axis=0),
      'labels': jnp.concatenate((jnp.ones((data.shape[0],)), jnp.zeros((fake_samples.shape[0],))), axis=0)}

px.scatter(df, x='x', y='y', color='labels', labels={"0": "Fake", "1": "True"})


**Did the GAN manage to learn the underlying data distribution? Do you observe anything odd? Why, why not?**


Looking at the figure above, you might have noticed that the generator only generates points in the blob located in the lower right corner. When training GANs, a common challenge is non-convergence and mode collapse. Mode collapse occurs as the generator converges to a local minima, where the generator only rotates through few different outputs that fool the discriminator. If the generator tries to diverge from these outputs, the discriminator notices and classifies them as fake i.e. not coming from the true data distribution. Therefore, the generator cannot explore and learn to generator data from the blob in the top left corner, as doing so would increase the loss of the generator. 

## CycleGAN

In recent years, many different GANs variants have been proposed, where an attempt to solve mode collapse was made. One of these variants is CycleGAN, that introduces a cycle consistency loss and the necessity for two data distributions during training. As a result of having two data distributions, it is necessary to have two generators (a generator for each direction between the distributions) and two discriminators (a discriminator for each distribution).
![Insert figure](hej.png) 

- Explain the introduction of cycle consistency loss with images
    - How does it combat mode collapse? 

In [3]:
# Code cell where two generator and discriminator are constructed

In [None]:
# Code cell with training loop. Implement cycle consistency loss

- Limitations of CycleGAN
    - Geometric translations (i.e. cat to dog)