**Importing required libraries**

In [None]:
BATCH_SIZE = 4
import re
import os
import math
import cv2
import random
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from tensorflow.keras import Model, losses, optimizers
from tensorflow.keras.callbacks import Callback
from kaggle_datasets import KaggleDatasets

**Enabling TPU for implementation**

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    tpu_strategy = tf.distribute.get_strategy()
print('Number of replicas:', tpu_strategy.num_replicas_in_sync)

AUTO_TUNE = tf.data.experimental.AUTOTUNE
    
print("version:",tf.__version__)

**Importing the dataset.**

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started')

monets_tfr = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
photos_tfr = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

monet_jpg = count_data_items(monets_tfr)
photo_jpg = count_data_items(photos_tfr)

EPOCHS = 30

print("Monet TFRecord files:", len(monets_tfr))
print("Monet image files:", monet_jpg)
print("Photo TFRecord files:", len(photos_tfr))
print("Photo image files:", photo_jpg)
print("EPOCHS:",EPOCHS)

**Adding a function to display sample images**

In [None]:
def view_data(dataset, nrows, ncols):
    ds_iter = iter(dataset)
    plt.figure(figsize=(15, int(15*nrows/ncols)))
    for j in range(nrows*ncols):
        monet_sample = next(ds_iter)
        plt.subplot(nrows,ncols,j+1)
        plt.axis('off')
        plt.imshow(monet_sample[0] * 0.5 + 0.5)
    plt.show()

**Setting up Image size**

In [None]:
IMAGE_SIZE = [256, 256]

We perform decoding of image and also read the images

All the images are sized to 256x256. As these images are RGB images, set the channel to 3. Additionally, we need to scale the images to a [-1, 1] scale. Because we are building a generative model, we don't need the labels or the image id so we'll only return the image from the TFRecord.

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

**IMAGE PRE-PROCESSING**

**Resizing image** (In this case it would not be necessary to do it, because the images are already in the necessary size. But with this step if we wanted to add new images it would not be necessary to scale them previously)

**Normalizing** the images to [-1, 1]

**Random jittering and mirroring** to the training dataset. These are some of the image augmentation techniques that avoids overfitting, Random jittering performs:

Resize an image to bigger height and width
Randomly crop to the target size
Randomly flip the image horizontally

In [None]:
def data_augment(image):
    rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    if crop > .5:
        image = tf.image.resize(image, [286, 286]) #resizing to 286 x 286 x 3
        image = tf.image.random_crop(image, size=[256, 256, 3]) # randomly cropping to 256 x 256 x 3
        if crop > .9:
            image = tf.image.resize(image, [300, 300])
            image = tf.image.random_crop(image, size=[256, 256, 3])
            
    if rotate > .9:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif rotate > .7:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif rotate > .5:
        image = tf.image.rot90(image, k=1) # rotate 90º
        
        ## random mirroring
    if spatial > .6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        if spatial > .9:
            image = tf.image.transpose(image)
    
    return image

In [None]:
def load_dataset(filenames):
    data = tf.data.TFRecordDataset(filenames)
    data = data.map(read_tfrecord, num_parallel_calls=AUTO_TUNE)
    return data

In [None]:
def get_gan_dataset(monet_files, photo_files, augment=None, repeat=True, shuffle=True, batch_size=1):

    monet_ds = load_dataset(monet_files)
    photo_ds = load_dataset(photo_files)
    
    if augment:
        monet_ds = monet_ds.map(augment, num_parallel_calls=AUTO_TUNE)
        photo_ds = photo_ds.map(augment, num_parallel_calls=AUTO_TUNE)
        
    if repeat:
        monet_ds = monet_ds.repeat()
        photo_ds = photo_ds.repeat()
    if shuffle:
        monet_ds = monet_ds.shuffle(2048)
        photo_ds = photo_ds.shuffle(2048)
        
    monet_ds = monet_ds.batch(batch_size, drop_remainder=True)
    photo_ds = photo_ds.batch(batch_size, drop_remainder=True)
    monet_ds = monet_ds.cache()
    photo_ds = photo_ds.cache()
    monet_ds = monet_ds.prefetch(AUTO_TUNE)
    photo_ds = photo_ds.prefetch(AUTO_TUNE)
    
    gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))
    
    return gan_ds

**load in our datasets.**

In [None]:
data = get_gan_dataset(monets_tfr, photos_tfr, augment=data_augment, repeat=True, shuffle=True, batch_size=BATCH_SIZE)

**Display sample monet images **

In [None]:
view_data(load_dataset(monets_tfr).batch(1), 2, 3)

**Display sample photo images**

In [None]:
view_data(load_dataset(photos_tfr).batch(1),2,3)

**Define the path for the monet and photo images**

In [None]:
BASE_PATH = '../input/gan-getting-started/'
MONET_PATH = os.path.join(BASE_PATH, 'monet_jpg')
PHOTO_PATH = os.path.join(BASE_PATH, 'photo_jpg')

**Batch visualization of photo and monet images**

In [None]:
def batch_visualization(path, n_images, is_random=True, figsize=(16, 16)):
    plt.figure(figsize=figsize)
    
    w = int(n_images ** .5)
    h = math.ceil(n_images / w)
    
    all_names = os.listdir(path)
    
    image_names = all_names[:n_images]
    if is_random:
        image_names = random.sample(all_names, n_images)
    
    for ind, image_name in enumerate(image_names):
        img = cv2.imread(os.path.join(path, image_name))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
        plt.subplot(h, w, ind + 1)
        plt.imshow(img)
        plt.axis('off')
    
    plt.show()

In [None]:
batch_visualization(MONET_PATH, 6, is_random=True, figsize=(16, 16))

In [None]:
batch_visualization(PHOTO_PATH, 6, is_random=True, figsize=(16, 16))

**Colour historgrams for monet and photo data**

In [None]:
def color_hist_visualization(image_path, figsize=(16, 4)):
    plt.figure(figsize=figsize)
    colors = ['red', 'green', 'blue']
    
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
    plt.subplot(1, 4, 1)
    plt.imshow(img)
    plt.axis('off')
    
    for i in range(len(colors)):
        plt.subplot(1, 4, i + 2)
        plt.hist(
            img[:, :, i].reshape(-1),
            bins=25,
            alpha=0.5,
            color=colors[i],
            density=True
        )
        plt.xlim(0, 255)
        plt.xticks([])
        plt.yticks([])
    plt.show()

In [None]:
image_1 = '../input/gan-getting-started/monet_jpg/0260d15306.jpg'
image_2 = '../input/gan-getting-started/photo_jpg/0033c5f971.jpg'
color_hist_visualization(image_1)
color_hist_visualization(image_2)

**Channel visualization of monet and photo data**

In [None]:
def channels_visualization(image_path, figsize=(16, 4)):
    plt.figure(figsize=figsize)
    
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
    plt.subplot(1, 4, 1)
    plt.imshow(img)
    plt.axis('off')
    
    for i in range(3):
        plt.subplot(1, 4, i + 2)
        tmp_img = np.full_like(img, 0)
        tmp_img[:, :, i] = img[:, :, i]
        plt.imshow(tmp_img)
        plt.xlim(0, 255)
        plt.xticks([])
        plt.yticks([])
    plt.show()

In [None]:
img_path = '../input/gan-getting-started/monet_jpg/0bd913dbc7.jpg'
channels_visualization(img_path)

**"Downsample" function will be created that passing the number of filters to it and if normalization is applied, it will create a keras.Sequential object**

In [None]:
def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.04)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

     
    result = keras.Sequential()
    # Convolutional layer
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))
 # Normalization layer
    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
 # Activation layer
    result.add(layers.LeakyReLU())

    return result

**"Upsample" function will be created that passing the number of filters to it and if dropout is applied, it will create a keras.Sequential object**

In [None]:
def upsample(filters, size, apply_dropout=False):
     # Normalization layer
    initializer = tf.random_normal_initializer(0., 0.04)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

     # Transpose convolutional layer
    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))
#Instance Normalization
    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
# Dropout layer
    if apply_dropout:
        result.add(layers.Dropout(0.5))
# Activation layer
    result.add(layers.ReLU())

    return result

**Buid the Generator**

In [None]:
OUTPUT_CHANNELS = 3

def Generator_PM():
    data = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_sample = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_sample = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]
    
    initialize = tf.random_normal_initializer(0., 0.02)
    final = layers.Conv2DTranspose(OUTPUT_CHANNELS, 7,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initialize,
                                  activation='tanh') # (bs, 256, 256, 3)

    inputs = data

    # Downsampling through the model
    skips = []
    for down in down_sample:
        inputs = down(inputs)
        skips.append(inputs)

    skip_connection = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_sample, skip_connection):
        inputs = up(inputs)
        inputs = layers.Concatenate()([inputs, skip])

    inputs = final(inputs)

    return keras.Model(inputs=data, outputs=inputs)

In [None]:
generator = Generator_PM()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

**Build the discriminator**

In [None]:
def Discriminator_PM():
    initialize = tf.random_normal_initializer(0., 0.02)
    init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    data = layers.Input(shape=[256, 256, 3], name='input_image')

    inputs = data

    down1 = downsample(64, 4, False)(inputs) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 7, strides=2,
                         kernel_initializer=initialize,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)
    
    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    final = layers.Conv2D(1, 7, strides=2,
                         kernel_initializer=initialize)(zero_pad2) # (bs, 30, 30, 1)
    
    return tf.keras.Model(inputs=inputs, outputs=final)

In [None]:
discriminator_y = Discriminator_PM()
tf.keras.utils.plot_model(discriminator_y, show_shapes=True, dpi=64)

**Generator part:**

Starting from the photo, a simulation of a Monet painting is generated and later from this simulation an attempt is made to generate the original photo

Starting from the monet, a photo simulation is generated and later from this simulation an attempt is made to generate the original monet

**Discriminator part:**

Discriminator so that the fake photo looks like a real photo
Discriminator so that the monet fake looks like a Monet painting

In [None]:
with tpu_strategy.scope():
    monet_generator = Generator_PM() # transforms photos to Monet-esque paintings
    photo_generator = Generator_PM() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator_PM() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator_PM() 

**Define the CycleGan class that inherits from Keras.model, this will allow overwriting the train_step function that is used in the fit method in such a way that performance can be maximized with the execution in TPU.**

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=20,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)
            
             # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)
        
        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

In [None]:
with tpu_strategy.scope():
    # Discriminator loss {0: fake, 1: real} (The discriminator loss outputs the average of the real and generated loss)
    def discriminator_loss(real, generated):
        real_loss = losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5
    
    # Generator loss
    def generator_loss(generated):
        return losses.BinaryCrossentropy(from_logits=True, reduction=losses.Reduction.NONE)(tf.ones_like(generated), generated)
    
    
    # Cycle consistency loss (measures if original photo and the twice transformed photo to be similar to one another)
    with tpu_strategy.scope():
        def calc_cycle_loss(real_image, cycled_image, LAMBDA):
            loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

            return LAMBDA * loss1

    # Identity loss (compares the image with its generator (i.e. photo with photo generator))
    with tpu_strategy.scope():
        def identity_loss(real_image, same_image, LAMBDA):
            loss = tf.reduce_mean(tf.abs(real_image - same_image))
            return LAMBDA * 0.5 * loss

In [None]:
with tpu_strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

**Compile the CycleGAN model**

In [None]:
with tpu_strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, 
        monet_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

**Fit the CycleGAN model**

In [None]:
cycle_gan_model.fit(
    data,
    epochs=30,
    steps_per_epoch=(max(monet_jpg, photo_jpg)//BATCH_SIZE),
#     steps_per_epoch=1500
)

**Predict and save generated images**

In [None]:
import PIL
def predict_and_save(input_ds, generator_model, output_path):
    i = 1
    for img in input_ds:
        prediction = generator_model(img, training=False)[0].numpy() # make predition
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   # re-scale
        im = PIL.Image.fromarray(prediction)
        im.save(f'{output_path}{str(i)}.jpg')
        i += 1

In [None]:
import os
os.makedirs('../images/') # Create folder to save generated images

predict_and_save(load_dataset(photos_tfr).batch(1), monet_generator, '../images/')

**Create a zip folder with the generated images**

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")

In [None]:
print(f"Number of generated samples: {len([name for name in os.listdir('../images/') if os.path.isfile(os.path.join('../images/', name))])}")

**Display the generated samples for the test data**

In [None]:
def display_generated_samples(ds, model, n_samples):
    ds_iter = iter(ds)
    for n_sample in range(n_samples):
        example_sample = next(ds_iter)
        generated_sample = model.predict(example_sample)
    
        plt.subplot(121)
        plt.title("Input image")
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')

        plt.subplot(122)
        plt.title("Generated image")
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

In [None]:
display_generated_samples(load_dataset(photos_tfr).batch(1), monet_generator, 7)