In [None]:
import os

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

import matplotlib.pyplot as plt
import numpy as np

In [None]:
image_shape=[256, 256]

## GPU 

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

# Load Data

In [None]:
data_path = '/kaggle/input/gan-getting-started'
monet_files = tf.io.gfile.glob(str(data_path + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files: ', len(monet_files))

photo_files = tf.io.gfile.glob(str(data_path + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(photo_files))

In [None]:
def print_tfrecord_content(example_proto):
    return tf.train.Example.FromString(example_proto.numpy())

dataset = tf.data.TFRecordDataset([monet_files[0]])

for record in dataset.take(1):
    example = print_tfrecord_content(record)
    print(example.features.feature.keys())

In [None]:
def load_tfr(file):
        
    forma = {'image_name': tf.io.FixedLenFeature([], tf.string),
            'image': tf.io.FixedLenFeature([], tf.string),
            'target': tf.io.FixedLenFeature([], tf.string)}
    
    tfr_image = tf.io.parse_single_example(file, forma)
    
    image = tf.image.decode_jpeg(tfr_image['image'], channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*image_shape, 3])
    
    return image

dataset_monet = tf.data.TFRecordDataset(monet_files)
dataset_monet = dataset_monet.map(load_tfr, num_parallel_calls=AUTOTUNE).batch(1)

dataset_photo = tf.data.TFRecordDataset(photo_files)
dataset_photo = dataset_photo.map(load_tfr, num_parallel_calls=AUTOTUNE).batch(1)

In [None]:
for i in dataset_monet.take(1):
    plt.subplot(121)
    plt.imshow(i[0] * 0.5 + 0.5)
    
for i in dataset_photo.take(1):
    plt.subplot(122)
    plt.imshow(i[0] * 0.5 + 0.5)
plt.show()

# CycleGAN

## Generator

In [None]:
def Res_block(in_x, f):
    
    init = tf.random_normal_initializer(0., 0.02) 
    
    r = layers.Conv2D(f, (3, 3), strides=1, padding='same', kernel_initializer=init)(in_x)
    r = tfa.layers.InstanceNormalization(axis=-1)(r)
    r = layers.Activation('relu')(r)
    
    r = layers.Conv2D(f, (3, 3), strides=1, padding='same', kernel_initializer=init)(r)
    r = tfa.layers.InstanceNormalization(axis=-1)(r)
    
    r = layers.Concatenate()([r, in_x])
    
    return r


def make_generator(res=9):
    
    inp = layers.Input(shape=[*image_shape, 3])
    
    init = tf.random_normal_initializer(0., 0.02)   
    
    x = layers.Conv2D(64, (7, 7), strides=1, padding='same', kernel_initializer=init)(inp)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(128, (3, 3), strides=2, padding='same', kernel_initializer=init)(x)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(256, (3, 3), strides=2, padding='same', kernel_initializer=init)(x)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x = layers.Activation('relu')(x)
    
    for i in range(res):
        x = Res_block(x, 256)
        
    x = layers.Conv2DTranspose(128, (3, 3), strides=2, padding='same', kernel_initializer=init)(x)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2DTranspose(64, (3, 3), strides=2, padding='same', kernel_initializer=init)(x)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2DTranspose(3, (7, 7), padding='same', kernel_initializer=init)(x)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x_out = layers.Activation('relu')(x)
    
    model = keras.Model(inputs=inp, outputs=x_out)
    return model
    

## Discriminator

In [None]:
def make_discriminator():
    
    inp = layers.Input(shape=[*image_shape, 3])
    init = tf.random_normal_initializer(0, 0.02)   
    
    x = layers.Conv2D(64, (4, 4), strides=2, padding='same', kernel_initializer=init)(inp)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    
    x = layers.Conv2D(128, (4, 4), strides=2, padding='same', kernel_initializer=init)(x)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    
    x = layers.Conv2D(256, (4, 4), strides=2, padding='same', kernel_initializer=init)(x)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)  
    
    x = layers.Conv2D(512, (4, 4), strides=2, padding='same', kernel_initializer=init)(x)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    
    x = layers.Conv2D(512, (4, 4), strides=2, padding='same', kernel_initializer=init)(x)
    x = tfa.layers.InstanceNormalization(axis=-1)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    
    x_out = layers.Conv2D(1, (4, 4), padding='same', kernel_initializer=init)(x)
    
    model = keras.Model(inputs=inp, outputs=x_out)
    return model
    

In [None]:
with strategy.scope():
    monet_generator = make_generator() 
    photo_generator = make_generator() 

    monet_discriminator = make_discriminator() 
    photo_discriminator = make_discriminator()

In [None]:
pic_test = next(iter(dataset_photo))
monet_test = next(iter(dataset_monet))

In [None]:
to_monet = monet_generator(pic_test)

plt.subplot(1, 2, 1)
plt.title("Original")
plt.imshow(pic_test[0] * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title("Monet-esque")
plt.imshow(to_monet[0] * 0.5 + 0.5)
plt.show()

# Complete model

### Losses

In [None]:
with strategy.scope():
    
    def disc_loss(real, fake):
        
        return tf.reduce_mean((real - fake)**2)
    
    def gen_loss(gen):
        return tf.reduce_mean((tf.ones_like(gen) - gen)**2)
    
    def cycle_loss(real, cycled):
        return tf.reduce_mean(tf.abs(real - cycled))
    
    def id_loss(og, gen):
        return tf.reduce_mean(tf.abs(og - gen))

In [None]:
class CycleGAN(tf.keras.Model):
    
    def __init__(self, monet_g, image_g, monet_d, image_d, input_shape=[256, 256, 3]):
        super(CycleGAN, self).__init__()
        self.monet_g = monet_g
        self.image_g = image_g
        self.monet_d = monet_d
        self.image_d = image_d
        self.in_shape = input_shape

        
    def compile(self, monet_g_opt, image_g_opt, monet_d_opt, image_d_opt, cycle_loss, id_loss, g_loss, d_loss):
        
        super(CycleGAN, self).compile()
        
        self.monet_g_opt = monet_g_opt
        self.image_g_opt = image_g_opt
        self.monet_d_opt = monet_d_opt
        self.image_d_opt = image_d_opt
        self.g_loss = g_loss
        self.d_loss = d_loss
        self.cycle_loss = cycle_loss
        self. id_loss = id_loss
        
        
    def train_step(self, batch):
        
        real_monet, real_image = batch
        
        with tf.GradientTape(persistent=True) as tp:
            
            gen_monet = self.monet_g(real_image)
            gen_image = self.image_g(real_monet)

            monet_d_real = self.monet_d(real_monet)
            monet_d_fake = self.monet_d(gen_monet)

            image_d_real = self.image_d(real_image)
            image_d_fake = self.image_d(gen_image)

            cycled_m = self.monet_g(gen_image)
            cycled_image = self.image_g(gen_monet)
            
            id_m = self.monet_g(real_monet)
            id_i = self.image_g(real_image)
            
            # Losses
            
            # Adversarial            
            monet_g_loss = self.g_loss(monet_d_fake)
            image_g_loss = self.g_loss(image_d_fake)
            
            # Cycle
            cycle_loss = 10 * self.cycle_loss(real_image, cycled_image) + 10 * self.cycle_loss(real_monet, cycled_m)
                        
            # Identity
            monet_id_loss = 5 * self.id_loss(real_monet, id_m)
            image_id_loss = 5 * self.id_loss(real_image, id_i)
            
            # Total G loss
            monet_g_total = monet_g_loss + cycle_loss + monet_id_loss
            image_g_total = image_g_loss + cycle_loss + image_id_loss
            
            # Disc loss
            monet_d_loss = self.d_loss(real=monet_d_real, fake=monet_d_fake)
            image_d_loss = self.d_loss(real=image_d_real, fake=image_d_fake)
            
        # Gradients
        
        monet_g_grad = tp.gradient(monet_g_total, self.monet_g.trainable_variables)
        image_g_grad = tp.gradient(image_g_total, self.image_g.trainable_variables)
        
        monet_d_grad  = tp.gradient(monet_d_loss, self.monet_d.trainable_variables)
        image_d_grad = tp.gradient(image_d_loss, self.image_d.trainable_variables)
        
        self.monet_g_opt.apply_gradients(zip(monet_g_grad, self.monet_g.trainable_variables))
        self.image_g_opt.apply_gradients(zip(image_g_grad, self.image_g.trainable_variables))
        self.monet_d_opt.apply_gradients(zip(monet_d_grad, self.monet_d.trainable_variables))
        self.image_d_opt.apply_gradients(zip(image_d_grad, self.image_d.trainable_variables))
        
        return {'monet_g_loss': monet_g_total,
               'image_g_loss': image_g_total,
               'monet_d_loss': monet_d_loss,
               'image_d_loss': image_d_loss}
            
        
        

In [None]:
with strategy.scope():
    
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    
    cycle_gan_model = CycleGAN(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        monet_g_opt = monet_generator_optimizer,
        image_g_opt = photo_generator_optimizer,
        monet_d_opt = monet_discriminator_optimizer,
        image_d_opt = photo_discriminator_optimizer,
        g_loss = gen_loss,
        d_loss = disc_loss,
        cycle_loss = cycle_loss,
        id_loss = id_loss
    )

In [None]:
cycle_gan_model.fit(
    tf.data.Dataset.zip((dataset_monet, dataset_photo)),
    epochs=100
)

In [None]:
cycle_gan_model.save_weights('Cycle_GAN_weights.h5')

In [None]:
_, ax = plt.subplots(5, 2, figsize=(12, 12))
for i, img in enumerate(dataset_photo.take(5)):
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()