<a href="https://colab.research.google.com/github/KoroshRH/Image-Colorizer/blob/main/Colorizer_CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/tensorflow/examples.git

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import matplotlib.pyplot as plt
from IPython.display import clear_output

Here we set parameters for our training procedure and we can change them to see the corresponding results.

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
LAMBDA = 10
EPOCHS = 10
TRAIN_WINDOWS_SIZE = 200
TEST_WINDOWS_SIZE = 50000

# Preprocessing
In next 4 cells, we define preprocessing section methods to resize, flip, make images gray, and normalize the values.

In [None]:
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

In [None]:
def random_jitter(image):
  # resizing to 256 x 256 x 3
  image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image

In [None]:
def preprocess_image(image):
  image = random_jitter(image)
  image = normalize(image)
  return image

In [None]:
def make_grayscale(img):
  gray = tf.image.rgb_to_grayscale(img)
  gray = tf.concat([gray, gray, gray], axis=-1)
  return gray

# Dataset
We use CelebA dataset for this project, but you can use every available dataset for this section and translate their attributes to each other.

In [None]:
gcs_base_dir = "gs://celeb_a_dataset/"
celeb_a_builder = tfds.builder("celeb_a", data_dir=gcs_base_dir, version='2.0.0')
celeb_a_builder.download_and_prepare()

In [None]:
train_x = celeb_a_builder.as_dataset("train[:" + str(TRAIN_WINDOWS_SIZE) + "]").shuffle(BUFFER_SIZE).map(lambda celeb: celeb["image"])
train_y = celeb_a_builder.as_dataset("train[" + str(TRAIN_WINDOWS_SIZE) + ":" + str(2 * TRAIN_WINDOWS_SIZE) + "]").shuffle(BUFFER_SIZE).map(lambda celeb: celeb["image"])

test_x = celeb_a_builder.as_dataset("train[" + str(2 * TRAIN_WINDOWS_SIZE) + ":" + str(2 * TRAIN_WINDOWS_SIZE + TEST_WINDOWS_SIZE) + "]").shuffle(BUFFER_SIZE).map(lambda celeb: celeb["image"])

In [None]:
train_x = train_x.map(make_grayscale)

In [None]:
train_x = train_x.map(preprocess_image).batch(BATCH_SIZE)
train_y = train_y.map(preprocess_image).batch(BATCH_SIZE)

test_x = test_x.map(preprocess_image).batch(BATCH_SIZE)

# Model
Here we use unet model from [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py), but we are going to use them in a Cyclic structure. 

In [None]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [None]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# Loss functions

In [None]:
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)
  generated_loss = loss_obj(tf.zeros_like(generated), generated)
  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

In [None]:
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

## Cycle loss
This loss function is the main idea of CycleGAN.
In CycleGAN, we want to translate a domain's attributes to another domain and keep the main characteristics from the source domain.
This loss function returns the difference between translated image and the original one.

In [None]:
def calc_cycle_loss(real_image, cycled_image):
  loss = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss

In [None]:
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

# Training

In [None]:
generator_g_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

In [None]:
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.

    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)

    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)

  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))
  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))

In [None]:
sample_celeb = next(iter(test_x))
gray_sample = make_grayscale(sample_celeb)

### Helper funtion
This function allow us to display gray, colorized, and the original picture and compare them.


In [None]:
def generate_images(model, test_input, ground_truth):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0], ground_truth[0]]
  title = ['Input Image', 'Predicted Image', 'Original Colored Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

In [None]:
for epoch in range(EPOCHS):
  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_x, train_y)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)

  print(str(epoch + 1) + "/" + str(EPOCHS))
  generate_images(generator_g, gray_sample, sample_celeb)

In [None]:
for test in test_x.take(5):
  gray_sample = make_grayscale(test)
  generate_images(generator_g, gray_sample, test)