Becoming an artist takes a lot of skill. And becoming an artist with the status enjoyed by [Claude Monet](https://en.wikipedia.org/wiki/Claude_Monet) takes an _incredible_ amount of skill. Hundreds of artists practice their craft for years and years in order to produce something slightly close to Monet.

Luckily, we have Deep Learning. 😉

What artists take years to accomplish, we can accomplish in a few hours (even less if we have the right hardware) thanks to the awesome power granted to us by deep learning in the form of [Neural Style Transfer](https://en.wikipedia.org/wiki/Neural_style_transfer) (NST).

NST is a problem in computer vision where the goal is to take an input image and transfer the "style" of another image to this image. That is, we move the style of the input image from its original (input) domain to a new (output) domain. The two domains can be anything. In our case, our input domain consists of real-life photos taken by a camera and our output domain consists of Monet paintings. The goal is to make these real-life photos look like Monet paintings.

While the problem of NST has been studied for decades, deep learning models made their way into the field with the introduction of the [NST model](https://arxiv.org/abs/1508.06576) in 2015. The model treats the problem as a supervised learning task, i.e. it needs image pairs as input with one image being in the input domain and other being its output-domain counterpart so that it has "input features" and the ground truth "label". This involves a lot of effort in data collection since for each training image, we also need the expected output image.

Luckily, we have Deep Learning. 😉

The introduction of [Generative Adversarial Networks](https://en.wikipedia.org/wiki/Generative_adversarial_network) (GANs) in 2014 and the subsequent development of [CycleGAN](https://arxiv.org/abs/1703.10593), introduced in 2017, has allowed the treatment of NST as an unsupervised task, thereby eliminating the need of having paired data.

In this notebook, we will train a CycleGAN model from scratch on a TPU using [TensorFlow](https://github.com/tensorflow/tensorflow) to transform real-life photos to Monet-esque paintings.

> Note: For a quick overview on how GANs work, see this video: [A Friendly Introduction to Generative Adversarial Networks (GANs)](youtube.com/watch?v=8L11aMN5KY8).

> Note: For a nice explanation of the main ideas in the CycleGAN paper, see this video: [CycleGAN Paper Walkthrough](https://www.youtube.com/watch?v=5jziBapziYE).

# CycleGAN - Main Ideas

CycleGAN is based on two key ideas:
- IST should be "cycle consistent." That is, if we convert an image in domain A to one in domain B and then convert it back into domain A, we should get back the original image. That is, the generators should be able to reverse each others operations.
- If an image in domain A is given as input to the generator which generates images in domain A, the generator should do nothing since the image already in the generator's domain. That is, the generator should act as an identity function, yielding the image back as the output. The same goes for the generator for domain B.

Both these ideas lead to a cyclic forward pass. This is why the model is called a _CycleGAN_.

To achieve this, there are two sets of GANs. That is, we have two generators and two discriminators:
- There is a generator that takes images from domain A and generates images in domain B
- There is a discriminator which classifies images in domain B as real or fake
- There is a generator that takes images from domain B and generates images in domain A
- There is a discriminator which classifies images in domain A as real or fake

Thus, one GAN is responsible for converting an image in domain A to one in domain B and another GAN is responsible for converting an image in domain B to one in domain A. Both these GANs are trained simultaneously and the performance of one contributes in that of the other.

All discriminators are trained to output the probability of an image being real. An output of 1 implies that the discriminator thinks the image is a real image (not generated) and an output of 0 implies that it thinks the image is fake (generated).

For this dataset, the following diagram summarises the forward pass through the CycleGAN. The text in `()` denotes the domain each output/input is in (follow the color coded arrows and see "CycleGAN Model" section):

![CycleGAN Forward Pass](https://drive.google.com/uc?export=view&id=1EvPhRtsgOoD5WSy8ETsg91-Zl8YCNbkN)

To quantify the performance of the generators and discriminators, the following loss functions are involved:

- **Discriminator Loss** - This takes the outputs of a discriminator for a generated image and the real image, and applies a binary cross entropy loss on them since the discriminator is a binary classifier. The `y_true` is all 0's and all 1's respectively since the first is a fake image and the second is a real image.
- **Generator Loss** - This takes the output of the discriminator for its generated image and applies a binary cross entropy on it, with the `y_true` being all 1's since the generator's goal is to fool the discriminator into thinking that the generated image is real.
- **Cycle Consistency Loss* (CCL)* - This takes the original image and the cycled image, and calculates how far the cycled image is from the original image in order to measure the cycle consistency property mentioned above.
- **Identity Loss** - This takes the original image and the identity image, and calculates how far the identity image is from the original image in order to measure the identity function propery mentioned above.

The total loss for a generator is the sum of the generator loss, the _total_ (across both the generators) CCL and the identity loss. The total CCL is taken for each generator so that the overall model is cycle consistent as a whole rather one GAN individually.

The total loss for a discriminator is simply the discriminator loss.

All in all, we get a total of four loss values after every forward step:
- Domain A Generator Loss - Loss for the generator which generates images in domain A.
- Domain B Generator Loss - Loss for the generator which generates images in domain B.
- Domain A Discriminator Loss - Loss for the discriminator which classifies images in domain A as fake or real.
- Domain B Discriminator Loss - Loss for the discriminator which classifies images in domain B as fake or real.

The following diagram summarises the loss calculation for this dataset:

![CycleGAN Loss Calculation](https://drive.google.com/uc?export=view&id=1w5YHLl_XXUOfhMW6nMutbfUX4AwCO6gL)

# Imports

In [None]:
# Turn off annoying TensorFlow warnings
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "4"

In [None]:
import io
import os
import random
import warnings
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from kaggle_datasets import KaggleDatasets
from PIL import Image

from tensorflow.keras import (
    Input,
    Model,
    layers,
    optimizers,
    losses,
    utils
)

warnings.filterwarnings("ignore")

%matplotlib inline

In [None]:
def seed_everything(seed=42):
    tf.random.set_seed(42)
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    
seed_everything()

# Detect TPU

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    strategy = tf.distribute.TPUStrategy(tpu)
    print("Using TPU:")
except ValueError:
    print("No TPU found. Falling back to GPU/CPU:")
    strategy = tf.distribute.MirroredStrategy()

print("Number of accelerators: ", strategy.num_replicas_in_sync)

# Configuration

We will define some basic configuration that will be used throughout the notebook.

In [None]:
# This gives us the URL on Google Cloud Storage where this dataset is stored
GCS_DS_PATH = KaggleDatasets().get_gcs_path("gan-getting-started")
GCS_DS_PATH

In [None]:
class Config:
    DATA_DIR = GCS_DS_PATH
    
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    
    IMG_SIZE = 256
    
    PREFIXES = {
        "monet": "monet_",
        "photo": "photo_"
    }
    
    # Paper uses a batch size of 1
    BATCH_SIZE = 1
    
    EPOCHS = 50

    LR = 2e-4
    
    # This is a scaling factor for the cycle consistency
    # And the identity losses
    LAMBDA = 10
    
    VAL_SPLIT = 0.1
    
    @classmethod
    def filepath(cls, prefix="photo"):
        prefixes = tuple(cls.PREFIXES)
        
        if prefix not in prefixes:
            raise ValueError(f"Unrecognized prefix. Should be in {prefixes}.")
            
        return os.path.join(cls.DATA_DIR, f"{cls.PREFIXES[prefix]}tfrec")

# Dataset Functions

In [None]:
# Decode a single image in a TFRecord
# It also standardizes the image
def decode_image(image, channels=3) -> tf.Tensor:
    img = tf.image.decode_jpeg(image, channels=channels)
    img = (tf.cast(img, tf.float32) / 127.5) - 1
    img = tf.reshape(img, [Config.IMG_SIZE, Config.IMG_SIZE, channels])
    return img

In [None]:
# Read one TFRecord
# There are no labels and so we only return the image
def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [None]:
# Create TensorFlow dataset from the given list of filepaths
def get_dataset(filepaths, ordered=False) -> tf.data.Dataset:
    options = tf.data.Options()
    if ordered is False:
        options.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(filepaths, num_parallel_reads=Config.AUTOTUNE)
    dataset = dataset.with_options(options)

    dataset = dataset.map(read_tfrecord, num_parallel_calls=Config.AUTOTUNE)
    
    # Cache helps speed up data access by preventing multipe file I/O operations
    dataset = dataset.cache()
    
    dataset = dataset.shuffle(2048)
    
    return dataset

## Example

In [None]:
# Load dataset
path = os.path.join(Config.filepath("monet"), "*.tfrec")
monet_filenames = tf.io.gfile.glob(path)

path = os.path.join(Config.filepath("photo"), "*.tfrec")
photo_filenames = tf.io.gfile.glob(path)

monet_ds = get_dataset(monet_filenames).batch(1)
photo_ds = get_dataset(photo_filenames).batch(1)

In [None]:
# Create iterators
example_monet = next(iter(monet_ds))
example_photo = next(iter(photo_ds))

In [None]:
# Shape: Each image is an RGB image with dimensions 256x256
example_monet.shape, example_photo.shape

In [None]:
# Plot a Monet painting and a photo
monet = example_monet[0]
img = example_photo[0]

f, axs = plt.subplots(1, 2, figsize=(7, 7))

axs[0].imshow(monet)
axs[0].set_axis_off()
axs[0].set_title("Monet")

axs[1].imshow(img)
axs[1].set_axis_off()
axs[1].set_title("Photo")

# Discriminator

The discrimnators in a CycleGAN have the following architecture:

![CycleGan Discriminator](https://drive.google.com/uc?export=view&id=1RJeYlKABS8w9wuKDAzCx9Dyzb4x3opwF)

We will first define a convolutional block (yellow and pink block in the architecture above), which is made up of the following components:
- `Conv2D` layer
- Optional `InstanceNormalization`
- `LeakyReLU` activation

It takes the input for the layer, the number of filters for `Conv2D`, the stride and a boolean indicating whether instance normalization should be added or not.

The kernel size is set to 4 and there is a zero padding before the `Conv2D` layer.

> Note: The paper uses mirror or reflection padding but TensorFlow TPU doesn't support that operation. Therefore, reflection padding have been replaced by zero padding.

In [None]:
def discriminator_convblock(ip, filters, strides=1, *, norm=True):
    x = layers.ZeroPadding2D(padding=1)(ip)
    
    x = layers.Conv2D(filters=filters, kernel_size=4, strides=strides)(x)
    
    if norm is True:
        x = tfa.layers.InstanceNormalization()(x)
        
    x = layers.LeakyReLU()(x)
    
    return x

Now, we will define the discriminator itself. It takes the number of input channels in the image (defaulting to 3 for RGB images) and an optional list of integers denoting the number of filters for the convolutional blocks before the output layer. If not provided, it uses the values in the architecture above (4 blocks of sizes 64, 128, 256 and 512). Note that the first hidden layer does not have instance normalization.

The output layer is a `Conv2D` layer with 1 output channel since the discriminator does binary classification (1 - Real, 0 - Fake). Sigmoid is not applied here since it will be applied as part of the loss function.

For an $N\times3\times256\times256$ image, the output is always $N\times1\times30\times30$, where $N$ is the batch size. Notice that the architecture refers to the discriminator as a "PatchGAN". This is beacuse each pixel in this $30\times30$ output corresponds to a $70\times70$ patch in the input, i.e. each pixel, after applying sigmoid, is the probability of the corresponding $70\times70$ patch in the input being real.

> Note: Since the entire network is made up of convolutional layers and has no component that depends on input size, it can handle any input size.

In [None]:
def Discriminator(channels=3, conv_filters=None):
    ip = Input(shape=(Config.IMG_SIZE, Config.IMG_SIZE, channels))
    
    # If not provided, set conv_filers to default value
    if conv_filters is None:
        conv_filters = [64, 128, 256, 512]
        
    filters = conv_filters[-1]
    
    x = discriminator_convblock(ip=ip, filters=filters, strides=2, norm=False)
    
    for filters in conv_filters[1:-1]:
        x = discriminator_convblock(ip=x, filters=filters, strides=2)
        
    x = discriminator_convblock(ip=x, filters=conv_filters[-1], strides=1)
    
    x = layers.ZeroPadding2D(padding=1)(x)
    
    op = layers.Conv2D(filters=1, kernel_size=4, strides=1)(x)
    
    return Model(inputs=ip, outputs=op)

## Summary

We will initialize a Discriminator model and get a nice summary. We see that, with the default values, we have  $3,705,985$ parameters to train.

In [None]:
d = Discriminator()
d.summary()

In [None]:
utils.plot_model(d)

In [None]:
# Free up RAM
del d

# Generator

The generators in a CycleGAN have the following architecture. It follows a UNet-like architecture with a downsampling part and then an upsampling part, with a little ResNet sprinkled in between.

![CycleGAN Generator](https://drive.google.com/uc?export=view&id=11bEYU1ZiffmrOeVi5ib7-x8KuSV0KRdb)

We will first define a convolutional block, which can handle both the yellow blocks and the green blocks in the architecture above. It is made up of the following components:
- Either a `Conv2D` (yellow) layer or a `Conv2DTranspose` (green) layer.
- Optional `InstanceNormalization`.
- Optional `ReLU` activation.

It takes: 
- Input
- Number of filters for the `Conv2D` or `Conv2DTranspose`
- Padding
- Boolean indicating whether to use `Conv2d` or `ConvTranspose2d`
- Boolean indicating whether to add instance normalization or not
- Boolean indicating whether to add `ReLU` or not
- Any additional arguments by keyword. All additional `kwargs` are passed to either `Conv2d` or `ConvTranspose2d` to set things like the kernel size, stride and padding. Not passing `kwargs` will set these to their PyTorch defaults.

In [None]:
def generator_convblock(
    ip,
    filters,
    padding=1,
    *,
    transpose=False,
    use_activation=True,
    norm=True,
    **kwargs,
):
    x = layers.ZeroPadding2D(padding=padding)(ip)
        
    klass = layers.Conv2DTranspose if transpose is True else layers.Conv2D
    
    if transpose is True:
        kwargs["padding"] = "same"
    
    x = klass(filters=filters, **kwargs)(x)
    
    if norm is True:
        x = tfa.layers.InstanceNormalization()(x)
        
    if use_activation is True:
        x = layers.ReLU()(x)
        
    return x

Next, we will define the residual blocks (orange blocks in the architecture above). These are made up of two of the above convolutional blocks with `Conv2D` as their convolution operation. The first one has an activation while the second one does not. They use a kernel size of 3, stride of 1, padding of 1. They have the same number of input and output channels (since they are residual blocks and use skip connections which involve element-wise addition).

In [None]:
def residual_block(ip, filters):
    x = generator_convblock(
        ip=ip,
        filters=filters,
        kernel_size=3,
        strides=1,
        padding=1,
    )
    
    x = generator_convblock(
        ip=x,
        filters=filters,
        kernel_size=3,
        strides=1,
        padding=1,
        use_activation=False,
    )

    # Skip connection
    x = layers.Add()([ip, x])
    
    return x

Finally, we will define the generator itself. It takes the following arguments:
- Number of input channels in the image (defaulting to 3 for RGB images).
- Optional list of integers containing the number of filters for the convolutional blocks with `Conv2D` as their convolutional operation. If not provided, it uses the values from the paper (3 blocks of size 64, 128 and 256).
- Optional filters for the residual blocks (defaulting to 256). This and the last value in the above argument should match .
- Optional list of integers containing the number of filter for the convolutional blocks with `Conv2DTranspose` as their convolutional operation. If not provided, it uses the values from the paper (2 blocks of size 128 and 64).
- Number of residual blocks (defaulting to 6).

Note that the first convolutional block has no instance normalization.

The output layer is a `Conv2D` layer with the same number of filters as the input channels in the image since the generator is responsible for generating another image. It has a `tanh` activation, which makes the output similar to a standardized image.

For an $N\times3\times H\times W$ image, the output is also $N\times3\times H\times W$.

> Note: A mistake I made earlier was add an `InstanceNormalization` layer after the output layer. This completely crippled the model and it was not learning at all. Do not make the same mistake. 🙂

In [None]:
def Generator(
    channels=3,
    conv_filters=None,
    res_filters=256,
    transpose_filters=None,
    n_residuals=6,
):    
    # If not provided, set conv_filters to the default value
    if conv_filters is None:
        conv_filters = [64, 128, res_filters]
    # If provided, make sure that the last value is same as res_filters
    elif conv_filters[-1] != res_filters:
        msg = (
            f"Make sure that the last value (={conv_filters[-1]}) "
            "in conv_out_channels is the same as " 
            f"res_out_channels (={res_filters})"
        )
        raise ValueError(msg)

    # If not provided, set transpose_filters to default value
    if transpose_filters is None:
        transpose_filters = [128, 64]
        
    ip = Input(shape=(Config.IMG_SIZE, Config.IMG_SIZE, channels))
    
    filters = conv_filters[0]
    
    x = generator_convblock(
        ip=ip,
        filters=filters,
        kernel_size=7,
        strides=1,
        padding=3,
        norm=False
    )
    
    for filters in conv_filters[1:]:
        x = generator_convblock(
            ip=x,
            filters=filters,
            kernel_size=3,
            strides=2,
            padding=1,
        )
        
    for _ in range(n_residuals):
        x = residual_block(ip=x, filters=res_filters)
        
    for filters in transpose_filters:
        x = generator_convblock(
            ip=x,
            filters=filters,
            kernel_size=3,
            strides=2,
            padding=0,
            output_padding=1,
            transpose=True,
        )
        
        
    x = layers.ZeroPadding2D(padding=3)(x)
    # A tanh activation is added
    # This keeps the input in the range [-1, 1]
    # Which is similar to standardizing the image
    op = layers.Conv2D(
        filters=channels,
        kernel_size=7,
        strides=1,
        activation="tanh",
    )(x)    
    
    return Model(inputs=ip, outputs=op)

## Summary

We will initialize a Generator model and get a nice summary. We see that, with the default values, we have  $7,844,995$ parameters to train.

In [None]:
g = Generator()
g.summary()

In [None]:
utils.plot_model(g)

In [None]:
# Free up RAM
del g

# Loss Functions

We will define a function to calculate the loss for a discriminator.

The loss for a discriminator is the sum of the binary cross entropy loss for the real image (image which has not been generated) and the fake image (image generated using a generator), since the discriminator is a binary classifier.

The discriminator is set to output the probability of an image being real. A probability of 1 implies that the discriminator thinks the image is real while a probability of 0 implies that it thinks the image is fake. Thus, for the real image, the target is all ones and for a fake image, the target is all zeros.

In [None]:
with strategy.scope():
    def discriminator_loss(real_disc, fake_disc):
        # Discriminator is a binary classifier
        # And so uses Binary Cross Entropy
        loss_fn = losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)

        # The real image's y_true is all ones
        real_loss = loss_fn(tf.ones_like(real_disc), real_disc)
        # The fake image's y_true is all zeros
        fake_loss = loss_fn(tf.zeros_like(fake_disc), fake_disc)

        return (real_loss + fake_loss) * 0.5

We will define the generator loss for a generator. This loss is used by the generator to measure how good it is at fooling the discriminator. The goal of the generator is to make the discriminator output something as close to all 1s as possible for an image it generates. Thus, this loss is measured using binary cross entropy, where the output of the discriminator for the generated image is the prediction and all 1s is the target.

In [None]:
with strategy.scope():
    def generator_loss(fake_disc):
        loss_fn = losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)
        # The goal of the generator is to fool the discriminator into thinking
        # The generated image is real
        # Hence, the loss function uses all 1s as y_true
        return loss_fn(tf.ones_like(fake_disc), fake_disc)

Next, we will define the cycle consistency loss (CCL). CCL is the [L1 norm](https://en.wikipedia.org/wiki/Taxicab_geometry) or the mean absolute error (MAE) between the original image and the output after cycling the image back through the generator. This loss tells us how far the cycled image is from the original image. An alternative to this is the [L2 norm](https://en.wikipedia.org/wiki/Euclidean_distance) or mean squared error (MSE) but that has certain issues when there are outliers in the data.

`scale` is a scaling factor for the loss, which can be used to increase or decrease the contribution of the loss in the final generator loss.

In [None]:
with strategy.scope():
    def cycle_consistency_loss(original, cycled, scale):
        # Use L1 distance as loss function
        return tf.math.reduce_mean(tf.abs(original - cycled)) * scale

Finally, we will define the identity loss. The identity loss is measured in a way similar to CCL. This measures how far away the identity image (image generated by running original image through its own generator) is from the original image. Here, `scale` is halved, which implies that the identity loss will have half as much contribution to the generator loss as CCL.

In [None]:
with strategy.scope():
    def identity_loss(original, identity, scale):
        # Use L1 distance as loss function
        return tf.math.reduce_mean(tf.abs(original - identity)) * scale * 0.5

# CycleGAN Model

We will define a CycleGAN model to put everything together. It takes the generator and discriminator for the Monet paintings, and the generator and discriminator for the real-life photos.

The forward pass takes a real-life photo (`img`) and a Monet painting (`monet_img`), and does the following:

- Step 1
    - Converts `img` to a Monet-esque painting using the generator for Monet paintings, giving `fake_monet`.
    - Converts `monet_img` to a real-life photo using the generator for real-life photos, giving `fake_img`.
- Step 2
    - Recreates `monet_img` from `fake_img` using the generator for Monet paintings to test how good the generator is at undoing the changes made on the image. This gives `cycled_monet`.
    - Repeats the same thing, but this time for `img` using the generator for real-life photos and `fake_monet` . This gives `cycled_img`.
- Step 3
    - Passes `monet_img` through the generator for Monet paintings to test how good the generator is at recognizing that `monet_img` is from its own domain and does not need any style transfer. This gives `identity_monet`.
    - Repeats the same thing, but this time for `img` using the generator for real-life photos. This gives `identity_img`.
- Step 4
    - Classifies `monet_img` using the discriminator for Monet paintings. The expected output is close to 1 since `monet_img` is not generated. This gives `real_monet_disc`.
    - Classifies `img` using the discriminator for real-life photos. The expected output is close to 1 since `img` is not generated. This gives `real_img_disc`.
- Step 5
    - Classifies `fake_monet` using the discriminator for Monet paintings. The expected output is close to 0 since `fake_monet` is generated. This gives `fake_monet_disc`.
    - Classifies `fake_img` using the discriminator for real-life photos. The expected output is close to 0 since `fake_img` is generated. This gives `fake_img_disc`.


All the outputs are returned in a dictionary.
    
    
The cyclic nature of the network can be seen in the above steps.

In [None]:
class CycleGAN(Model):
    def __init__(self, monet_gen, monet_disc, img_gen, img_disc, scale=10):
        super().__init__()
        self.monet_gen = monet_gen
        self.monet_disc = monet_disc
        self.img_gen = img_gen
        self.img_disc = img_disc
        self.scale = scale
        
    def compile(
        self,
        monet_gen_opt,
        monet_disc_opt,
        img_gen_opt,
        img_disc_opt,
        gen_loss_fn,
        disc_loss_fn,
        id_loss_fn,
        ccl_fn
    ):
        super().compile()
        
        self.monet_gen_opt = monet_gen_opt
        self.monet_disc_opt = monet_disc_opt
        
        self.img_gen_opt = img_gen_opt
        self.img_disc_opt = img_disc_opt
        
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.id_loss_fn = id_loss_fn
        self.ccl_fn = ccl_fn
        
    def call(self, data):
        img, monet_img = data
        
        # Generate the fake Monet painting from original image
        # Generate the fake image from the original Monet
        fake_monet = self.monet_gen(img)
        fake_img = self.img_gen(monet_img)

        # Recreate the Monet from the fake image
        # Recreate the original image from the fake Monet
        cycled_monet = self.monet_gen(fake_img)
        cycled_img = self.img_gen(fake_monet)

        # Generate the identity images from the images
        identity_monet = self.monet_gen(monet_img)
        identity_img = self.img_gen(img)

        # Use discriminators to classifiy the real images
        # Expected output is close to 1 since the images are real
        real_monet_disc = self.monet_disc(monet_img)
        real_img_disc = self.img_disc(img)

        # Use discriminators to classify the fake images
        # Expected output is close to 0 since the images are fake
        fake_monet_disc = self.monet_disc(fake_monet)
        fake_img_disc = self.img_disc(fake_img)
        
        # Store all the outputs related to the normal image
        img_output = {
            "fake_img": fake_img,
            "cycled_img": cycled_img,
            "identity_img": identity_img,
            "real_img_disc": real_img_disc,
            "fake_img_disc": fake_img_disc,
        }
        
        # Store all the outputs related to the Monet image
        monet_output = {
            "fake_monet": fake_monet,
            "cycled_monet": cycled_monet,
            "identity_monet": identity_monet,
            "real_monet_disc": real_monet_disc,
            "fake_monet_disc": fake_monet_disc,
        }
        
        # Return all the outputs
        return {
            **img_output,
            **monet_output,
        }
        
    def train_step(self, data):
        img, monet_img = data
        
        with tf.GradientTape(persistent=True) as tape:
            outputs = self(data)
            
            # Calculate the generator losses
            img_gen_loss = self.gen_loss_fn(outputs["fake_img_disc"])
            monet_gen_loss = self.gen_loss_fn(outputs["fake_monet_disc"])
            
            # Add the identity losses to the generator losses
            img_gen_loss += self.id_loss_fn(
                original=img,
                identity=outputs["identity_img"],
                scale=self.scale,
            )
            monet_gen_loss += self.id_loss_fn(
                original=monet,
                identity=outputs["identity_monet"],
                scale=self.scale,
            )  
            
            # Calculate the total CCL
            total_ccl = self.ccl_fn(
                original=img,
                cycled=outputs["cycled_img"],
                scale=self.scale,
            )
            total_ccl += self.ccl_fn(
                original=monet,
                cycled=outputs["cycled_monet"],
                scale=self.scale,
            )
            
            # Add the total CCL to the generator losses
            img_gen_loss += total_ccl
            monet_gen_loss += total_ccl
            
            # Calculate the discriminator losses
            img_disc_loss = self.disc_loss_fn(
                outputs["real_img_disc"],
                outputs["fake_img_disc"],
            )
            
            monet_disc_loss = self.disc_loss_fn(
                outputs["real_monet_disc"],
                outputs["fake_monet_disc"],
            )
            
        # Calculate gradients for img_gen
        img_gen_weights = self.img_gen.trainable_weights
        img_gen_grads = tape.gradient(img_gen_loss, img_gen_weights)
        
        # Calculate gradients for monet_gen
        monet_gen_weights = self.monet_gen.trainable_weights
        monet_gen_grads = tape.gradient(monet_gen_loss, monet_gen_weights)
        
        # Calculate gradients for img_disc
        img_disc_weights = self.img_disc.trainable_weights
        img_disc_grads = tape.gradient(img_disc_loss, img_disc_weights)
        
        # Calculate gradients for monet_disc
        monet_disc_weights = self.monet_disc.trainable_weights
        monet_disc_grads = tape.gradient(monet_disc_loss, monet_disc_weights)
        
        # Apply the gradients to the respective optimizers
        self.img_gen_opt.apply_gradients(zip(img_gen_grads, img_gen_weights))
        self.monet_gen_opt.apply_gradients(zip(monet_gen_grads, monet_gen_weights))
        
        self.img_disc_opt.apply_gradients(zip(img_disc_grads, img_disc_weights))
        self.monet_disc_opt.apply_gradients(zip(monet_disc_grads, monet_disc_weights))
        
        # Return all the loss values
        return {
            "img_gen_loss": img_gen_loss,
            "img_disc_loss": img_disc_loss,
            "monet_gen_loss": monet_gen_loss,
            "monet_disc_loss": monet_disc_loss,
        }

## Summary

We will initialize a `CycleGAN` model and get a nice summary. We see that, with the default values, we have $23,101,960$ parameters to train. This checks out since there are two sets of a generator and a discriminator, and $2*(7,844,995+3,705,985)=23,101,972$

In [None]:
monet_gen = Generator()
monet_disc = Discriminator()

img_gen = Generator()
img_disc = Discriminator()

gan = CycleGAN(monet_gen=monet_gen, monet_disc=monet_disc, img_gen=img_gen, img_disc=img_disc)
gan.build(input_shape=[(None, 256, 256, 3), (None, 256, 256, 3)])
gan.summary()

In [None]:
# Free up RAM
del monet_gen
del monet_disc
del img_gen
del img_disc
del gan

# Training

## Load Datasets

In [None]:
path = os.path.join(Config.filepath("monet"), "*.tfrec")
monet_filenames = tf.io.gfile.glob(path)

path = os.path.join(Config.filepath("photo"), "*.tfrec")
photo_filenames = tf.io.gfile.glob(path)

In [None]:
monet_ds = get_dataset(monet_filenames).batch(Config.BATCH_SIZE)
photo_ds = get_dataset(photo_filenames).batch(Config.BATCH_SIZE)

# Prefetch allows TensorFlow
# To fetch data parallely while it is training the model
monet_ds = monet_ds.prefetch(Config.AUTOTUNE)
photo_ds = photo_ds.prefetch(Config.AUTOTUNE)

## Initialize CycleGAN

We will use the default values.

In [None]:
with strategy.scope():
    monet_gen = Generator()
    monet_disc = Discriminator()
    
    img_gen = Generator()
    img_disc = Discriminator()
    
    cycle_gan = CycleGAN(
        monet_gen=monet_gen,
        monet_disc=monet_disc,
        img_gen=img_gen,
        img_disc=img_disc,
        scale=Config.LAMBDA,
    )

## Initialize Optimizers and Compile the Model

In [None]:
with strategy.scope():
    monet_gen_opt = optimizers.Adam(learning_rate=Config.LR, beta_1=0.5)
    monet_disc_opt = optimizers.Adam(learning_rate=Config.LR, beta_1=0.5)
    
    img_gen_opt = optimizers.Adam(learning_rate=Config.LR, beta_1=0.5)
    img_disc_opt = optimizers.Adam(learning_rate=Config.LR, beta_1=0.5)
    
    cycle_gan.compile(
        monet_gen_opt=monet_gen_opt,
        monet_disc_opt=monet_disc_opt,
        img_gen_opt=img_gen_opt,
        img_disc_opt=img_disc_opt,
        gen_loss_fn=generator_loss,
        disc_loss_fn=discriminator_loss,
        id_loss_fn=identity_loss,
        ccl_fn=cycle_consistency_loss,
    )

## Get, Set and Go!

In [None]:
history = cycle_gan.fit(
    x=tf.data.Dataset.zip((photo_ds, monet_ds)),
    epochs=Config.EPOCHS,
)

# Loss Curve

We will visualize the loss curves for the generators and discriminators. 

In [None]:
# Calculate the mean loss per epoch
keys = ["img_gen_loss", "img_disc_loss", "monet_gen_loss", "monet_disc_loss"]

epoch_history = {"img": {}, "monet": {}}

for key in keys:
    img_type, model, _ = key.split("_")
    epoch_history[img_type][model] = np.array(
        [tf.reduce_mean(loss).numpy() for loss in history.history[key]]
    )

In [None]:
titles = {"img": "Image", "monet": "Monet"}

f, axs = plt.subplots(1, len(epoch_history), figsize=(10, 5))

for (key, losses), ax in zip(epoch_history.items(), axs.flatten()):
    ax.plot(epoch_history[key]["gen"], label="Generator")
    ax.plot(epoch_history[key]["disc"], label="Discriminator")
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Loss")
    ax.set_title(titles[key])
    ax.legend()

plt.tight_layout()

# Visual Validation

We will plot some of the outputs from the generator for Monet paintings.

In [None]:
n_imgs = 5

f, axs = plt.subplots(
    nrows=2,
    ncols=n_imgs,
    figsize=(10, 10),
    sharex=True,
    sharey=True,
)

axs = axs.flatten()
ds = enumerate(photo_ds.take(n_imgs))
               
# Plot side-by-side: Photo in even axis and Monet painting in odd axis
for (idx, img), ax1, ax2 in zip(ds, axs[::2], axs[1::2]):
    prediction = cycle_gan.monet_gen(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    
    ax1.imshow(img[0])
    ax1.set_axis_off()
    ax1.set_title(f"Photo-{idx + 1}")
    
    ax2.imshow(prediction)
    ax2.set_axis_off()
    ax2.set_title(f"Gen. Monet-{idx + 1}")
    
plt.tight_layout()

# Submission

We will create the submission file by directly writing to a ZIP file.

In [None]:
from tqdm.notebook import tqdm

with tqdm(photo_ds, unit="imgs", total=7038) as ds, zipfile.ZipFile("images.zip", "w") as zf:
    for idx, img in enumerate(ds):
        prediction = cycle_gan.monet_gen(img, training=False)[0].numpy()
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
        img = Image.fromarray(prediction)
        dest = io.BytesIO()
        img.save(dest, format="JPEG")
        zf.writestr(f"{idx + 1}.jpg", dest.getvalue())

# References

1. Leon A. Gatys, Alexander S. Ecker, Matthias Bethge. 2015. [A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576).
2. Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio. 2014. [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661).
3. Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros. 2017. [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593).