In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
def random_crop(image):
    cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
    
    return cropped_image

In [None]:
# normalizing images to -1, 1
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    
    return image

In [None]:
#random jittering
def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
    image = random_crop(image)
  
    image = tf.image.random_flip_left_right(image)
    
    return image

In [None]:
def preprocess_image_train(image, label):
    image = random_jitter(image)
    image = normalize(image)
    
    return image

In [None]:
def preprocess_image_test(image, label):
    image = normalize(image)
    
    return image

In [None]:
train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

In [None]:
test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

In [None]:
#picking sample horse
sample_horse = next(iter(train_horses))

#plotting the picked sample
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

In [None]:
#picking sample zebra
sample_zebra = next(iter(train_zebras))

#plotting the picked sample
plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)

In [None]:
OUTPUT_CHANNELS = 3

generator_g1 = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_g2 = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_d1 = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_d2 = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [None]:
#let's check our generators
to_zebra = generator_g1(sample_horse)
to_horse = generator_g2(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
    plt.subplot(2, 2, i+1)
    plt.title(title[i])
    if i % 2 == 0:
        plt.imshow(imgs[i][0] * 0.5 + 0.5)
    else:
        plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

In [None]:
#let's check our discriminators
plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_d1(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_d2(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

In [None]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
    real_loss = loss_obj(tf.ones_like(real), real)

    generated_loss = loss_obj(tf.zeros_like(generated), generated)

    total_disc_loss = real_loss + generated_loss
    
    return total_disc_loss * 0.5

def generator_loss(generated):
    return loss_obj(tf.ones_like(generated), generated)

In [None]:
LAMBDA = 10

def calc_cycle_loss(real_image, cycled_image):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
    
    return LAMBDA * loss1

In [None]:
def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    
    return LAMBDA * 0.5 * loss

In [None]:
EPOCHS=170

In [None]:
#helper function to generate images
def generate_images(model, test_input):
    prediction = model(test_input)
    plt.figure(figsize=(12, 12))
    
    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']
    
    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
        
    plt.show()

In [None]:
generator_g1_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_g2_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_d1_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_d2_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g1=generator_g1,
                           generator_g2=generator_g2,
                           discriminator_d1=discriminator_d1,
                           discriminator_d2=discriminator_d2,
                           generator_g1_optimizer=generator_g1_optimizer,
                           generator_g2_optimizer=generator_g2_optimizer,
                           discriminator_d1_optimizer=discriminator_d1_optimizer,
                           discriminator_d2_optimizer=discriminator_d2_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [None]:
@tf.function
def train_step(real_x, real_y):
    with tf.GradientTape(persistent=True) as tape:
        
    # Generator G1 translates X -> Y
    # Generator G2 translates Y -> X.
    #step 1
        fake_y = generator_g1(real_x, training=True)
        cycled_x = generator_g2(fake_y, training=True)

        fake_x = generator_g2(real_y, training=True)
        cycled_y = generator_g1(fake_x, training=True)

    # same_x and same_y are used for identity loss.
        same_x = generator_g2(real_x, training=True)
        same_y = generator_g1(real_y, training=True)

        disc_real_x = discriminator_d1(real_x, training=True)
        disc_real_y = discriminator_d2(real_y, training=True)

        disc_fake_x = discriminator_d1(fake_x, training=True)
        disc_fake_y = discriminator_d2(fake_y, training=True)

    # step 2 calculate the loss
        gen_g1_loss = generator_loss(disc_fake_y)
        gen_g2_loss = generator_loss(disc_fake_x)
    
        total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    #step 2 Total generator loss = adversarial loss + cycle loss
        total_gen_g1_loss = gen_g1_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_gen_g2_loss = gen_g2_loss + total_cycle_loss + identity_loss(real_x, same_x)

        disc_d1_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_d2_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # step 3 Calculate the gradients for generator and discriminator
    generator_g1_gradients = tape.gradient(total_gen_g1_loss, 
                                        generator_g1.trainable_variables)
    generator_g2_gradients = tape.gradient(total_gen_g2_loss, 
                                        generator_g2.trainable_variables)
  
    discriminator_d1_gradients = tape.gradient(disc_d1_loss, 
                                            discriminator_d1.trainable_variables)
    discriminator_d2_gradients = tape.gradient(disc_d2_loss, 
                                            discriminator_d2.trainable_variables)
  
  # step 4 Apply gradients to the optimizer
    generator_g1_optimizer.apply_gradients(zip(generator_g1_gradients, 
                                            generator_g1.trainable_variables))

    generator_g2_optimizer.apply_gradients(zip(generator_g2_gradients, 
                                            generator_g2.trainable_variables))
  
    discriminator_d1_optimizer.apply_gradients(zip(discriminator_d1_gradients,
                                                discriminator_d1.trainable_variables))
  
    discriminator_d2_optimizer.apply_gradients(zip(discriminator_d2_gradients,
                                                discriminator_d2.trainable_variables))

In [None]:
for epoch in range(EPOCHS):
    start = time.time()
    n = 0
    
    for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
        train_step(image_x, image_y)
        
        if n % 10 == 0:
            print ('-', end='')
        n+=1 

    clear_output(wait=True)    
# Using a consistent image (sample_horse) so that the progress of the model
# is clearly visible.
    generate_images(generator_g1, sample_horse)

    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))
    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

In [None]:
# Run the trained model on the test dataset 
for inp in test_zebras.take(5):
  generate_images(generator_g1, inp)