# CycleGAN

CycleGAN (aka: Cycle-consistent Generative Adversarial Network) is different from the other GANs. Instead of generating new observations from scratch (by giving so noise to the genrator), CycleGAN implement a method called *style transfer* that consists in creating a new observation with two inputs (a style input and an input to apply the style on).

It works by simultaneously transfer the style of one type of images to another type of images. There are 3 key points in this learning process:
- It learns how to modify one type of images with the style of the other (Validity)
- The same method should be able to correct what happened to images during the previous step (ie: remove the other style) (Reconstruction)
- The same method should not modify an image of the same style (ie: applying the style of an image to the same image should not modify the image) (Identity)

Contrary to other style transfer methods like pix2pix, a CycleGAN does not need paired images to be trained (ie: a dataset with original images and the same images with a new style applied) because it is made of 4 models (2 generators and 2 discriminators) that will be trained in an adversarial way:
- A generator G_AB able to convert images from style A to style B
- A generator G_BA able to convert images from style B to style A
- A discriminator d_A able to identify real images with style A and images generated by G_BA
- A discriminator d_B able to identify real images with style B and images generated by G_AB

Note: the discriminators do not output a binary output (unlike ther other GANs) instead it returns a tensor indicating if each part of the sample has the correct style. This comes from a architecture called PatchGAN where the discriminator divides each image in patches and guess if each patch is real or generated. In the case of a CycleGAN, this helps the discriminator focusing on the style rather than the content.

Note: This type of GAN requires a dataset of images to transform and style images. Applying the style of an image to another single image can be done using Neural Style Transfer.

It is really easy to create an CycleGAN using Keras Model Subclassing API as show in this [Tensorflow tutorial](https://www.tensorflow.org/tutorials/generative/cyclegan).

Let's build an WGAN-GP "from scratch" to have a better understanding!

## Hand Made CycleGAN

In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import pickle

import tensorflow as tf
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Conv2D, UpSampling2D, Activation, Dropout, LayerNormalization, Concatenate, LeakyReLU, Add
from keras.optimizers import Adam
from keras.callbacks import Callback, LearningRateScheduler
from keras.callbacks import ModelCheckpoint
from keras.utils import plot_model

# Clear TensorFlow session
K.clear_session()

# Disable eager execution
# from tensorflow.python.framework.ops import disable_eager_execution
# disable_eager_execution()

# Tensorflow debugging
# tf.debugging.enable_check_numerics()

import matplotlib.pyplot as plt

### Load and Visualize Data

In [None]:
img_A = np.asarray(Image.open("inferno.png"))

print(type(img_A))
print(img_A.shape)

# Resize
img_A = img_A[:-8]
print(img_A.shape)

plt.imshow(img_A)
plt.axis("off")

In [None]:
img_B = np.asarray(Image.open("italy.jpg"))

print(type(img_B))
print(img_B.shape)

# Resize
img_B = img_B[:-8]
print(img_B.shape)

plt.imshow(img_B)
plt.axis("off")

In [None]:
IMG_WIDTH = 1920
IMG_HEIGHT = 1072 # 1072 is a multiple of 16 which is required for this specific U-Net architecture
IMG_DEPTH = 3

In [None]:
X_train_A = np.array([img_A])
X_train_B = np.array([img_B])

In [None]:
print(X_train_A.shape)

### Generator (U-Net)

In [None]:
class Generator_U():

    def __init__(self, input_dim, generator_n_filters, model_name):
        self.input_dim = input_dim
        self.generator_n_filters = generator_n_filters

        def downsample(layer_input, filters, kernel_size, name_suffix):
            d = Conv2D(filters=filters,
                       kernel_size=kernel_size,
                       strides=2,
                       padding="same",
                       name="generator_conv_" + name_suffix)(layer_input)
            d = LayerNormalization(axis=-1,
                                   center=False,
                                   scale=False,
                                   name="generator_instance_normalization_" + name_suffix)(d)
            d = Activation(activation="relu",
                           name="generator_activation_" + name_suffix)(d)
            return d
        
        def upsample(layer_input, layer_skip_input, filters, kernel_size, dropout_rate, name_suffix):
            u = UpSampling2D(size=2,
                             name="generator_upsampling_" + name_suffix)(layer_input)
            u = Conv2D(filters=filters,
                       kernel_size=kernel_size,
                       strides=1,
                       padding="same",
                       name="generator_conv_" + name_suffix)(u)
            u = LayerNormalization(axis=-1,
                                   center=False,
                                   scale=False,
                                   name="generator_instance_normalization_" + name_suffix)(u)
            u = Activation(activation="relu",
                           name="generator_activation_" + name_suffix)(u)
            if dropout_rate:
                u = Dropout(rate=dropout_rate,
                            name="generator_dropout_" + name_suffix)(u)
                
            u = Concatenate(name="generator_concatenate_" + name_suffix)([u, layer_skip_input])
            return u
        
        # Input
        self.input = Input(shape=self.input_dim,
                           name="generator_input")

        # Downsampling
        d1 = downsample(self.input, self.generator_n_filters, 4, "1")
        d2 = downsample(d1, self.generator_n_filters*2, 4, "2")
        d3 = downsample(d2, self.generator_n_filters*4, 4, "3")
        d4 = downsample(d3, self.generator_n_filters*8, 4, "4")

        # Upsampling
        u1 = upsample(d4, d3, self.generator_n_filters*4, 4, 0, "5")
        u2 = upsample(u1, d2, self.generator_n_filters*2, 4, 0, "6")
        u3 = upsample(u2, d1, self.generator_n_filters, 4, 0, "7")
        u4 = UpSampling2D(size=2,
                          name="generator_upsampling_8")(u3)
        
        # Output
        self.output = Conv2D(filters=self.input_dim[2],
                             kernel_size=4,
                             strides=1,
                             padding="same",
                             activation="tanh",
                             name="generator_output")(u4)
        
        # Model
        self.model = Model(self.input, self.output, name=model_name)
        
    def summary(self):
        self.model.summary()

    def predict(self, x):
        return self.model.predict(x)

### Generator (ResNet)

In [None]:
class Generator_R():

    def __init__(self, input_dim, generator_n_filters, model_name):
        self.input_dim = input_dim
        self.generator_n_filters = generator_n_filters

        def downsample(layer_input, filters, kernel_size, name_suffix):
            d = Conv2D(filters=filters,
                       kernel_size=kernel_size,
                       strides=2,
                       padding="same",
                       name="generator_down_sample_conv_" + name_suffix)(layer_input)
            d = LayerNormalization(axis=-1,
                                   center=False,
                                   scale=False,
                                   name="generator_down_sample_instance_normalization_" + name_suffix)(d)
            d = Activation(activation="relu",
                           name="generator_down_sample_activation_" + name_suffix)(d)
            return d
        
        def residual(layer_input, filters, name_suffix):
            shortcut = layer_input
            y = Conv2D(filters=filters,
                       kernel_size=(3,3),
                       strides=1,
                       padding="same",
                       name="generator_residual_conv_1_" + name_suffix)(layer_input)
            y = LayerNormalization(axis=-1,
                                   center=False,
                                   scale=False,
                                   name="generator_residual_instance_normalization_1_" + name_suffix)(y)
            y = Activation(activation="relu",
                           name="generator_residual_activation_" + name_suffix)(y)
            y = Conv2D(filters=filters,
                       kernel_size=(3,3),
                       strides=1,
                       padding="same",
                       name="generator_residual_conv_2_" + name_suffix)(y)
            y = LayerNormalization(axis=-1,
                                   center=False,
                                   scale=False,
                                   name="generator_residual_instance_normalization_2_" + name_suffix)(y)
            return Add()([shortcut, y])
        
        def upsample(layer_input, layer_skip_input, filters, kernel_size, dropout_rate, name_suffix):
            u = UpSampling2D(size=2,
                             name="generator_up_sample_upsampling_" + name_suffix)(layer_input)
            u = Conv2D(filters=filters,
                       kernel_size=kernel_size,
                       strides=1,
                       padding="same",
                       name="generator_up_sample_conv_" + name_suffix)(u)
            u = LayerNormalization(axis=-1,
                                   center=False,
                                   scale=False,
                                   name="generator_up_sample_instance_normalization_" + name_suffix)(u)
            u = Activation(activation="relu",
                           name="generator_up_sample_activation_" + name_suffix)(u)
            if dropout_rate:
                u = Dropout(rate=dropout_rate,
                            name="generator_up_sample_dropout_" + name_suffix)(u)
                
            u = Concatenate(name="generator_up_sample_concatenate_" + name_suffix)([u, layer_skip_input])
            return u
        
        # Input
        self.input = Input(shape=self.input_dim,
                           name="generator_input")

        # Downsampling
        d1 = downsample(self.input, self.generator_n_filters, 4, "1")
        d2 = downsample(d1, self.generator_n_filters*2, 4, "2")
        d3 = downsample(d2, self.generator_n_filters*4, 4, "3")
        d4 = downsample(d3, self.generator_n_filters*8, 4, "4")

        # Residual
        r1 = residual(d4, self.generator_n_filters*8, "1")
        r2 = residual(r1, self.generator_n_filters*8, "2")
        r3 = residual(r2, self.generator_n_filters*8, "3")
        r4 = residual(r3, self.generator_n_filters*8, "4")
        r5 = residual(r4, self.generator_n_filters*8, "5")
        r6 = residual(r5, self.generator_n_filters*8, "6")

        # Upsampling
        u1 = upsample(d4, d3, self.generator_n_filters*4, 4, 0, "5")
        u2 = upsample(u1, d2, self.generator_n_filters*2, 4, 0, "6")
        u3 = upsample(u2, d1, self.generator_n_filters, 4, 0, "7")
        u4 = UpSampling2D(size=2,
                          name="generator_upsampling_8")(u3)
        
        # Output
        self.output = Conv2D(filters=self.input_dim[2],
                             kernel_size=4,
                             strides=1,
                             padding="same",
                             activation="tanh",
                             name="generator_output")(u4)
        
        # Model
        self.model = Model(self.input, self.output, name=model_name)
        
    def summary(self):
        self.model.summary()

    def predict(self, x):
        return self.model.predict(x)

### Discriminator

In [None]:
class Discriminator():

    def __init__(self, input_dim, discriminator_n_filters, model_name):
        self.input_dim = input_dim
        self.discriminator_n_filters = discriminator_n_filters

        def conv4(layer_input, filters, strides, instance_normalisation, name_suffix):
            y = Conv2D(filters=filters,
                       kernel_size=4,
                       strides=strides,
                       padding="same",
                       name="discriminator_conv_" + name_suffix)(layer_input)
            
            if instance_normalisation:
                y = LayerNormalization(axis=-1,
                                       center=False,
                                       scale=False,
                                       name="discriminator_instance_normalization_" + name_suffix)(y)
                
            y = LeakyReLU(alpha=0.2,
                          name="discriminator_leaky_relu_" + name_suffix)(y)
            return y
        
        # Input
        self.input = Input(shape=self.input_dim)

        # Layers
        y = conv4(self.input, self.discriminator_n_filters, 2, False, "1")
        y = conv4(y, self.discriminator_n_filters*2, 2, True, "2")
        y = conv4(y, self.discriminator_n_filters*4, 2, True, "3")
        y = conv4(y, self.discriminator_n_filters*8, 1, True, "4")

        # Output
        self.output = Conv2D(filters=1,
                             kernel_size=4,
                             strides=1,
                             padding="same",
                             name="discriminator_output")(y)
        
        # Model
        self.model = Model(self.input, self.output, name=model_name)

    def summary(self):
        self.model.summary()

    def predict(self, x):
        return self.model.predict(x)

### CycleGAN

In [None]:
class CycleGAN():

    def __init__(self, input_dim, generator_n_filtres, discriminator_n_filters, learning_rate, lambda_validation, lambda_reconstruction, lambda_identity):
        self.input_dim = input_dim
        self.generator_n_filtres = generator_n_filtres
        self.discriminator_n_filters = discriminator_n_filters
        self.learning_rate = learning_rate
        self.lambda_validation = lambda_validation
        self.lambda_reconstruction = lambda_reconstruction
        self.lambda_identity = lambda_identity

        self.epoch = 0

        self.discriminator_A_loss = []
        self.discriminator_B_loss = []
        self.discriminator_loss = []
        self.global_loss = []

        # Discriminators
        self.discriminator_A = Discriminator(self.input_dim, self.discriminator_n_filters, model_name="discriminator_A")
        self.discriminator_A.model.compile(loss="mse",
                                           optimizer=Adam(learning_rate=self.learning_rate, beta_1=0.5),
                                           metrics=["accuracy"])
        self.discriminator_B = Discriminator(self.input_dim, self.discriminator_n_filters, model_name="discriminator_B")
        self.discriminator_B.model.compile(loss="mse",
                                           optimizer=Adam(learning_rate=self.learning_rate, beta_1=0.5),
                                           metrics=["accuracy"])
        
        # Generators
        # self.generator_AB = Generator_U(self.input_dim, self.generator_n_filtres, model_name="generator_AB")
        # self.generator_BA = Generator_U(self.input_dim, self.generator_n_filtres, model_name="generator_BA")
        self.generator_AB = Generator_R(self.input_dim, self.generator_n_filtres, model_name="generator_AB")
        self.generator_BA = Generator_R(self.input_dim, self.generator_n_filtres, model_name="generator_BA")
        
        self.discriminator_A.model.trainable = False
        self.discriminator_B.model.trainable = False

        self.real_A = Input(shape=self.input_dim, name="")
        self.real_B = Input(shape=self.input_dim, name="")

        self.generated_A = self.generator_BA.model(self.real_B)
        self.generated_B = self.generator_AB.model(self.real_A)

        self.valid_A = self.discriminator_A.model(self.generated_A)
        self.valid_B = self.discriminator_B.model(self.generated_B)

        self.reconstruct_A = self.generator_BA.model(self.generated_B)
        self.reconstruct_B = self.generator_AB.model(self.generated_A)

        self.identity_A = self.generator_BA.model(self.real_A)
        self.identity_B = self.generator_AB.model(self.real_B)

        self.model = Model(inputs=[self.real_A, self.real_B],
                           outputs=[self.valid_A, self.valid_B,
                                    self.reconstruct_A, self.reconstruct_B,
                                    self.identity_A, self.identity_B])
        self.model.compile(loss=["mse", "mse",
                                 "mae", "mae",
                                 "mae", "mae"],
                           loss_weights=[self.lambda_validation, self.lambda_validation,
                                         self.lambda_reconstruction, self.lambda_reconstruction,
                                         self.lambda_identity, self.lambda_identity],
                           optimizer=Adam(learning_rate=self.learning_rate, beta_1=0.5))
        
        self.discriminator_A.model.trainable = True
        self.discriminator_B.model.trainable = True

    def summary(self):
        self.model.summary()

    def plot_model(self, run_folder):
        plot_model(self.model, to_file=os.path.join(run_folder ,'viz/model.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.discriminator_A.model, to_file=os.path.join(run_folder ,'viz/discriminator_A.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.discriminator_B.model, to_file=os.path.join(run_folder ,'viz/discriminator_B.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.generator_AB.model, to_file=os.path.join(run_folder ,'viz/generator_AB.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.generator_BA.model, to_file=os.path.join(run_folder ,'viz/generator_BA.png'), show_shapes=True, show_layer_names=True)

    def fit(self, X_train_A, X_train_B, batch_size, epochs):
        patch_height = int(self.input_dim[0] / 2**3)
        patch_width = int(self.input_dim[1] / 2**3)
        patch_dim = (patch_height, patch_width, 1)
        print(f"patch_dim = {patch_dim}")

        real = np.ones((batch_size,) + patch_dim)       # One response per patch
        generated = np.zeros((batch_size,) + patch_dim) # One response per patch

        for epoch in range(self.epoch, epochs):
            generated_A = self.generator_BA.predict(X_train_B)
            generated_B = self.generator_AB.predict(X_train_A)

            # Train discriminator A
            discriminator_A_loss_real = self.discriminator_A.model.train_on_batch(X_train_A, real)
            discriminator_A_loss_generated = self.discriminator_A.model.train_on_batch(generated_A, generated)
            discriminator_A_loss = 0.5 * np.add(discriminator_A_loss_real, discriminator_A_loss_generated)
            self.discriminator_A_loss.append([discriminator_A_loss_real, discriminator_A_loss_generated, discriminator_A_loss])

            # Train discriminator B
            discriminator_B_loss_real = self.discriminator_B.model.train_on_batch(X_train_B, real)
            discriminator_B_loss_generated = self.discriminator_B.model.train_on_batch(generated_B, generated)
            discriminator_B_loss = 0.5 * np.add(discriminator_B_loss_real, discriminator_B_loss_generated)
            self.discriminator_B_loss.append([discriminator_B_loss_real, discriminator_B_loss_generated, discriminator_B_loss])

            discriminator_loss = 0.5 * np.add(discriminator_A_loss, discriminator_B_loss)
            self.discriminator_loss.append(discriminator_loss)

            # Train generators
            g_loss = self.model.train_on_batch([X_train_A, X_train_B],
                                                        [real, real,
                                                         X_train_A, X_train_B,
                                                         X_train_A, X_train_B])
            self.global_loss.append(g_loss)
            
            self.epoch += 1

    def load_weights(self, filepath="model/weights/params.pkl"):
        self.model.load_weights(filepath)

    def predict(self, x):
        return self.generator.model.predict(x)

    def save(self, folder="model"):
        if not os.path.exists(folder):
            os.makedirs(folder)
            os.makedirs(os.path.join(folder, 'viz'))
            os.makedirs(os.path.join(folder, 'weights'))
            os.makedirs(os.path.join(folder, 'images'))

        with open(os.path.join(folder, 'weights/params.pkl'), 'wb') as f:
            pickle.dump([self.input_dim,
                         self.discriminator_conv_filters,
                         self.discriminator_conv_kernel_sizes,
                         self.discriminator_conv_strides,
                         self.discriminator_batch_norm_momentum,
                         self.discriminator_dropout_rate,
                         self.latent_dim,
                         self.generator_conv_filters,
                         self.generator_conv_kernel_sizes,
                         self.generator_conv_strides,
                         self.generator_batch_norm_momentum,
                         self.generator_dropout_rate], f)
        self.plot_model(folder)

In [None]:
cycle_gan = CycleGAN(input_dim=(IMG_HEIGHT,IMG_WIDTH,IMG_DEPTH),
                     generator_n_filtres=32,
                     discriminator_n_filters=32,
                     learning_rate=0.0002,
                     lambda_validation=1,
                     lambda_reconstruction=10,
                     lambda_identity=2)

In [None]:
cycle_gan.plot_model("model")

In [None]:
cycle_gan.discriminator_A.summary()

In [None]:
# cycle_gan.discriminator_B.summary()

In [None]:
cycle_gan.generator_AB.summary()

In [None]:
# cycle_gan.generator_BA.summary()

In [None]:
cycle_gan.summary()

### Train

In [None]:
BATCH_SIZE = 1
EPOCHS = 200

In [None]:
cycle_gan.fit(X_train_A, X_train_B, BATCH_SIZE, EPOCHS)

### Evaluate CycleGAN

In [None]:
# Plot discriminator losses
plt.plot([row[0] for row in cycle_gan.discriminator_losses], label="loss")
plt.plot([row[1] for row in cycle_gan.discriminator_losses], label="loss (real images)")
plt.plot([row[2] for row in cycle_gan.discriminator_losses], label="loss (generated images)")
plt.legend()
plt.show()

### Save model

In [None]:
cycle_gan.save()

### Load Pre-Trained Model

In [None]:
# cycle_gan.load_weights()

### Predictions