# Introduction and Setup

Our problem may be thought of as an image-to-image conversion. The purpose is to learn a mapping G: X->Y that converts an image x in X to y in Y. We will be building an architecture similar to that of CycleGAN, using the U-net architecture in the generator.

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import seaborn as sns
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
import os, random, json, PIL, shutil, re, imageio, glob
from tensorflow.keras.callbacks import Callback
from glob import glob
import cv2 

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE

tf.random.set_seed(0)
    
print(tf.__version__)

# Load in the data

First, load in the Monet painting (TFRecords).

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('monet-gan-getting-started')
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
n_monet_samples = count_data_items(MONET_FILENAMES)

In [None]:
IMAGE_SIZE = [256, 256]
BUFFER_SIZE = 1000
BATCH_SIZE =  4
EPOCHS_NUM = 15
IMG_WIDTH = 256
IMG_HEIGHT = 256
EPOCHS = 15

def decode_image(image): 
    image = tf.image.decode_jpeg(image, channels=3)#Decode a JPEG-encoded image to a uint8 tensor. contents:A Tensor of type string,所以下面用的tf.string
    image = (tf.cast(image, tf.float32) / 127.5) - 1 #unit8: Unsigned Integers of 8 bits. 
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example): 
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string), 
        "image": tf.io.FixedLenFeature([], tf.string), 
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)#（A scalar string Tensor, a single serialized Example，A dict mapping feature keys to FixedLenFeature or VarLenFeature values.）
    #Return A dict mapping feature keys to Tensor and SparseTensor values. 
    image = decode_image(example['image'])
    return image

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames) #A Dataset comprising records from one or more TFRecord files. dataset: tf.data.Dataset
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)#好家伙还是得看看TPU是个啥。 但先猜这行是用来TPU提速的。
    
    return dataset

In [None]:
monet_ds = load_dataset(MONET_FILENAMES, labeled=True).batch(1) 
monet_file  = load_dataset(MONET_FILENAMES, labeled=True)

Then we load Van Gogh painting(JPGs)

In [None]:
main_path = '/kaggle/input/cyclegan-model/van-gogh-paintings/'
style_img_paths = []
for class_path in [os.path.join(main_path,class_name) for class_name in os.listdir(main_path)]:
    
    class_img_paths = glob(class_path+"/*")
    for class_img_path in class_img_paths:
        style_img_paths.append(class_img_path)

print("There are {} style images in Van Gogh Paintings Dataset".format(len(style_img_paths)))

In [None]:
style_images = []
for style_path in style_img_paths:
    img = cv2.imread(style_path)
    img = cv2.resize(img,(256,256))
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    style_images.append(img)
    
n_van_samples = len(style_images)
n_van_samples

In [None]:
style_images = np.array(style_images,dtype=np.float32)
# scaling between -1 and 1
style_images = style_images / 127.5 - 1
# batching
van_file = tf.data.Dataset.from_tensor_slices(style_images)
van_ds = tf.data.Dataset.from_tensor_slices(style_images).batch(1)

In [None]:
def data_augment(image):
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # Apply jitter
    if p_crop > .5:
        image = tf.image.resize(image, [286, 286])
        image = tf.image.random_crop(image, size=[256, 256, 3])
        if p_crop > .9:
            image = tf.image.resize(image, [300, 300])
            image = tf.image.random_crop(image, size=[256, 256, 3])
    
    # Random rotation
    if p_rotate > .9:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .7:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=1) # rotate 90º
    
    # Random mirroring
    if p_spatial > .6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        if p_spatial > .9:
            image = tf.image.transpose(image)
    
    return image

In [None]:
def get_gan_dataset(monet_file, van_file,augment=None, repeat=True, shuffle=True, batch_size=1):
    
    if augment:
        monet_file = monet_file.map(augment, num_parallel_calls=AUTOTUNE)
        van_file = van_file.map(augment, num_parallel_calls=AUTOTUNE)

    if repeat:
        monet_file = monet_file.repeat()
        van_file = van_file.repeat()
        
    if shuffle:
        monet_file = monet_file.shuffle(1024, reshuffle_each_iteration=True)
        van_file = van_file.shuffle(1024, reshuffle_each_iteration=True)
        
    monet_file = monet_file.batch(batch_size, drop_remainder=True)
    van_file = van_file.batch(batch_size, drop_remainder=True)
    van_file = van_file.cache()
    monet_file = monet_file.cache()
    monet_file = monet_file.prefetch(AUTOTUNE)
    van_file = van_file.prefetch(AUTOTUNE)
    
    gan_ds = tf.data.Dataset.zip((monet_file, van_file))
    
    return gan_ds

In [None]:
full_dataset = get_gan_dataset(monet_file, van_file, augment=data_augment, repeat=True, shuffle=True, batch_size=BATCH_SIZE)

In [None]:
example_monet , example_van_goph = next(iter(full_dataset))

Let's  visualize a van example and a Monet example.

In [None]:
def view_image(ds, nrows=1, ncols=5):
    ds_iter = iter(ds)
    # image = next(iter(ds)) # extract 1 from the dataset
    # image = image.numpy()  # convert the image tensor to NumPy ndarrays.

    fig = plt.figure(figsize=(25, nrows * 5.05 )) # figsize with Width, Height
    
    # loop thru all the images (number of rows * number of columns)
    for i in range(ncols * nrows):
        image = next(ds_iter)
        image = image.numpy()
        ax = fig.add_subplot(nrows, ncols, i+1, xticks=[], yticks=[])
        ax.imshow(image[0] * 0.5 + .5) # rescale the data in [0, 1] for display

In [None]:
view_image(monet_ds,2, 5)

In [None]:
view_image(van_ds,2, 5)

# Build the generator

Our problem may be thought of as an image-to-image conversion. The purpose is to learn a mapping G: X->Y that converts an image x in X to y in Y. We will be building an architecture similar to that of CycleGAN, using the U-net architecture in the generator.
Our goal is to train two generator models using the CycleGAN architecture. The two generator models generate images from the input domain to the target domain and back. We also train two discriminator models to differentiate the images in the two domains. The ultimate goal is for the two generator models to generate images that cannot be distinguished by the discriminator models. In addition, we would also like to ensure cycle consistency by making sure that an image can be transformed from the input style to the target style and back again. This cycle consistency ensures that we have x -> G_X(x) -> G_Y(G_X(x)) -> x, where x is an image in the input domain, G_X is the generator model that generates a target image from the input image, and G_Y is the reverse generator model that generates an input image from the target image. Cycle consistency ensures that the input image matches the final output image by applying two generator transformations.


In [None]:
OUTPUT_CHANNELS = 3

def downsample(filters, size,stride=2, apply_instancenorm=True):#filter = Integer, the dimensionality of the output space,size = kernel size
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02) #Initializer that generates tensors with a normal distribution
 
    result = keras.Sequential() #Sequential groups a linear stack of layers into a tf.keras.Model.
    #from tensorflow.keras import layers
    result.add(layers.Conv2D(filters, size, strides=stride, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)) #Instance normalization layer. Inherits From: GroupNormalization
        #tfa.layers.GroupNormalization:Group normalization layer.
    

    result.add(layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size,stride=2, apply_dropout=False):#小变大
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    #tf.keras.layers.Conv2DTranspose: Transposed convolution layer (sometimes called Deconvolution).
    result.add(layers.Conv2DTranspose(filters, size, strides=stride, 
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

For the generator, we plan to apply seven convolution layers with a kernel size of 4 for the downsampling, seven transpose convolution layers with a kernel size of 4 for the upsampling, and concatenate the downsampling and upsampling layers in the same depth.

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3]) #256*256*3

    # bs = batch size
    down_stack = [
        downsample(64, 7, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]#conv2d layer

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]#Tansposeconv2d layer

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 7,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skip_connection = 7
    skips = []
    for down in down_stack: #down 
        x = down(x) 
        skips.append(x)
    skips = reversed(skips[:-1])
    
    # Upsampling and establishing the skip connections

    for up, skip in zip(up_stack, skips):
        x = up(x)
        if skip_connection <= 4:
            x = layers.Concatenate()([x, skip])
        skip_connection = skip_connection-1

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

# Build the discriminator

Since a powerful discriminator will cause low performance for the whole model, it will be a simple convolution neural network.

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4,2, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4,2)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4,2)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256) 
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
with strategy.scope(): 
    monet_generator = Generator() # transforms vans to Monet-esque paintings
    van_generator = Generator() # transforms Monet paintings to be more like vans

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    van_discriminator = Discriminator() # differentiates real vans and generated vans

# Build the CycleGAN model



In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        van_generator,
        monet_discriminator,
        van_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__() 
        self.m_gen = monet_generator
        self.p_gen = van_generator
        self.m_disc = monet_discriminator
        self.p_disc = van_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile( #Configures the model for training.
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_van = batch_data
        
        with tf.GradientTape(persistent=True) as tape: #Record operations for automatic differentiation. persistent=True 
            #Boolean controlling whether a persistent gradient tape is created. False by default, 
            #which means at most one call can be made to the gradient() method on this object.
            
            # van to monet back to van  
            fake_monet = self.m_gen(real_van, training=True)
            cycled_van = self.p_gen(fake_monet, training=True)

            # monet to van back to monet
            fake_van = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_van, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_van = self.p_gen(real_van, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_van = self.p_disc(real_van, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_van = self.p_disc(fake_van, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            van_gen_loss = self.gen_loss_fn(disc_fake_van)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_van, cycled_van, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_van_gen_loss = van_gen_loss + total_cycle_loss + self.identity_loss_fn(real_van, same_van, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            van_disc_loss = self.disc_loss_fn(disc_real_van, disc_fake_van)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        van_generator_gradients = tape.gradient(total_van_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        van_discriminator_gradients = tape.gradient(van_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(van_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(van_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "van_gen_loss": total_van_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "van_disc_loss": van_disc_loss
        }

# Loss functions

Due to the lack of paired data to train on in CycleGAN, there is no guarantee that the input x and target y pair will be relevant throughout training. The author from the original paper suggested we use the cycle consistency loss (the output should be close to the original input) to ensure that the network learns the correct mapping. Therefore, for the generator, we will compare the result with the value of the true label (suppose we use 1) to let the generator learn how to make the target image. For the discriminator, we will compare the real photo with the value of the true label and the generated photo with the value of the fake label (suppose we use 0) since we want the discriminator to learn to distinguish between the true and fake. For the current stage, we plan to use BinaryCrossEntropy to calculate the loss. We will use Adam as our optimizer.


In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)
        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)
        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

In [None]:
with strategy.scope():
    def generator_loss(generated):
       # return tf.square(tf.ones_like(generated) - generated)
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        #loss1 = tf.reduce_mean(tf.square(real_image - cycled_image))
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
        return LAMBDA * loss1

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        #loss = tf.reduce_mean(tf.square(real_image - same_image))
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

# Train the CycleGAN



From the original paper, a small learning rate at the end of the training will be helpful

In [None]:
@tf.function
def linear_schedule_with_warmup(step):
    """ Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.
    """
    lr_start   = 2e-4
    lr_max     = 2e-4
    lr_min     = 0.
    
    steps_per_epoch = int(max(n_monet_samples, n_van_samples)//BATCH_SIZE)
    total_steps = EPOCHS * steps_per_epoch
    warmup_steps = 1
    hold_max_steps = total_steps * 0.8
    
    if step < warmup_steps:
        lr = (lr_max - lr_start) / warmup_steps * step + lr_start
    elif step < warmup_steps + hold_max_steps:
        lr = lr_max
    else:
        lr = lr_max * ((total_steps - step) / (total_steps - warmup_steps - hold_max_steps))
        if lr_min is not None:
            lr = tf.math.maximum(lr_min, lr)

    return lr

steps_per_epoch = int(max(n_monet_samples, n_van_samples)//BATCH_SIZE)
total_steps = EPOCHS * steps_per_epoch
rng = [i for i in range(0, total_steps, 50)]
y = [linear_schedule_with_warmup(x) for x in rng]

sns.set(style="whitegrid")
fig, ax = plt.subplots(figsize=(20, 6))
plt.plot(rng, y)
print(f'{EPOCHS} total epochs and {steps_per_epoch} steps per epoch')
print(f'Learning rate schedule: {y[0]:.3g} to {max(y):.3g} to {y[-1]:.3g}')

In [None]:
# Callbacks
class GANMonitor(Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, num_img=1, monet_path='monet', van_path='van'):
        self.num_img = num_img
        self.monet_path = monet_path
        self.van_path = van_path
        self.epoch = 0
        self.N = 10
        # Create directories to save the generate images
        if not os.path.exists(self.monet_path):
            os.makedirs(self.monet_path)
        if not os.path.exists(self.van_path):
            os.makedirs(self.van_path)

    def on_epoch_end(self, epoch, logs=None):
        # Monet generated images
        for i, img in enumerate(van_ds.take(self.num_img)):
            prediction = monet_generator(img, training=False)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            prediction = PIL.Image.fromarray(prediction)
            prediction.save(f'{self.monet_path}/generated_{i}_{epoch+1}.png')
            
        # van generated images
        for i, img in enumerate(monet_ds.take(self.num_img)):
            prediction = van_generator(img, training=False)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            prediction = PIL.Image.fromarray(prediction)
            prediction.save(f'{self.van_path}/generated_{i}_{epoch+1}.png')
            
        if self.epoch % self.N == 0:
            name1 = 'monet_generator%02d.h5' % self.epoch
            name2 = 'van_generator%02d.h5' % self.epoch
            name3 = 'monet_discriminator%02d.h5' % self.epoch
            name4 = 'van_discriminator%02d.h5' % self.epoch
            monet_generator.save(name1)
            van_generator.save(name2)
            monet_discriminator.save(name3)
            van_discriminator.save(name4)
        self.epoch += 1

In [None]:
with strategy.scope():
    lr_monet_gen = lambda: linear_schedule_with_warmup(tf.cast(monet_generator_optimizer.iterations, tf.float32))
    lr_van_gen = lambda: linear_schedule_with_warmup(tf.cast(van_generator_optimizer.iterations, tf.float32))
    
    monet_generator_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_monet_gen, beta_1=0.5)
    van_generator_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_van_gen, beta_1=0.5)

    # Create discriminators
    lr_monet_disc = lambda: linear_schedule_with_warmup(tf.cast(monet_discriminator_optimizer.iterations, tf.float32))
    lr_van_disc = lambda: linear_schedule_with_warmup(tf.cast(van_discriminator_optimizer.iterations, tf.float32))
    
    monet_discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_monet_disc, beta_1=0.5)
    van_discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_van_disc, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, van_generator, monet_discriminator, van_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = van_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = van_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

In [None]:
with strategy.scope():
    history = cycle_gan_model.fit(full_dataset,
                        callbacks=[GANMonitor()],
                        epochs=EPOCHS,
                        steps_per_epoch=(max(n_monet_samples, n_van_samples)//BATCH_SIZE), 
                        verbose=1).history

In [None]:
history.keys()

# Plot Loss

In [None]:
plt.plot(np.append(np.mean(np.array(history['monet_gen_loss']),axis = (1,2,3)),np.mean(np.array(temphistoy['monet_gen_loss']),axis = (1,2,3))))
plt.plot(np.append(np.mean(np.array(history['van_gen_loss']),axis = (1,2,3)),np.mean(np.array(temphistoy['van_gen_loss']),axis = (1,2,3))))
plt.title('Generator Loss')
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.legend(['Van Gogh', 'photo'], loc='upper left')
plt.show()


plt.plot(np.append(np.mean(np.array(history['monet_disc_loss']),axis = (1,2,3)),np.mean(np.array(temphistoy['monet_disc_loss']),axis = (1,2,3))))
plt.plot(np.append(np.mean(np.array(history['van_disc_loss']),axis = (1,2,3)),np.mean(np.array(temphistoy['van_disc_loss']),axis = (1,2,3))))
plt.title('Discriminator Loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['Van Gogh', 'photo'], loc='upper left')
plt.show()

# Visualize the Training Result

In [None]:
def evaluate_cycle(ds, generator_a, generator_b, n_samples=1):
    fig, axes = plt.subplots(n_samples, 3, figsize=(22, (n_samples*6)))
    axes = axes.flatten()
    
    ds_iter = iter(ds)
    for n_sample in range(n_samples):
        idx = n_sample*3
        example_sample = next(ds_iter)
        generated_a_sample = generator_a.predict(example_sample)
        generated_b_sample = generator_b.predict(generated_a_sample)
        
        axes[idx].set_title('Input image', fontsize=18)
        axes[idx].imshow(example_sample[0] * 0.5 + 0.5)
        axes[idx].axis('off')
        
        axes[idx+1].set_title('Generated image', fontsize=18)
        axes[idx+1].imshow(generated_a_sample[0] * 0.5 + 0.5)
        axes[idx+1].axis('off')
        
        axes[idx+2].set_title('Cycled image', fontsize=18)
        axes[idx+2].imshow(generated_b_sample[0] * 0.5 + 0.5)
        axes[idx+2].axis('off')
        
    plt.show()

In [None]:
evaluate_cycle(van_ds.take(5), monet_generator, van_generator, n_samples=5)

In [None]:
evaluate_cycle(monet_ds.take(5), van_generator,monet_generator, n_samples=5)

# Saving the Model

In [None]:
monet_generator.save('vanphoto_generator.h5')
van_generator.save('photovan_generator.h5')
monet_discriminator.save('vanphoto_discriminator.h5')
van_discriminator.save('photovan_discriminator.h5')

# Visualize Trained Model for the Final Report

Here is our trained model, use it see the result

In [None]:
with strategy.scope():
    monet_generator = keras.models.load_model('/kaggle/input/cyclegan-model/monet_generator30.h5')
    photo_monet_generator = keras.models.load_model('/kaggle/input/cyclegan-model/photo_generator30.h5')
    monet_discriminator = keras.models.load_model('/kaggle/input/cyclegan-model/monet_discriminator30.h5')
    photo_monet_discriminator = keras.models.load_model('/kaggle/input/cyclegan-model/photo_discriminator30.h5')    
    van_generator = keras.models.load_model('/kaggle/input/cyclegan-model/vanphoto_generator1.h5')
    photo_van_generator = keras.models.load_model('/kaggle/input/cyclegan-model/photovan_generator1.h5')
    van_discriminator = keras.models.load_model('/kaggle/input/cyclegan-model/vanphoto_discriminator1.h5')
    photo_van_monet_discriminator = keras.models.load_model('/kaggle/input/cyclegan-model/photovan_discriminator1.h5')

In [None]:
evaluate_cycle(monet_ds.take(25), photo_monet_generator, photo_van_generator, n_samples=25)

In [None]:
evaluate_cycle(van_ds.take(25), van_generator, monet_generator, n_samples=25)