In [None]:
import tensorflow as tf
from tensorflow.keras import datasets,layers,models
from tensorflow_addons.layers import InstanceNormalization
import numpy as np

In [None]:
import tensorflow_datasets as tfds
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
dataset, metadata = tfds.load('cycle_gan/horse2zebra',with_info=True, as_supervised=True)
horses_train, zebras_train = dataset['trainA'], dataset['trainB']
horses_test, zebras_test = dataset['testA'], dataset['testB']

In [None]:
mnist_builder = tfds.builder("cycle_gan/horse2zebra")
info = mnist_builder.info
print(info)


In [None]:
BUFFER_SIZE = 900
BATCH_SIZE = 1#because we are using instance normalization(it works better for style applications)
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
def crop(img):
  cropped=tf.image.random_crop(img,[IMG_HEIGHT,IMG_WIDTH,3])
  return cropped

In [None]:

def normalize(img):#normalised to -1 to 1
  img = tf.cast(img, tf.float32)
  img = (img / 127.5) - 1
  return img

In [None]:
def process_image(img):
  img=tf.image.resize(img,[286,286],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)#resize to 286*286 (height and width)
  img=crop(img)#crop image
  img=tf.image.random_flip_left_right(img)#randomly flip the images
  return img

In [None]:
def preprocess_image_train(img,label):#function that will be passed to map() function
  image = process_image(img)
  image = normalize(image)
  return image

In [None]:
def preprocess_image_test(image,label):
  image = normalize(image)
  return image

In [None]:
#applying the preprocess function to all images of the dataset using map()
train_horses = horses_train.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = zebras_train.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = horses_test.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = zebras_test.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

In [None]:
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))

In [None]:
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

In [None]:
def residual_block(inputs):
  initializer = tf.keras.initializers.RandomNormal(mean=0.0,stddev=0.02)
  y=tf.keras.layers.Conv2D(256,(3,3),padding='same',kernel_initializer=initializer)(inputs)
  y=InstanceNormalization(axis=-1)(y)
  y=tf.keras.layers.Activation('relu')(y)
  y=tf.keras.layers.Conv2D(256,(3,3),padding='same',kernel_initializer=initializer)(y)
  y=InstanceNormalization(axis=-1)(y)
  y = tf.keras.layers.Concatenate()([y, inputs])# making skip connection
  return y

In [None]:
# the cycle gan uses two generators and two discriminators
def make_generator():
    inputs = tf.keras.Input(shape=(IMG_HEIGHT,IMG_WIDTH,3))
    initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    x = tf.keras.layers.Conv2D(64,(7,7),padding='same',kernel_initializer=initializer)(inputs)
    x = InstanceNormalization(axis=-1)(x)
    x=tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Conv2D(128,(3,3),strides=2,padding='same',kernel_initializer=initializer)(x)
    x = InstanceNormalization(axis=-1)(x)
    x=tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Conv2D(256,(3,3),strides=2,padding='same',kernel_initializer=initializer)(x)
    x = InstanceNormalization(axis=-1)(x)
    x=tf.keras.layers.Activation('relu')(x)
    for i in range(9):
      x=residual_block(x)
    x = tf.keras.layers.Conv2DTranspose(128,(3,3),strides=2,padding='same',kernel_initializer=initializer)(x)
    x = InstanceNormalization(axis=-1)(x)
    x=tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Conv2DTranspose(64,(3,3),strides=2,padding='same',kernel_initializer=initializer)(x)
    x = InstanceNormalization(axis=-1)(x)
    x=tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Conv2D(3,(7,7),padding='same',kernel_initializer=initializer)(x)
    x = InstanceNormalization(axis=-1)(x)
    x=tf.keras.layers.Activation('tanh')(x)
    return tf.keras.Model(inputs=inputs, outputs=x)


In [None]:
gen_h_to_z=make_generator()#horse to zebra generator
gen_z_to_h=make_generator()

In [None]:
tf.keras.utils.plot_model(gen_h_to_z, show_shapes=True, dpi=64)#plot the model

In [None]:
gen_h_to_z.summary()

In [None]:
fake_zebra=gen_h_to_z(sample_horse)
fake_horse=gen_z_to_h(sample_zebra)

In [None]:
images = [sample_horse,fake_zebra, sample_zebra, fake_horse]
title = ['Horse', 'Fake Zebra', 'Zebra', 'Fake Horse']
contrast=8
plt.figure(figsize=(8, 8))

for i in range(len(images)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(images[i][0] * 0.5 + 0.5)#for originals
  else:
    plt.imshow(images[i][0] * 0.5 * contrast + 0.5)#for generated
plt.show()


In [None]:
def make_discriminator():#we use a Patch discriminator
    inputs = tf.keras.Input(shape=(IMG_HEIGHT,IMG_WIDTH,3))
    initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    x = tf.keras.layers.Conv2D(64,(4,4),strides=2,padding='same',kernel_initializer=initializer)(inputs)
    x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    x = tf.keras.layers.Conv2D(128,(4,4),strides=2,padding='same',kernel_initializer=initializer)(x)
    x = InstanceNormalization(axis=-1)(x)
    x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    x = tf.keras.layers.Conv2D(256,(4,4),strides=2,padding='same',kernel_initializer=initializer)(x)
    x = InstanceNormalization(axis=-1)(x)
    x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    x = tf.keras.layers.Conv2D(512,(4,4),strides=2,padding='same',kernel_initializer=initializer)(x)
    x = InstanceNormalization(axis=-1)(x)
    x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    # x=tf.keras.layers.ZeroPadding2D()(x)
    x = tf.keras.layers.Conv2D(512,(4,4),padding='same',kernel_initializer=initializer)(x)
    x = InstanceNormalization(axis=-1)(x)
    x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    # x=tf.keras.layers.ZeroPadding2D()(x)
    x = tf.keras.layers.Conv2D(1,(4,4),padding='same',kernel_initializer=initializer)(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
dis_h=make_discriminator()
dis_z=make_discriminator()

In [None]:
tf.keras.utils.plot_model(dis_h, show_shapes=True, dpi=64)#plot the model

In [None]:
real=dis_h(sample_horse)
fake=dis_h(fake_horse)

In [None]:
#sample outputs of discriminator
plt.figure(figsize=(8,8))
plt.subplot(121)
plt.title('Discriminator Output for real horse')
plt.imshow(real[0,:,:,0],cmap='RdBu_r')
plt.subplot(122)
plt.title('Discriminator Output for fake horse')
plt.imshow(fake[0,:,:,0], cmap='RdBu_r')
plt.show()


In [None]:
squareloss=tf.keras.losses.MeanSquaredError()
Lambda=10#the weight for cycle loss


In [None]:
def discriminator_loss(real_output,fake_output):
  realloss=squareloss(tf.ones_like(real_output),real_output)
  fakeloss=squareloss(tf.zeros_like(fake_output),fake_output)
  return (realloss+fakeloss)*0.5

In [None]:
def generator_loss(fake_output):#it is the output generated by discriminator when fake image is inputted
       return squareloss(tf.ones_like(fake_output),fake_output)


In [None]:
def cycle_loss(real_input,cycled_input):
  return Lambda*tf.reduce_mean(tf.abs(real_input-cycled_input))


In [None]:
def identity_loss(real_input,generated_input):
  loss=tf.reduce_mean(tf.abs(real_input-generated_input))
  return loss*0.5*Lambda

In [None]:
generator_htoz_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_ztoh_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_h_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_z_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)


In [None]:
@tf.function
def train_step(horses,zebras):
    # persistent is set to True because the tape is used more than
    # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    fake_zebras=gen_h_to_z(horses,training=True)
    fake_horses=gen_z_to_h(zebras,training=True)

    cycled_horses=gen_z_to_h(fake_zebras,training=True)
    cycled_zebras=gen_h_to_z(fake_horses,training=True)

    real_horse_disc=dis_h(horses,training=True)
    fake_horse_disc=dis_h(fake_horses,training=True)

    real_zebra_disc=dis_z(zebras,training=True)
    fake_zebra_disc=dis_z(fake_zebras,training=True)

    generated_horse=gen_z_to_h(horses,training=True)
    generated_zebra=gen_h_to_z(zebras,training=True)

    disc_h_loss=discriminator_loss(real_horse_disc,fake_horse_disc)
    disc_z_loss=discriminator_loss(real_zebra_disc,fake_zebra_disc)

    gen_htoz_loss=generator_loss(fake_zebra_disc)
    gen_ztoh_loss=generator_loss(fake_horse_disc)

    cycleloss=cycle_loss(horses,cycled_horses)+cycle_loss(zebras,cycled_zebras)

    total_htoz_loss=gen_htoz_loss+cycleloss+identity_loss(zebras,generated_zebra)#adversialLoss+cycleLoss+IdentityLoss
    total_ztoh_loss=gen_ztoh_loss+cycleloss+identity_loss(horses,generated_horse)
    
  gen_htoz_gradients=tape.gradient(total_htoz_loss,gen_h_to_z.trainable_variables)
  gen_ztoh_gradients=tape.gradient(total_ztoh_loss,gen_z_to_h.trainable_variables)
  disc_h_gradients=tape.gradient(disc_h_loss,dis_h.trainable_variables)
  disc_z_gradients=tape.gradient(disc_z_loss,dis_z.trainable_variables)
  generator_htoz_optimizer.apply_gradients(zip(gen_htoz_gradients,gen_h_to_z.trainable_variables))
  generator_ztoh_optimizer.apply_gradients(zip(gen_ztoh_gradients,gen_z_to_h.trainable_variables))
  discriminator_h_optimizer.apply_gradients(zip(disc_h_gradients,dis_h.trainable_variables))
  discriminator_z_optimizer.apply_gradients(zip(disc_z_gradients,dis_z.trainable_variables))

     

In [None]:
def train(epochs):
 for epoch in range(epochs):
  start = time.time()
  n = 0
  count=0
  progbar = tf.keras.utils.Progbar(1067)#it is the number of datapoints as seen from info
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1
    count+=1
    progbar.update(count)
  clear_output(wait=True)
  show_output(gen_h_to_z, sample_horse)
  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,time.time()-start))


In [None]:
 def show_output(model,test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Real image', 'Generated Fake Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()


In [None]:
epochs=100
train(epochs)

In [None]:
for inp in test_horses.take(5):
  show_output(gen_h_to_z, inp)


In [None]:
for inp in test_zebras.take(5):
  show_output(gen_z_to_h, inp)
