# Q2: Cycle Generative Adversarial Network (CycleGAN)

In [1]:
# Load the necessary imports
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
from library import datasets, models
from library.visualizations.point_pairs import visualize_point_correspondence
from library.utils.numerical_checking import NumericalCheckingRecord
from library.utils.helper_functions import check_model_forward, plot_generated_samples, plot_samples
from flax import linen as nn
from jax import random, numpy as jnp
import optax
import plotly.io as pio
from plotly.subplots import make_subplots
import plotly.graph_objects as go
pio.renderers.default = 'notebook_connected'

## Dataset

In the previous notebook, you saw how GAN tries to learn to generate samples from a desired distribution. In this notebook, you will use CycleGAN to learn how to map samples from one distribution to another. One of the key advantages of CycleGAN is that it does not require a paired dataset. Instead, it learns the sample-to-sample correspondence between two datasets by enforcing cycle consistency. In this homework, you will consider the mapping between two distributions of the blob dataset you saw in the GAN notebook. These are illustrated below.

In [2]:
key = random.PRNGKey(5)
key_A, key_B = random.split(key)

# Create dataset
A = datasets.utils.make_blobs(n_samples=1000, min_sigma=0, max_sigma=0.1, key=key_A)
B = datasets.utils.make_blobs(n_samples=1000, min_sigma=0, max_sigma=0.1, key=key_B)

real_A = A.get_tensors()
real_B = B.get_tensors()

In [3]:
fig = make_subplots(rows=1, cols=2, subplot_titles=["Distribution A", "Distribution B"])
fig.add_trace(go.Scatter(x=real_A[:, 0], y=real_A[:, 1], mode='markers', marker=dict(color="blue"),name='Real A'), row=1, col=1)
fig.add_trace(go.Scatter(x=real_B[:, 0], y=real_B[:, 1], mode='markers', marker=dict(color="green"), name='Real B'), row=1, col=2)
fig.update_layout(title_text="Real samples")
fig.show()

## Part (A): Initialize CycleGAN and interpret generated fake samples

Since you already learned how to create the discriminator and generator in the GAN notebook, you can reuse their architectures below. Remember that we have two different distributions now instead of one for GAN.

In [4]:
# Define ambient dimension, discriminator and generator
AMBIENT_DIM = 2

discriminator_A = nn.Sequential([ # TODO: Copy from GAN notebook
    nn.Dense(8),
    nn.relu,
    nn.Dense(16),
    nn.relu,
    nn.Dense(16),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(1),
])

discriminator_B = nn.Sequential([ # TODO: Copy from GAN notebook
    nn.Dense(8),
    nn.relu,
    nn.Dense(16),
    nn.relu,
    nn.Dense(16),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(1),
])

generator_AB = nn.Sequential([ # TODO: Copy from GAN notebook
    nn.Dense(8),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(4),
    nn.relu,
    nn.Dense(AMBIENT_DIM),
])

generator_BA = nn.Sequential([ # TODO: Copy from GAN notebook
    nn.Dense(8),
    nn.relu,
    nn.Dense(8),
    nn.relu,
    nn.Dense(4),
    nn.relu,
    nn.Dense(AMBIENT_DIM),
])


In [5]:
# Initialize model with generators, discriminators and loss function.
modules = {
        'generator_AB': generator_AB,
        'generator_BA': generator_BA,
        'discriminator_A': discriminator_A,
        'discriminator_B': discriminator_B,
    }

model = models.cyclegan.CycleGAN(modules, (AMBIENT_DIM,), (AMBIENT_DIM,))
model.initialize(optax.sigmoid_binary_cross_entropy)

In [8]:
# Get generators from untrained model and draw samples
gan_AB, gan_BA = model.create_distribution()
fake_A = gan_BA.draw_samples(real_B)
fake_B = gan_AB.draw_samples(real_A)

In [9]:
## Implementation check
expected_fake_A = NumericalCheckingRecord.load('library/checks/cyclegan_generatorBA_check')
expected_fake_B = NumericalCheckingRecord.load('library/checks/cyclegan_generatorAB_check')
assert expected_fake_A.data.shape == fake_A.shape, "Ensure output shape matches what you set for GAN"
assert expected_fake_B.data.shape == fake_B.shape, "Ensure output shape matches what you set for GAN"
assert expected_fake_A.check(fake_A), "Remember your architecture should match what you set for GAN"
assert expected_fake_B.check(fake_B), "Remember your architecture should match what you set for GAN"

In [10]:
fig = make_subplots(rows=1, cols=2, subplot_titles = ["Distribution A", "Distribution B"])
fig.add_trace(go.Scatter(x=real_A[:, 0], y=real_A[:, 1], mode='markers', marker=dict(color="blue"),name='Real A'), row=1, col=1)
fig.add_trace(go.Scatter(x=real_B[:, 0], y=real_B[:, 1], mode='markers', marker=dict(color="green"),name='Real B'), row=1, col=2)
fig.add_trace(go.Scatter(x=fake_A[:, 0], y=fake_A[:, 1], mode='markers', marker=dict(color="red"),name='Fake'), row=1, col=1)
fig.add_trace(go.Scatter(x=fake_B[:, 0], y=fake_B[:, 1], mode='markers', marker=dict(color="red"),name='Fake', showlegend=False), row=1, col=2)
fig.update_layout(title_text="Real samples and fake samples from untrained generators")
fig.show()

**Q. Are the generated fake samples as you expected? Why or why not?**

Write your answers in the written portion.

## Part (B): Implement cycle consistency loss and train CycleGAN

For the CycleGAN model, we aim to solve the minimax game between the generators and the discriminators:
$$\underset{G_{AB}, G_{BA}}{min}\: \underset{D_A, D_B}{max}\: V(G_{AB}, G_{BA}, D_A, D_B) = \mathcal{L}(G_{AB},G_{BA},D_A, D_B)$$

The full objective is defined as:
$$\mathcal{L}(G_{AB},G_{BA},D_A, D_B) = \mathcal{L}(G_{AB},D_B,A,B) + \mathcal{L}(G_{BA},D_A,B,A) +\lambda \mathcal{L}_{cyc}(G_{AB}, G_{BA})$$
where $\lambda$ controls the relative importance of the two objectives (i.e. GAN or cycle consistency loss) and where the GAN and cycle consistency losses are defined as:
\begin{align*}
\mathcal{L}_{GAN}(G_{AB}, D_B, A, B) &= \mathbb{E}_{b \sim p_{data}(b)}\left[log \: D_B(b)\right]+\mathbb{E}_{a \sim p_{data}(a)}\left[log(1-D_B(G_{AB}(a)))\right] \\
\mathcal{L}_{cyc}(G_{AB}, G_{BA}) &= \mathbb{E}_{a \sim p_{data}(a)}\left[||G_{BA}(G_{AB}(a))-a||_1\right] + \mathbb{E}_{b \sim p_{data}(b)}\left[||G_{AB}(G_{BA}(b))-b||_1\right]
\end{align*}
Note that $\mathcal{L}(G_{BA},D_A,B,A)$ is calculated in similar fashion as $\mathcal{L}(G_{AB},D_B,A,B)$.


In this part, you should implement the GAN and cycle consistency loss functions from above in the train method in `cyclegan.py`. Once you have successfully implemented the functions, you can train the CycleGAN model below.

In [None]:
# Train CycleGAN
lambda_ = 5
history = model.train(A, B, optax.adam(learning_rate=5e-3), print_every=5, batch_size=1000, num_epochs=1000, cycle_loss_weight=lambda_)

In [None]:
## Implementation check
expected_history = NumericalCheckingRecord.load('library/checks/cyclegan_history_check')
for key, values in history.items():
    values = jnp.array(values)
    expected = jnp.array(expected_history.data[key])
    assert jnp.equal(values, expected).all().item(), f"Losses are not equal for {key}"

## Part (C): CycleGAN loss curves

In [None]:
fig = plt.figure(figsize=(10, 6))
for key, value in history.items():
    plt.plot(range(len(value)), value, label=key)
plt.legend()
plt.title("Loss curves of generators and discriminators")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()

**Q. Looking at the loss curves above, can you identify a limitation of CycleGAN?**

Write your answers in the written portion.

Now, let's visualize how the generators transformed samples across distributions.

In [None]:
gan_AB, gan_BA = model.create_distribution()
fake_A = gan_BA.generator(real_B)
fake_B = gan_AB.generator(real_A)

In [None]:
fig = make_subplots(rows=1, cols=2, subplot_titles = ["Distribution A", "Distribution B"])
fig.add_trace(go.Scatter(x=real_A[:, 0], y=real_A[:, 1], mode='markers', marker=dict(color="blue"),name='Real A'), row=1, col=1)
fig.add_trace(go.Scatter(x=real_B[:, 0], y=real_B[:, 1], mode='markers', marker=dict(color="green"),name='Real B'), row=1, col=2)
fig.add_trace(go.Scatter(x=fake_A[:, 0], y=fake_A[:, 1], mode='markers', marker=dict(color="red"),name='Fake'), row=1, col=1)
fig.add_trace(go.Scatter(x=fake_B[:, 0], y=fake_B[:, 1], mode='markers', marker=dict(color="red"),name='Fake', showlegend=False), row=1, col=2)
fig.update_layout(title_text="Real samples and fake samples from trained generators")
fig.show()


## Part (D): Visualize sample correspondence

Explanation of the graphs:

The function `visualize_point_correspondence` visualizes the correspondence across two sets of points by giving points that correspond to each other the same color. The left subplot corresponds to the points passed in as the first argument, while the right subplot corresponds to the second argument. For example, if you call `visualize_point_correspondence(real_A, fake_B)`, where `fake_B` is transformed from `real_A` by the generator, then an orange point in the left subplot is transformed by the generator to the corresponding orange point in the right subplot.

In [None]:
visualize_point_correspondence(real_A, fake_A, "Real A", "Fake A")

In [None]:
visualize_point_correspondence(real_B, fake_B, "Real B", "Fake B")

**Why does it seem that the clusters of real A samples and the clusters of fake A samples are switched, despite a low cycle loss observed at the end of training?**

Answer the question in written portion.

In [None]:
# TODO: the remaining cells in this question are SOLUTION ONLY. Remove them in the homework.
recon_A = gan_BA.generator(gan_AB.generator(real_A))
recon_B = gan_AB.generator(gan_BA.generator(real_B))

In [None]:
visualize_point_correspondence(real_A, recon_A, "Real A", "Reconstructed A")

In [None]:
visualize_point_correspondence(real_B, recon_B, "Real B", "Reconstructed B")

In [None]:
visualize_point_correspondence(real_A, real_B, "Real A", "Real B")

## Part (E): Ablation study with cycle-consistency loss
In this part, an ablation study is conducted on the impact of the cycle-consistency loss. In Part (B) you implemented cycle-consistency loss which consists of forward-consistency (from distribution A to B) and backward-consistency (from distribution B to A). In this part, you will conduct an ablation study examining the generator performance when only having forward-consistency. For your intuition, consider the cycle-consistency loss function.

$$\mathcal{L}_{cyc}(G_{AB}, G_{BA}) = \mathcal{L}_{forward}(G_{AB}) + \mathcal{L}_{backward}(G_{BA})$$
where 
\begin{align*}
\mathcal{L}_{forward}(G_{AB}, G_{BA}) &= \mathbb{E}_{a \sim p_{data}(a)}\left[||G_{BA}(G_{AB}(a))-a||_1\right] \\
\mathcal{L}_{backward}(G_{AB},G_{BA}) &= \mathbb{E}_{b \sim p_{data}(b)}\left[||G_{AB}(G_{BA}(b))-b||_1\right]
\end{align*}

In [6]:
model = models.cyclegan.CycleGAN(modules, (AMBIENT_DIM,), (AMBIENT_DIM,))
model.initialize(optax.sigmoid_binary_cross_entropy)

In [7]:
lambda_ = 5
history = model.train(A, B, optax.adam(learning_rate=5e-3), print_every=5, batch_size=1000, num_epochs=500, cycle_loss_weight=lambda_, cycle_loss_mask = 'BAB')

Epoch 0; Generator AB GAN loss:  0.6995; Generator BA GAN loss:  0.7239; Generator AB Cycle loss:  0.8255; Generator BA Cycle loss:  0.7572; Discriminator A loss:  0.6781; Discriminator B loss:  0.7197: : 1it [00:08,  8.08s/it]
Epoch 1; Generator AB GAN loss:  0.6928; Generator BA GAN loss:  0.7248; Generator AB Cycle loss:  0.8220; Generator BA Cycle loss:  0.7760; Discriminator A loss:  0.6762; Discriminator B loss:  0.7119: : 1it [00:00,  1.68it/s]
Epoch 2; Generator AB GAN loss:  0.6895; Generator BA GAN loss:  0.7241; Generator AB Cycle loss:  0.8218; Generator BA Cycle loss:  0.7964; Discriminator A loss:  0.6744; Discriminator B loss:  0.7079: : 1it [00:00,  1.99it/s]
Epoch 3; Generator AB GAN loss:  0.6881; Generator BA GAN loss:  0.7232; Generator AB Cycle loss:  0.8189; Generator BA Cycle loss:  0.8059; Discriminator A loss:  0.6728; Discriminator B loss:  0.7039: : 1it [00:00,  2.00it/s]
Epoch 4; Generator AB GAN loss:  0.6869; Generator BA GAN loss:  0.7223; Generator AB Cy

In [8]:
gan_AB, gan_BA = model.create_distribution()
fake_A = gan_BA.generator(real_B)
fake_B = gan_AB.generator(real_A)

In [9]:
plot_generated_samples(real_A, real_B, fake_A, fake_B)

## Part (F): Pretraining with GAN loss only

For this part, run pretraining with only GAN loss and then introduce it after 500 epochs. You do not need to implement any code for Part (E) and (F).

### Pretrain with $\lambda = 0$

In [None]:
# Reinitialize
model = models.cyclegan.CycleGAN(modules, (AMBIENT_DIM,), (AMBIENT_DIM,))
model.initialize(optax.sigmoid_binary_cross_entropy)

In [None]:
# Pretrain
lambda_pretrain = 0
history = model.train(A, B, optax.adam(learning_rate=5e-3), print_every=5, batch_size=1000, num_epochs=500, cycle_loss_weight=lambda_pretrain)

In [None]:
# Obtain generators and samples
gan_AB, gan_BA = model.create_distribution()
fake_A = gan_BA.draw_samples(real_B)
fake_B = gan_AB.draw_samples(real_A)

# Visualize samples
fig = make_subplots(rows=1, cols=2, subplot_titles = ["Distribution A", "Distribution B"])
fig.add_trace(go.Scatter(x=real_A[:, 0], y=real_A[:, 1], mode='markers', marker=dict(color="blue"),name='Real A'), row=1, col=1)
fig.add_trace(go.Scatter(x=real_B[:, 0], y=real_B[:, 1], mode='markers', marker=dict(color="green"),name='Real B'), row=1, col=2)
fig.add_trace(go.Scatter(x=fake_A[:, 0], y=fake_A[:, 1], mode='markers', marker=dict(color="red"),name='Fake'), row=1, col=1)
fig.add_trace(go.Scatter(x=fake_B[:, 0], y=fake_B[:, 1], mode='markers', marker=dict(color="red"),name='Fake', showlegend=False), row=1, col=2)
fig.update_layout(title_text="Real samples and fake samples from pretrained generators")
fig.show()

### Now, train with both GAN loss and cycle loss.

In [None]:
# Further training
lambda_train = 5
history = model.train(A, B, optax.adam(learning_rate=5e-3), print_every=5, batch_size=1000, num_epochs=500, cycle_loss_weight=lambda_train)

In [None]:
# Obtain generators and samples
gan_AB, gan_BA = model.create_distribution()
fake_A = gan_BA.draw_samples(real_B)
fake_B = gan_AB.draw_samples(real_A)

# Visualize samples
fig = make_subplots(rows=1, cols=2, subplot_titles = ["Distribution A", "Distribution B"])
fig.add_trace(go.Scatter(x=real_A[:, 0], y=real_A[:, 1], mode='markers', marker=dict(color="blue"),name='Real A'), row=1, col=1)
fig.add_trace(go.Scatter(x=real_B[:, 0], y=real_B[:, 1], mode='markers', marker=dict(color="green"),name='Real B'), row=1, col=2)
fig.add_trace(go.Scatter(x=fake_A[:, 0], y=fake_A[:, 1], mode='markers', marker=dict(color="red"),name='Fake'), row=1, col=1)
fig.add_trace(go.Scatter(x=fake_B[:, 0], y=fake_B[:, 1], mode='markers', marker=dict(color="red"),name='Fake', showlegend=False), row=1, col=2)
fig.update_layout(title_text="Real samples and fake samples from the trained generators")
fig.show()

**Q. How does pretraining without cycle loss affect subsequent training where cycle loss is present? Can you give a possible explanation?**

Write your answers in the written portion.

## Part (G): Pretraining with GAN loss and cycle consistency loss

In [None]:
# Reinitialize
model = models.cyclegan.CycleGAN(modules, (AMBIENT_DIM,), (AMBIENT_DIM,))
model.initialize(optax.sigmoid_binary_cross_entropy)

### Pretrain with $\lambda = 5$

In [None]:
# Pretrain
lambda_pretrain = 5
history = model.train(A, B, optax.adam(learning_rate=5e-3), print_every=5, batch_size=1000, num_epochs=1000, cycle_loss_weight=lambda_pretrain)

In [None]:
# Obtain generators and samples
gan_AB, gan_BA = model.create_distribution()
fake_A = gan_BA.draw_samples(real_B)
fake_B = gan_AB.draw_samples(real_A)

# Visualize samples
fig = make_subplots(rows=1, cols=2, subplot_titles = ["Distribution A", "Distribution B"])
fig.add_trace(go.Scatter(x=real_A[:, 0], y=real_A[:, 1], mode='markers', marker=dict(color="blue"),name='Real A'), row=1, col=1)
fig.add_trace(go.Scatter(x=real_B[:, 0], y=real_B[:, 1], mode='markers', marker=dict(color="green"),name='Real B'), row=1, col=2)
fig.add_trace(go.Scatter(x=fake_A[:, 0], y=fake_A[:, 1], mode='markers', marker=dict(color="red"),name='Fake'), row=1, col=1)
fig.add_trace(go.Scatter(x=fake_B[:, 0], y=fake_B[:, 1], mode='markers', marker=dict(color="red"),name='Fake', showlegend=False), row=1, col=2)
fig.update_layout(title_text="Real samples and fake samples from pretrained generators")
fig.show()

### Now, train with GAN loss ONLY

In [None]:
# Fine tuning
lambda_train = 0 # TODO: Enter lambda value for fine tuning
history = model.train(A, B, optax.adam(learning_rate=5e-3), print_every=5, batch_size=1000, num_epochs=500, cycle_loss_weight=lambda_train)

In [None]:
# Obtain generators and samples
gan_AB, gan_BA = model.create_distribution()
fake_A = gan_BA.draw_samples(real_B)
fake_B = gan_AB.draw_samples(real_A)

# Visualize samples
fig = make_subplots(rows=1, cols=2, subplot_titles = ["Distribution A", "Distribution B"])
fig.add_trace(go.Scatter(x=real_A[:, 0], y=real_A[:, 1], mode='markers', marker=dict(color="blue"),name='Real A'), row=1, col=1)
fig.add_trace(go.Scatter(x=real_B[:, 0], y=real_B[:, 1], mode='markers', marker=dict(color="green"),name='Real B'), row=1, col=2)
fig.add_trace(go.Scatter(x=fake_A[:, 0], y=fake_A[:, 1], mode='markers', marker=dict(color="red"),name='Fake'), row=1, col=1)
fig.add_trace(go.Scatter(x=fake_B[:, 0], y=fake_B[:, 1], mode='markers', marker=dict(color="red"),name='Fake', showlegend=False), row=1, col=2)
fig.update_layout(title_text="Real samples and fake samples from the trained generators")
fig.show()

**Q. How does pretraining with cycle loss affect subsequent training where cycle loss is absent? Can you give a possible explanation?**

Write your answers in the written portion.