<a href="https://colab.research.google.com/github/ageraustine/Day-Night-Image-CycleGANs/blob/master/Day_Night_Translator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorflow_addons

Collecting tensorflow_addons
  Downloading tensorflow_addons-0.16.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 7.2 MB/s 
Installing collected packages: tensorflow-addons
Successfully installed tensorflow-addons-0.16.1


DATA FETCHING

In [None]:
# https://www.kaggle.com/datasets/raman77768/day-time-and-night-time-road-images/download
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import numpy as np
from matplotlib import pyplot as plt
import os
import zipfile

url = "https://storage.googleapis.com/kaggle-data-sets/801580/1383092/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20220502%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20220502T074036Z&X-Goog-Expires=259199&X-Goog-SignedHeaders=host&X-Goog-Signature=17a60bdb553871ad7ea563dcf664f3962a19670731e9e5df7a194c8567fa33c963ccc38d93a9478204265ec6755dc05f001706d71860a8544f5e4f248cfc14d4c9113484f1dc45a7ed1781f006e00c45efecc726ed35eb614748a2fb0af1ad35a8150fcaab59b1cc9223fd4804aa55f5de2672a16e2c4140cb25dab5213670837f117a69d1d67e67e534a63285ba5dee9678bceb5c4c53f6f28c58b083079aea142c08b9e98116b7a610470710063c25f0cd205ebd043208893cb97c4e68cb9d0b5c0cbc58cce4020eb1dc8c896700cd7d7a567752c948c97db87bb7fed09c2cc3a59d171f2d4a24affb3cede1ffe7e2cf3287eef2491f6c330f476e237bc8a1"
data_zpfile_path = os.path.join(os.getcwd(), "data.tar.gz")

def download():
  if not os.path.exists(data_zpfile_path):
    keras.utils.get_file(fname=data_zpfile_path, origin=url)

def extract():
  if not os.path.exists("data"):
    os.makedirs("data")
    extpath = os.path.join(os.getcwd(), "data")
    with zipfile.ZipFile(data_zpfile_path, "r") as zp:
      zp.extractall(extpath)
   

def get_data():
  download()
  extract()

get_data()

DATASET PREPARATION

In [None]:
import math
night_images_paths = [
                      os.path.join(os.getcwd(), "data/night time road images/night time road images/", x)
                      for x in os.listdir("data/night time road images/night time road images/")
                      ]
day_images_paths = [
                    os.path.join(os.getcwd(), "data/day light road images/day light road images/", x)
                    for x in os.listdir("data/day light road images/day light road images/")
                   ]
# Define the standard image size.
orig_img_size = (286, 286)
# Size of the random crops to be used during training.
input_img_size = (256, 256, 3)
# Weights initializer for the layers.
kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
# Gamma initializer for instance normalization.
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

buffer_size = 256
batch_size = 1

def normalize_img(img):
    img = tf.cast(img, dtype=tf.float32)
    # Map values in the range [-1, 1]
    return (img / 127.5) - 1.0

def preprocess_train_image(img):
    # Random flip
    img = tf.image.random_flip_left_right(img)
    # Resize to the original size first
    img = tf.image.resize(img, [*orig_img_size])
    # Random crop to 256X256
    img = tf.image.random_crop(img, size=[*input_img_size])
    # Normalize the pixel values in the range [-1, 1]
    img = normalize_img(img)
    return img

def preprocess_test_image(img):
    # Only resizing and normalization for the test images.
    img = tf.image.resize(img, [*input_img_size])
    img = normalize_img(img)
    return img

def load_train_images(image_paths):
  imgs = []
  for x in image_paths:
    image = tf.io.read_file(x)
    image = tf.image.decode_jpeg(image,3)
    image = preprocess_train_image(image)
    imgs.append(image)
  return imgs

def load_test_images(image_paths):
  imgs = []
  for x in image_paths:
    image = tf.io.read_file(x)
    image = tf.image.decode_jpeg(image,3)
    image = preprocess_test_image(x)
    imgs.append(image)
  return imgs

total_index = len(night_images_paths)
split_index = math.floor(0.1*total_index)

# train images paths
train_day_paths = day_images_paths[:split_index]
train_night_paths = night_images_paths[:split_index]

# test images paths
test_day_paths = day_images_paths[split_index:total_index]
test_night_paths = night_images_paths[split_index:total_index]

In [None]:
train_day = load_train_images(train_day_paths)
train_A_ds = tf.data.Dataset.from_tensors(train_day)
 

In [None]:
train_night = load_train_images(train_night_paths)
train_B_ds = tf.data.Dataset.from_tensors(train_night)
train_B_ds

<TensorDataset element_spec=TensorSpec(shape=(1696, 256, 256, 3), dtype=tf.float32, name=None)>

THE BUILDING BLOCKS OF OUR CYCLEGAN

In [None]:
class ReflectionPadding(layers.Layer):
  def __init__(self, padding = (1, 1), **kwargs):
      self.padding = tuple(padding)
      super(ReflectionPadding, self).__init__(**kwargs)

  def call(self, input_tensor, mask=None):
    padding_width, padding_height = self.padding
    padding_tensor = [
                      [0,0],
                      [padding_height, padding_height],
                      [padding_width, padding_width],
                      [0,0] ]
    return tf.pad(input_tensor, padding_tensor, mode='REFLECT')

def residual_block(
    x, 
    activation, 
    kernel_initializer= kernel_init, 
    kernel_size=(3,3), 
    padding="valid", 
    strides=(1, 1), 
    gamma_initializer=gamma_init, 
    use_bias=False):
  dim = x.shape[-1]
  input_tensor = x
  x = ReflectionPadding()(input_tensor)
  x = layers.Conv2D(dim, kernel_size, strides=strides, padding=padding, kernel_initializer=kernel_initializer, use_bias=use_bias)
  x = tfa.layers.InstanceNormalization(gamma_initializer = gamma_initializer)(x)
  x = activation(x)
  x = ReflectionPadding()(x)
  x = layers.Conv2D(dim, kernel_size, strides=strides, padding=padding, kernel_initializer=kernel_initializer, use_bias=use_bias)(x)
  x = tfa.layers.InstanceNormalization(gamma_initializer = gamma_initializer)(x)
  x = layers.add([input_tensor, x])
  return x

def downsample(
    x, 
    filters, 
    activation, 
    kernel_size=(3,3), 
    strides=(2,2), 
    kernel_initializer=kernel_init, 
    padding="same", 
    gamma_initializer=gamma_init, 
    use_bias= False ):
  x = layers.Conv2D(filters,kernel_size, strides=strides, kernel_initializer=kernel_initializer, padding=padding, use_bias=use_bias)(x)
  x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
  if activation:
    x = activation(x)
  return x

def upsample(
        x, 
        filters, 
        activation, 
        kernel_size=(3,3), 
        strides=(2,2), 
        kernel_initializer=kernel_init, 
        padding="same", 
        gamma_initializer=gamma_init, 
        use_bias= False):
  x = layers.Conv2DTranspose(filters,kernel_size, strides=strides, kernel_initializer=kernel_initializer, padding=padding, use_bias=use_bias)(x)
  x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
  if activation:
    x = activation(x)
  return x


In [None]:
def resnet_generator(
        filters=64, 
        num_downsampling_blocks =2, 
        num_upsampling_blocks=9, 
        num_residual_blocks=2, 
        gamma_initializer=gamma_init, 
        name=None, 
        ):
  img_input = layers.Input(shape=input_img_size, name=name)
  x = ReflectionPadding(padding=(3,3))(img_input)
  x = layers.Conv2D(filters, (7,7), kernel_initializer=kernel_init, use_bias=False)(x)
  x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
  x = layers.Activation("relu")(x)

  #downsampling
  for _ in range(num_downsampling_blocks):
    filters *= 2
    x = downsample(x, filters=filters, activation=layers.Activation("relu"))
   # residual block  
  for _ in range(num_residual_blocks):
    x = residual_block(x, activation=layers.Activation("relu"))
  # upsampling blocks
  for _ in range(num_upsampling_blocks):
    filters //= 2
    x = upsample(x, filters= filters, activation=layers.Activation("relu"))
  # final block
  x = ReflectionPadding(padding=(3,3))(x)
  x = layers.Conv2D(3, (7,7), padding="valid")(x)
  x = layers.Activation("tanh")(x)

  model = keras.models.Model(img_input, x, name=name)
  return model

In [None]:
def get_discriminator(filters = 64,num_downsampling_blocks= 3, kernel_initializer=kernel_init, name=None):
  img_input = layers.Input(shape=input_img_size, image = name + "_img_input")
  x = layers.Conv2D(filters, (4,4), strides=(2, 2), padding="same", kernel_initializer=kernel_initializer)(img_input)
  x = layers.LeakyReLU(0.2)(x)

  num_filters = filters 
  for block_num in range(num_downsampling_blocks):
    num_filters *= 2
    if block_num < 2:
      x = downsample(x, num_filters, activation= layers.LeakyReLU(0.2), kernel_size=(4,4), strides=(2, 2))
    else:
      x = downsample(x, num_filters, activation= layers.LeakyReLU(0.2), kernel_size=(4,4), strides=(1, 1))
  x = layers.Conv2D(1, (4, 4), strides=(1,1), padding="same", kernel_initializer=kernel_initializer)(x)

  model = keras.models.Model(inputs=img_input, outputs=x, name=name)
  return model

# declare the generators
gen_G = resnet_generator(name="generator_G")
gen_F = resnet_generator(name="generator_F")

# declare the discriminators
disc_X = get_discriminator(name="discriminator_X")
disc_Y = get_discriminator(name="discriminator_Y")

In [None]:
class CycleGAN(layers.Layer):
  def __init__(
          self, 
          generator_G, 
          generator_F, 
          discriminator_X, 
          discriminator_Y, 
          lambda_cycle=10, 
          lambda_identity=0.5 
          ):
    self.generator_G = generator_G
    self.generator_F = generator_F
    self.discriminator_X = discriminator_X
    self.discriminator_Y = discriminator_Y
    self.lambda_cycle = lambda_cycle
    self.lambda_identity = lambda_identity

  def compile(self, gen_G_optimizer, gen_F_optimizer, disc_optimizer_X, disc_optimizer_Y, gen_loss_fn, disc_loss_fn):
    super(CycleGAN, self).compile()
    self.gen_G_optimizer = gen_G_optimizer
    self.gen_F_optimizer = gen_F_optimizer
    self.disc_optimizer_X = disc_optimizer_X
    self.disc_optimize_Y = disc_optimizer_Y
    self.gen_loss_fn = gen_loss_fn
    self.disc_loss_fn = disc_loss_fn
    self.cycle_loss_fn = keras.losses.MeanAbsoluteError()
    self.identity_loss_fn = keras.losses.MeanAbsoluteError()

  def train_step(self, batch_data):
    real_x, real_y = batch_data

    with tf.GradientTape(persistent=True) as tape:
      # x to y
      fake_y = self.generator_G(real_x, training=True)
      # y to x
      fake_x = self.generator_F(real_y, training= True)
      # Reconstructing fake outputs back to original
      reconstructed_y = self.generator_G(fake_x, training=True)
      reconstructed_x = self.generator_F(fake_y, training=True)
      # Identity mapping
      same_x = self.generator_F(real_x, training= True)
      same_y = self.generator_G(real_y, training=True)
      #Discriminator output
      disc_real_x = self.discriminator_X(real_x, training=True)
      disc_fake_x = self.discriminator_X(fake_x, training=True)

      disc_real_y = self.discriminator_Y(real_y, training=True)
      disc_fake_y = self.discriminator_Y(fake_y, training=True)

      # Generator Adversarial loss
      gen_G_loss = self.gen_loss_fn(disc_fake_y)
      gen_F_loss = self.gen_loss_fn(disc_fake_x)
      # Generator cycle loss
      cycle_loss_G = self.cycle_loss_fn(real_y, reconstructed_y)*self.lambda_cycle
      cycle_loss_F = self.cycle_loss_fn(real_x, reconstructed_x)*self.lambda_cycle 
      # generator identity loss
      id_loss_G = (
          self.identity_loss_fn(real_y, same_y)
          *self.lambda_cycle
          *self.lambda_identity
      )
      id_loss_F = (
          self.identity_loss_fn(real_x, same_x)
          *self.lambda_cycle
          *self.lambda_identity
      )
      # Total generator loss
      total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
      total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F
      # Discriminator loss
      disc_X_loss = self.disc_loss_fn(disc_real_x, disc_fake_x)
      disc_Y_loss = self.disc_loss_fn(disc_real_y, disc_fake_y)

    #Get gradients for the generator
    grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables)
    grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables)
    # Get gradients for the dicriminator
    disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables)
    disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables)
    # Update the weights of the generators 
    self.gen_G_optimizer.apply_gradients(
            zip(grads_G, self.gen_G.trainable_variables)
        )
    self.gen_F_optimizer.apply_gradients(
            zip(grads_F, self.gen_F.trainable_variables)
        )
    # Update the weights of the dicriminators
    self.disc_X_optimizer.apply_gradients(
            zip(disc_X_grads, self.disc_X.trainable_variables)
        )
    self.disc_Y_optimizer.apply_gradients(
            zip(disc_Y_grads, self.disc_Y.trainable_variables)
        )
    return {
            "G_loss": total_loss_G,
            "F_loss": total_loss_F,
            "D_X_loss": disc_X_loss,
            "D_Y_loss": disc_Y_loss,
        }

A Callback for periodically saving the model

In [None]:
class GANMonitor(keras.callbacks.Callback):
  def __init__(self, num_imgs=4):
    self.num_imgs = num_imgs
  
  def on_epoch_end(self, epoch, logs=None):
    _, ax = plt.subplots(4, 2, figsize=(12, 12))
    for i, img in enumerate(train_A_ds.take(self.num_img)):
      prediction = self.model.gen_G(img)[0].numpy()
      prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
      img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

      ax[i, 0].imshow(img)
      ax[i, 1].imshow(prediction)
      ax[i, 0].set_title("Input image")
      ax[i, 1].set_title("Translated image")
      ax[i, 0].axis("off")
      ax[i, 1].axis("off")

      prediction = keras.preprocessing.image.array_to_img(prediction)
      prediction.save(
                "generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1)
            )
    plt.show()
    plt.close()

TRAINING THE MODEL

In [None]:
# loss function for evaluating adversarial loss
adv_loss = keras.losses.MeanSquaredError()

def generator_loss_fn(fake):
  fake_loss = adv_loss(tf.ones_like(fake), fake)
  return fake_loss

def discriminator_loss_fn(real, fake):
  real_loss = adv_loss(tf.ones_like(real), real)
  fake_loss= adv_loss(tf.zeros_like(fake), fake)
  return (real_loss + fake_loss)*0.5

# Create cycle gan model
cycle_gan_model = CycleGAN(
    generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
)

# Compile the model
cycle_gan_model.compile(
    gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    gen_loss_fn=generator_loss_fn,
    disc_loss_fn=discriminator_loss_fn,
)

# Callbacks
plotter = GANMonitor()
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath
)

cycle_gan_model.fit(
    tf.data.Dataset.zip((train_A_ds, train_B_ds)),
    epochs=1,
    callbacks=[plotter, model_checkpoint_callback],
)