# CycleGAN TensorFlow Implementation

Created by Paul Scott<br><br>
References:
* https://www.tensorflow.org/tutorials/generative/cyclegan
* https://machinelearningmastery.com/cyclegan-tutorial-with-keras/
* https://www.tensorflow.org/tutorials/generative/pix2pix

# Setup

### Install Packages

In [None]:
%%capture
%pip install tensorflow_addons

### Imports

In [None]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import models, layers, losses, optimizers
import tensorflow_addons.layers as tfa_layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from IPython import display
from PIL import Image
import numpy as np
import time
import shutil

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

### Helper Functions

In [None]:
def preprocess(x):
  return (x - 127.5) / 127.5


def postprocess(x):
  return x * 0.5 + 0.5


def make_dir(directory):
  if not os.path.isdir(directory):
    os.makedirs(directory)


def remove_dir(directory):
  if os.path.isdir(directory):
    shutil.rmtree(directory)

# Preprocessing

### Choose Dataset And Define Directories

In [None]:
dataset_name = 'vangogh2photo'
class_x, class_y = dataset_name.split('2')

env_dir = 'drive/MyDrive/colab_envs/cycle_gan_env'
checkpoint_dir = f'{env_dir}/checkpoints/{dataset_name}'
output_images_dir = f'{env_dir}/output_images/{dataset_name}'
sample_images_dir = f'{env_dir}/sample_images'
model_diagrams_dir = f'{env_dir}/model_diagrams'
generators_dir = f'{env_dir}/generators'
metrics_dir = f'{env_dir}/metrics'

make_dir(env_dir)
make_dir(checkpoint_dir)
make_dir(output_images_dir)
make_dir(sample_images_dir)
make_dir(model_diagrams_dir)
make_dir(generators_dir)
make_dir(metrics_dir)

### Get Dataset From Google Drive

In [None]:
!unzip -q drive/MyDrive/datasets/{dataset_name}.zip -d .

### Create Dataset Iterators

In [None]:
img_dim = 256
img_shape = (img_dim, img_dim, 3)

assert img_dim >= 32
assert np.log2(img_dim) == int(np.log2(img_dim))

def random_jitter(image):
  jitter_amount = img_dim // 8
  jitter_dim = img_dim + jitter_amount
  image = tf.image.resize(image, (jitter_dim, jitter_dim), method=tf.image.ResizeMethod.AREA)
  image = tf.image.random_crop(image, size=img_shape)
  image = tf.image.random_flip_left_right(image)
  image = tf.clip_by_value(image, 0, 255)
  image = preprocess(image)
  return image

datagen = ImageDataGenerator(preprocessing_function=random_jitter)

train_x = datagen.flow_from_directory(
    f'{dataset_name}/train_x',
    target_size=(img_dim, img_dim),
    batch_size=1,
    class_mode=None,
    interpolation='box',
)

train_y = datagen.flow_from_directory(
    f'{dataset_name}/train_y',
    target_size=(img_dim, img_dim),
    batch_size=1,
    class_mode=None,
    interpolation='box',
)

### Plot Dataset Samples

In [None]:
num_samples = 9

num_rows = int(np.ceil(np.sqrt(num_samples)))
num_cols = num_rows * 2

samples_x = [train_x.next()[0] for _ in range(num_samples)]
samples_y = [train_y.next()[0] for _ in range(num_samples)]

fig = plt.figure(figsize=(20, 10))
fig.tight_layout()
fig.subplots_adjust(top=0.9)
fig.suptitle(f'{class_x.capitalize()} to {class_y.capitalize()}', fontsize=24)

for i in range(num_rows):
  for j in range(num_rows * 2):
    sample = samples_x[i * num_rows + j] if j < num_rows else samples_y[(i - 1) * num_rows + j]
    fig.add_subplot(num_rows, num_cols, i * num_cols + j + 1)
    plt.title(class_x if j < num_rows else class_y)
    plt.imshow(postprocess(sample))
    plt.axis('off')

plt.show()

# Create Models

### Define Model Building Functions

In [None]:
def create_generator():
  init = tf.random_normal_initializer(0, 0.02)
  base_filters = img_dim // 4
  num_layers = int(np.log2(img_dim))
  num_dropouts = max(1, num_layers - 5)
  
  skips = []

  # input
  generator_in = keras.Input(shape=img_shape)
  x = generator_in

  # downsample layers
  for i in range(num_layers):
    filters = min(base_filters * 2 ** i, 512)
    x = downsample(x, filters, batch_norm=(i != 0))
    skips.append(x)

  # upsample layers
  for i, skip in enumerate(reversed(skips[:-1])):
    filters = min(base_filters * 2 ** (num_layers - i - 2), 512)
    x = upsample(x, filters, dropout=(i < num_dropouts))
    x = layers.Concatenate()((x, skip))
  
  # output layer
  generator_out = layers.Conv2DTranspose(3, 4, 2, 'same', kernel_initializer=init, activation='tanh')(x)

  # create generator
  generator = keras.Model(inputs=generator_in, outputs=generator_out)

  return generator


def create_discriminator():
  init = tf.random_normal_initializer(0, 0.02)
  base_filters = img_dim // 4

  # input
  discriminator_in = keras.Input(shape=img_shape)
  x = discriminator_in

  # downsample layers
  for i in range(3):
    x = downsample(x, base_filters * (2 ** i), batch_norm=(i != 0))

  # zero padding and final downsample layer
  x = layers.ZeroPadding2D()(x)
  x = downsample(x, base_filters * 8, strides=1, padding='valid')
  x = layers.ZeroPadding2D()(x)
  
  # output layer
  discriminator_out = layers.Conv2D(1, 4, 1, kernel_initializer=init)(x)

  discriminator = keras.Model(inputs=discriminator_in, outputs=discriminator_out)

  return discriminator


def downsample(x, filters, strides=2, padding='same', batch_norm=True):
  init = tf.random_normal_initializer(0, 0.02)
  x = layers.Conv2D(filters, 4, strides, padding, kernel_initializer=init, use_bias=False)(x)
  if batch_norm:
    x = tfa_layers.InstanceNormalization()(x)
  x = layers.LeakyReLU()(x)
  return x


def upsample(x, filters, dropout=False):
  init = tf.random_normal_initializer(0, 0.02)
  x = layers.Conv2DTranspose(filters, 4, 2, 'same', kernel_initializer=init, use_bias=False)(x)
  x = tfa_layers.InstanceNormalization()(x)
  if dropout:
    x = layers.Dropout(0.5)(x)
  x = layers.ReLU()(x)
  return x


def plot_model(model, output_file): 
  return keras.utils.plot_model(model, to_file=output_file, show_shapes=True, show_layer_activations=True)

### Plot Generator Architecture

In [None]:
plot_model(create_generator(), f'{model_diagrams_dir}/generator_{img_dim}.png')

### Plot Discriminator Architecture

In [None]:
plot_model(create_discriminator(), f'{model_diagrams_dir}/discriminator_{img_dim}.png')

# Train CycleGAN
$\mathcal{L}_{G_{XY}} = BCE(\textbf{1}, D_Y(G_{XY}(x))) + \mathcal{L}_{cycle} + \mathcal{L}_{id_{Y}}$

$\mathcal{L}_{G_{YX}} = BCE(\textbf{1}, D_X(G_{YX}(y))) + \mathcal{L}_{cycle} + \mathcal{L}_{id_{X}}$

$\mathcal{L}_{D_X} = \frac{1}{2} (BCE(\textbf{1}, D_X(x)) + BCE(\textbf{0}, D_X(G_{YX}(y))))$

$\mathcal{L}_{D_Y} = \frac{1}{2} (BCE(\textbf{1}, D_Y(y)) + BCE(\textbf{0}, D_Y(G_{XY}(x))))$

$\mathcal{L}_{cycle} = \lambda * (mean(|x - G_{YX}(G_{XY}(x))|)) + mean(|y - G_{XY}(G_{YX}(y))|))$

$\mathcal{L}_{id_{X}} = \frac{\lambda}{2} * mean(|x - G_{YX}(x)|)$

$\mathcal{L}_{id_{Y}} = \frac{\lambda}{2} * mean(|y - G_{XY}(y)|)$

### Define Training Functions

In [None]:
def train(num_epochs, restore_epoch=0, checkpoint_freq=5):
  np.save(f'{sample_images_dir}/{dataset_name}_x.npy', samples_x)
  np.save(f'{sample_images_dir}/{dataset_name}_y.npy', samples_y)

  avg_time_per_epoch = 0
  for epoch in range(restore_epoch, num_epochs):
    start = time.time()
    iterations = min(len(train_x), len(train_y))
    
    # train cycle gan
    for i in range(iterations):
      print(f'\rEpoch {epoch+1} Progress: {i+1}/{iterations}', end='')
      real_x = train_x.next()
      real_y = train_y.next()
      train_step(real_x, real_y)

    # save checkpoint
    if (epoch + 1) % checkpoint_freq == 0:
      checkpoint_manager.save()

    # display training progress
    generate_and_plot_images(f'Epoch {epoch+1}', plot_save_name=f'image_at_epoch_{(epoch+1):04d}')
    avg_time_per_epoch = print_time_string(avg_time_per_epoch, start, epoch, num_epochs, restore_epoch)


@tf.function
def train_step(real_x, real_y):
  with tf.GradientTape(persistent=True) as tape:
    
    # cycle real x
    fake_y = generator_xy(real_x, training=True)
    cycled_x = generator_yx(fake_y, training=True)

    # cycle real y
    fake_x = generator_yx(real_y, training=True)
    cycled_y = generator_xy(fake_x, training=True)

    # identities
    same_x = generator_yx(real_x, training=True)
    same_y = generator_xy(real_y, training=True)

    # discriminator on real images
    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)
    
    # discriminator on fake images
    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # generator losses
    generator_xy_loss = generator_loss(disc_fake_y)
    generator_yx_loss = generator_loss(disc_fake_x)

    # cycle loss
    cycle_loss = cycle_consistency_loss(real_x, cycled_x) + cycle_consistency_loss(real_y, cycled_y)

    # total generator losses
    generator_xy_loss += cycle_loss + identity_loss(real_y, same_y)
    generator_yx_loss += cycle_loss + identity_loss(real_x, same_x)

    # total discriminator losses
    discriminator_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    discriminator_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  generator_xy_gradients = tape.gradient(generator_xy_loss, generator_xy.trainable_variables)
  generator_yx_gradients = tape.gradient(generator_yx_loss, generator_yx.trainable_variables)
  discriminator_x_gradients = tape.gradient(discriminator_x_loss, discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(discriminator_y_loss, discriminator_y.trainable_variables)

  generator_xy_optimizer.apply_gradients(zip(generator_xy_gradients, generator_xy.trainable_variables))
  generator_yx_optimizer.apply_gradients(zip(generator_yx_gradients, generator_yx.trainable_variables))
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))


def discriminator_loss(real, fake):
  real_loss = bce_loss(tf.ones_like(real), real)
  fake_loss = bce_loss(tf.zeros_like(fake), fake)
  return (real_loss + fake_loss) / 2


def generator_loss(fake):
  return bce_loss(tf.ones_like(fake), fake)


def cycle_consistency_loss(real_image, cycled_image):
  return λ * tf.reduce_mean(tf.abs(real_image - cycled_image))


def identity_loss(real_image, same_image):
  return λ / 2 * tf.reduce_mean(tf.abs(real_image - same_image))


def generate_and_plot_images(title, plot_save_name=None):
  display.clear_output(wait=True)

  fake_y = generator_xy(samples_x)
  cycled_x = generator_yx(fake_y)

  fake_x = generator_yx(samples_y)
  cycled_y = generator_xy(fake_x)

  fig = plt.figure(figsize=(20, 10))
  fig.tight_layout()
  fig.subplots_adjust(top=0.9)
  fig.suptitle(title, fontsize=24)
  
  for i, images in enumerate(zip(samples_x, fake_y, cycled_x)):
    titles = [class_x, f'{class_x} to {class_y}', f'{class_x} to {class_y} to {class_x}']
    for j, (image, title) in enumerate(zip(images, titles)):
      fig.add_subplot(3, 6, 6 * i + j + 1)
      plt.title(title)
      plt.imshow(postprocess(image))
      plt.axis('off')

  for i, images in enumerate(zip(samples_y, fake_x, cycled_y)):
    titles = [class_y, f'{class_y} to {class_x}', f'{class_y} to {class_x} to {class_y}']
    for j, (image, title) in enumerate(zip(images, titles)):
      fig.add_subplot(3, 6, 6 * i + j + 4)
      plt.title(title)
      plt.imshow(postprocess(image))
      plt.axis('off')
  
  if plot_save_name:
    plt.savefig(f'{output_images_dir}/{plot_save_name}.png', bbox_inches='tight')
  plt.show()
  

def print_time_string(avg_time_per_epoch, start, epoch, num_epochs, restore_epoch):
  time_for_epoch = time.time() - start
  epoch_adj = epoch - restore_epoch
  avg_time_per_epoch = (avg_time_per_epoch * epoch_adj + time_for_epoch) / (epoch_adj + 1)
  remaining_epochs = num_epochs - (epoch_adj + 1)
  remaining_time = remaining_epochs * avg_time_per_epoch
  print(f'Time For Epoch {epoch + 1}: {get_time_string(time_for_epoch)}')
  print(f'Remaining Time: {get_time_string(remaining_time)}')
  return avg_time_per_epoch


def get_time_string(total_seconds):
  hours = int(total_seconds // 3600)
  remainder = total_seconds % 3600
  minutes = int(remainder // 60)
  seconds = round(remainder % 60, 2)
  time_string = ''
  if hours > 0:
    time_string += f'{hours}h '
  if remainder >= 60:
    time_string += f'{minutes}m '
  time_string += f'{seconds}s'
  return time_string

### Setup Models And Configure Training

In [None]:
λ = 10

generator_xy = create_generator()
generator_yx = create_generator()

discriminator_x = create_discriminator()
discriminator_y = create_discriminator()

generator_xy_optimizer = optimizers.Adam(2e-4, beta_1=0.5)
generator_yx_optimizer = optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = optimizers.Adam(2e-4, beta_1=0.5)

bce_loss = losses.BinaryCrossentropy(from_logits=True)

samples_x = np.array([train_x.next()[0] for _ in range(3)])
samples_y = np.array([train_y.next()[0] for _ in range(3)])

checkpoint = tf.train.Checkpoint(
  generator_xy=generator_xy,
  generator_yx=generator_yx,
  discriminator_x=discriminator_x,
  discriminator_y=discriminator_y,
  generator_xy_optimizer=generator_xy_optimizer,
  generator_yx_optimizer=generator_yx_optimizer,
  discriminator_x_optimizer=discriminator_x_optimizer,
  discriminator_y_optimizer=discriminator_y_optimizer,
)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)

### Restore Latest Checkpoint
(if necessary)

In [None]:
if checkpoint_manager.latest_checkpoint:
  checkpoint.restore(checkpoint_manager.latest_checkpoint)
  samples_x = np.load(f'{sample_images_dir}/{dataset_name}_x.npy')
  samples_y = np.load(f'{sample_images_dir}/{dataset_name}_y.npy')

### Train Model

In [None]:
train(100, checkpoint_freq=5)

### Save Generators

In [None]:
models.save_model(generator_xy, f'{generators_dir}/{class_x}2{class_y}')
models.save_model(generator_yx, f'{generators_dir}/{class_y}2{class_x}')

# Results

### Load Test Set

In [None]:
img_dim = 256
img_shape = (img_dim, img_dim, 3)

datagen = ImageDataGenerator(preprocessing_function=preprocess)

x_test = datagen.flow_from_directory(
  f'{dataset_name}/testX',
  target_size=(img_dim, img_dim),
  batch_size=1,
  class_mode=None,
  interpolation='box',
)

y_test = datagen.flow_from_directory(
  f'{dataset_name}/testY',
  target_size=(img_dim, img_dim),
  batch_size=1,
  class_mode=None,
  interpolation='box',
)

samples_x = np.array([x_test.next()[0] for _ in range(3)])
samples_y = np.array([y_test.next()[0] for _ in range(3)])

### Plot Images Generated From Test Set

In [None]:
generate_and_plot_images('Test Set')

### Test Images From The Web

In [None]:
from PIL import Image
import requests
from io import BytesIO

response = requests.get('https://media.cntraveler.com/photos/60e612ae0a709e97d73d9c60/1:1/w_3840,h_3840,c_limit/Beach%20Vacation%20Packing%20List-2021_GettyImages-1030311160.jpg')
img = Image.open(BytesIO(response.content))
img = preprocess(np.expand_dims(np.array(img.resize((img_dim, img_dim))), 0))
vg_img = postprocess(np.squeeze(generator_yx(img)))

plt.figure(figsize=(8, 8))
plt.imshow(vg_img)
plt.axis('off')
plt.show()