# Q1: Generative Adversarial Networks (GANs)

In [1]:
# Load the necessary imports
%load_ext autoreload
%autoreload 2

from library import datasets, models
from library.utils.numerical_checking import NumericalCheckingRecord
from library.utils.helper_functions import check_model_forward
from flax import linen as nn
from jax import random, numpy as jnp
import optax
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = 'notebook_connected'

  from .autonotebook import tqdm as notebook_tqdm


## Dataset
The generator and discriminator are both neural networks, whose architecture greatly depends on the task at hand. The most popular domain for GANs is on images, which means that ResNets and convolutions are part of the generator and discriminator architecture. Due to the computational complexity of using images and thereby training deep models, this homework will consider a points dataset, where data points are generated according a multivariate distribution with noise. As a result, the neural networks become shallower, and their architecture will deviate from what is used in the original CycleGAN paper. 

The points dataset in this homework is similar to the cluster dataset used in the [GAN Lab]( https://poloclub.github.io/ganlab/). GAN Lab also provides a great visualization while training a GAN. In essence, we are interested in training a GAN that can output data points within the two clusters seen in the plot.


In [2]:
# Load dataset
dataset = datasets.utils.make_blobs(n_samples=1000, min_sigma=0, max_sigma=0.1)
data = dataset.get_tensors()

# Visualize dataset
fig = make_subplots(rows=1, cols=1, subplot_titles=["Real samples"])
fig.add_trace(go.Scatter(x=data[:, 0], y=data[:, 1], mode='markers', marker=dict(color="blue"),name='Real'), row=1, col=1)
fig.show()

## Part (B): Define discriminator and generator

For this homework, consider the discriminator and generator as multi-layer perceptrons with ReLU activations between each layer. The number of layers and neurons is a design choice, but it is important to consider the differing dimensions of the final layer of both networks. The generator should generate output of similar dimension as the data that describes the distribution you wish to learn. The discriminator should generate a single logit, indicating whether it believes some given input is from the true distribution or the generated one.

Using the above, define the ouput dimensions of the discriminator and generator below. The rest of the architecture has already been given to you.

In [3]:
# Define ambient dimension, discriminator and generator
AMBIENT_DIM = 2

discriminator = nn.Sequential([
    nn.Dense(8),
    nn.relu,
    nn.Dense(16),
    nn.relu,
    nn.Dense(16),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(1), # TODO: Complete output dimension
])

generator = nn.Sequential([
    nn.Dense(8),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(4),
    nn.relu,
    nn.Dense(AMBIENT_DIM), # TODO: Complete output dimension
])

In [4]:
# Implementation check
key = random.PRNGKey(0) # DO NOT CHANGE
test_discriminator_input = random.normal(key, data.shape)
test_generator_input = random.normal(key, (1,4))
d_out = check_model_forward(discriminator, test_discriminator_input)
g_out = check_model_forward(generator, test_generator_input)

d_expected_out = NumericalCheckingRecord.load("checks/gan_discriminator_check")
g_expected_out = NumericalCheckingRecord.load("checks/gan_generator_check")

assert d_expected_out.data.shape == d_out.shape, "Output dimensions does not match expected"
assert g_expected_out.data.shape == g_out.shape, "Output dimensions does not match expected"
assert d_expected_out.check(d_out), "Output does not match the expected. Remember to output logits"
assert g_expected_out.check(g_out), "Output does not match the expected. Remember to output logits"

## Part (C): GAN training

Before training can commence, you need to implement the loss functions to train the generator and discriminator. 
The discriminator is trained to maximize the probability of assigning the correct labels for both real samples and generated samples, and train the generator to fool the discriminator simultaneously. In other words, generator and discriminator are playing a minmax game with the value function $V(G, D)$:

$$\underset{G}{min}\:\underset{D}{max} \:V(G, D) = \mathbb {E}_{x\sim p_{data}(x)}[log\:D(x)] + \mathbb{E}_{z\sim p_z(z)}[log(1 - D(G(z)))]$$

The objective can decomposed into an individual loss function for the discriminator and generator.

The discriminator *D* tries to maximize the following function, which is largely similar to binary cross entropy
$$\nabla_{\theta_d} \frac{1}{m}\sum_{i=1}^m \left[log\: D\left(\mathbb{x}^{(i)}\right)+log \: \left(1-D\left(G\left(z^{(i)}\right)\right)\right)\right]$$
where *m* is the number of examples in the minibatch, *x* is the true samples, *z* is a vector of noise samples and the generator *G* is frozen.

The generator *G* is updated by minimizing the following function
$$\nabla_{\theta_g} \frac{1}{m}\sum_{i=1}^m \left[log\:\left(1-D\left(G\left(z^{(i)}\right)\right)\right)\right]$$
where the discriminator *D* is frozen. Notice the difference between the two losses. The discriminator should be able to classify real samples and fake samples, where the generator is only concerned about trying to fool the discriminator.


In [5]:
MODEL_LATENT_DIM = 4

# Define GAN model
model = models.vanilla_gan.VanillaGAN(
    generator=generator,
    discriminator=discriminator,
    latent_shape=(MODEL_LATENT_DIM,),
    ambient_shape=(AMBIENT_DIM,)
)

model.initialize(loss_fn=optax.sigmoid_binary_cross_entropy)

In order to train the model, please complete the to-do's in `vanilla_gan.py`

In [6]:
# Code cell with training loop. Complete blank parts
model.train(
    datasets.base.TensorDataset(data), 
    optax.adam(learning_rate=1e-3), 
    print_every=5, 
    batch_size=1000, 
    num_epochs=500
)

Epoch 0; Generator loss:  0.7024; Discriminator loss:  0.6916: : 1it [00:13, 13.17s/it]
Epoch 1; Generator loss:  0.7025; Discriminator loss:  0.6915: : 1it [00:00,  1.49it/s]
Epoch 2; Generator loss:  0.7027; Discriminator loss:  0.6914: : 1it [00:00,  1.41it/s]
Epoch 3; Generator loss:  0.7033; Discriminator loss:  0.6911: : 1it [00:00,  1.51it/s]
Epoch 4; Generator loss:  0.7039; Discriminator loss:  0.6908: : 1it [00:00,  1.37it/s]
Epoch 5; Generator loss:  0.7031; Discriminator loss:  0.6911: : 1it [00:00,  1.60it/s]
Epoch 6; Generator loss:  0.7042; Discriminator loss:  0.6906: : 1it [00:00,  1.32it/s]
Epoch 7; Generator loss:  0.7049; Discriminator loss:  0.6903: : 1it [00:00,  1.32it/s]
Epoch 8; Generator loss:  0.7050; Discriminator loss:  0.6902: : 1it [00:00,  1.47it/s]
Epoch 9; Generator loss:  0.7060; Discriminator loss:  0.6898: : 1it [00:00,  1.20it/s]
Epoch 10; Generator loss:  0.7065; Discriminator loss:  0.6897: : 1it [00:00,  1.55it/s]
Epoch 11; Generator loss:  0.70

The interpretation of the losses is not as straightforward as you would think. Since the generator and discriminator are adversaries, their loss cannot decrease simultaneously. Usually when GANs are trained we expect the discriminator loss to decrease and the generator loss to oscillate, but they are not particularly indicative of whether training was a success. To determine this, let's visualize samples from the distribution the GAN has learned.

In [7]:
gan_distribution = model.create_distribution()
random_noise = random.uniform(random.PRNGKey(0), (len(data), MODEL_LATENT_DIM), minval=-1, maxval=1)
fake_samples = gan_distribution.draw_samples(random_noise)

In [8]:
## Implementation check
expected_fake_samples = NumericalCheckingRecord.load('checks/gan_output_check')
assert expected_fake_samples.check(fake_samples), "Output does not match expected. Please revisit your implementation."

In [9]:
fig = make_subplots(rows=1, cols=1, subplot_titles=["Real and fake samples from trained generator"])
fig.add_trace(go.Scatter(x=data[:, 0], y=data[:, 1], mode='markers', marker=dict(color="blue"),name='Real'), row=1, col=1)
fig.add_trace(go.Scatter(x=fake_samples[:, 0], y=fake_samples[:, 1], mode='markers', marker=dict(color="red"),name='Fake'), row=1, col=1)
fig.show()

**Q. Did the GAN manage to learn the underlying data distribution? Do you observe anything odd? Why, why not?**

Answer the question in written portion.