<a href="https://www.kaggle.com/victororlov/art-landscape-colorization-unet-cgan?scriptVersionId=89572082" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# from tensorflow import tfds
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
import math
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
from skimage.color import lab2rgb
import seaborn as sns
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

In [None]:
!wget https://raw.githubusercontent.com/tensorflow/io/v0.20.0/tensorflow_io/python/experimental/color_ops.py

In [None]:
from color_ops import rgb_to_lab

In [None]:
!pip install segmentation_models

In [None]:
import segmentation_models

# Load in the data

In [None]:
COLOR_MODEL = 'lab'
IMAGE_SIZE = 224

In [None]:
def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
    
    if COLOR_MODEL == 'lab':
        image = rgb_to_lab(image)
        lightness = image[:,:,0]
        lightness = lightness/50-1
        lightness = lightness[...,tf.newaxis]
        color = image[:,:,1:]/100
        color = color*1.3
        return lightness, color
    else:
        lightness = tf.image.rgb_to_grayscale(image)
        image = image*2 - 1
#     lightness = image[:,:,0]
#     lightness = lightness/100
#     lightness = lightness[...,tf.newaxis]
#     color = image[:,:,1:]/100
    
        return lightness, image

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('landscape-pictures')
# GCS_PATH = 'gs://tfds-data/datasets/coco'

In [None]:
FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/*.jpg'))
print('Monet TFRecord Files:', len(FILENAMES))

All the images for the competition are already sized to 256x256. As these images are RGB images, set the channel to 3. Additionally, we need to scale the images to a [-1, 1] scale. Because we are building a generative model, we don't need the labels or the image id so we'll only return the image from the TFRecord.

Define the function to extract the image from the files.

In [None]:
BATCH_SIZE=1

In [None]:
# with strategy.scope():
dataset = tf.data.Dataset.from_tensor_slices(FILENAMES)
dataset = dataset.map(load, num_parallel_calls=AUTOTUNE)

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('artstation-landscape-thumbnails')
FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/images/*.jpeg'))
print('Monet TFRecord Files:', len(FILENAMES))

art_dataset = tf.data.Dataset.from_tensor_slices(FILENAMES)
art_dataset = art_dataset.map(load, num_parallel_calls=AUTOTUNE)
art_dataset_train = art_dataset.skip(100)
art_dataset_val = art_dataset.take(100)

In [None]:
def get_image(input):
    if COLOR_MODEL == 'lab':
        l, ab = input
        image = np.zeros((IMAGE_SIZE,IMAGE_SIZE,3))
        image[:,:,:1] = l[0,...]*50+50
        image[:,:,1:] = ab[0,...]*100
        image = lab2rgb(image)
        lightness = np.array(l[0,...,0])
        return image, lightness
    else:
        lightness, image = input
        image = lab2rgb(image*100)
    
        return image[0,...], lightness[0,...,0]

def color_hist(color):
    for i in range(color.shape[-1]):
        sns.distplot(color[...,i])

In [None]:
for example_input, example_target in art_dataset_val.batch(1).take(2):
    light = example_input[:1,...]
    color = example_target[:1,...]

def check_images(light, color):
    image, lightness = get_image((light, color))
    plt.figure(figsize=(10,10))
    plt.subplot(1,2,1)
    plt.imshow(tf.squeeze(lightness), cmap='gray')
    plt.subplot(1,2,2)
    plt.imshow(tf.squeeze(image))
    
check_images(light, color)

# Build the generator

We'll be using a UNET architecture for our CycleGAN. To build our generator, let's first define our `downsample` and `upsample` methods.

The `downsample`, as the name suggests, reduces the 2D dimensions, the width and height, of the image by the stride. The stride is the length of the step the filter takes. Since the stride is 2, the filter is applied to every other pixel, hence reducing the weight and height by 2.

We'll be using an instance normalization instead of batch normalization. As the instance normalization is not standard in the TensorFlow API, we'll use the layer from TensorFlow Add-ons.

In [None]:
OUTPUT_CHANNELS = 3
if COLOR_MODEL=='lab':
    OUTPUT_CHANNELS = 2

In [None]:
segmentation_models.set_framework('tf.keras')

In [None]:
def Generator():
    unet = segmentation_models.Unet('resnet18', encoder_weights='imagenet', classes=OUTPUT_CHANNELS, activation='tanh', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), decoder_use_batchnorm=False)
    inp = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 1], name='input')
    x = layers.Concatenate()([inp, inp, inp])
    x = unet(x)
#     x = tf.linalg.normalize(x)[0]
    model = tf.keras.Model(inputs=inp, outputs=x)
    return model


with strategy.scope():
    generator = Generator()

In [None]:
# Generator().summary()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

In [None]:
gen_out = generator(example_input, training=False)
check_images(example_input, gen_out)

In [None]:
color_hist(gen_out)

In [None]:
# loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
loss_object = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
generator.compile(optimizer=tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss='mae')

In [None]:
generator.fit(
    art_dataset_train.take(len(art_dataset_train)//8*8).batch(8),
    epochs=20,
#     batch_size=8,
)

In [None]:
with strategy.scope():
    for example_input, example_target in art_dataset_val.batch(1):
        gen_out = generator(example_input, training=False)
        check_images(example_input, gen_out)

In [None]:
localhost_save_option = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")
localhost_load_option = tf.saved_model.LoadOptions(experimental_io_device="/job:localhost")
generator.save('generator_lab_deterministic', options=localhost_save_option)

In [None]:
# with strategy.scope():
#     generator = tf.keras.models.load_model('generator_lab_deterministic', options=localhost_load_option)

In [None]:
import shutil
shutil.make_archive('generator_lab_deterministic', 'zip', './generator_lab_deterministic')

# Build the discriminator

The discriminator takes in the input image and classifies it as real or fake (generated). Instead of outputing a single node, the discriminator outputs a smaller 2D image with higher pixel values indicating a real classification and lower values indicating a fake classification.

In [None]:
def downsample(filters, size, apply_instancenorm=True, name=None):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential(name=name)
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result


In [None]:
def upsample(filters, size, apply_dropout=False, name=None):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential(name=name)
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result



In [None]:

def Discriminator():
    initializer = tf.random_normal_initializer(0., 5)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.2)

    inp = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 1], name='input')
    tar = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, OUTPUT_CHANNELS], name='target')
#     tar_d = layers.Lambda(color_decoder)(tar)
#     tar_d = color_decoder(tar)
    x = layers.Concatenate()([tar, inp])
#     x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(IMAGE_SIZE, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
# from keras import backend as K 
# K.clear_session()

with strategy.scope():
    discriminator = Discriminator()

In [None]:
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

In [None]:
disc_out = discriminator([example_input, gen_out], training=False)
plt.imshow(disc_out[0, ..., -1], cmap='RdBu_r')
plt.colorbar()

# Build the CycleGAN model

We will subclass a `tf.keras.Model` so that we can run `fit()` later to train our model. During the training step, the model transforms a photo to a Monet painting and then back to a photo. The difference between the original photo and the twice-transformed photo is the cycle-consistency loss. We want the original photo and the twice-transformed photo to be similar to one another.

The losses are defined in the next section.

In [None]:
class CycleGan(keras.Model):
    def __init__(
                self,
                generator,
                discriminator,
                lambda_cycle=10,
                ):
        super(CycleGan, self).__init__()
        self.gen = generator
        self.disc = discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
                self,
                gen_optimizer,
                disc_optimizer,
                gen_loss_fn,
                disc_loss_fn,
                ):
        super(CycleGan, self).compile()
        self.gen_optimizer = gen_optimizer
        self.disc_optimizer = disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        
    def train_step(self, batch_data):
        input_image, target = batch_data
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

            gen_output = self.gen(input_image, training=True)

            # discriminator used to check, inputing real images
            disc_real_output = self.disc([input_image, target], training=True)
            # discriminator used to check, inputing fake images
            disc_generated_output = self.disc([input_image, gen_output], training=True)

            # evaluates generator loss
            gen_total_loss, gen_gan_loss, gen_l1_loss = self.gen_loss_fn(disc_generated_output, gen_output, target)
            # evaluates discriminator loss
            disc_loss = self.disc_loss_fn(disc_real_output, disc_generated_output)

        # Calculate the gradients for generator and discriminat
        generator_gradients = gen_tape.gradient(gen_total_loss,
                                              self.gen.trainable_variables)

        discriminator_gradients = disc_tape.gradient(disc_loss,
                                                   self.disc.trainable_variables)

        # Apply the gradients to the optimizer
        
        generator_optimizer.apply_gradients(zip(generator_gradients,
                                              self.gen.trainable_variables))

        discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                  self.disc.trainable_variables))
        return {
                'gen_total_loss': gen_total_loss,
                'gen_gan_loss': gen_gan_loss,
                'gen_l1_loss': gen_l1_loss,
                'disc_loss': disc_loss
                }

# Define loss functions

The discriminator loss function below compares real images to a matrix of 1s and fake images to a matrix of 0s. The perfect discriminator will output all 1s for real images and all 0s for fake images. The discriminator loss outputs the average of the real and generated loss.

In [None]:
with strategy.scope():
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
    def discriminator_loss(disc_real_output, disc_generated_output):
        real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
        generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
        total_disc_loss = real_loss + generated_loss
        return total_disc_loss

The generator wants to fool the discriminator into thinking the generated image is real. The perfect generator will have the discriminator output only 1s. Thus, it compares the generated image to a matrix of 1s to find the loss.

In [None]:
with strategy.scope():
    LAMBDA = 100000
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
#     loss_object_2 = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
    def generator_loss(disc_generated_output, gen_output, target):
        gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
        # Mean absolute error
        l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
        total_gen_loss = gan_loss + (LAMBDA * l1_loss)

        return total_gen_loss, gan_loss, l1_loss

We want our original photo and the twice transformed photo to be similar to one another. Thus, we can calculate the cycle consistency loss be finding the average of their difference.

The identity loss compares the image with its generator (i.e. photo with photo generator). If given a photo as input, we want it to generate the same image as the image was originally a photo. The identity loss compares the input with the output of the generator.

# Train the CycleGAN

Let's compile our model. Since we used `tf.keras.Model` to build our CycleGAN, we can just ude the `fit` function to train our model.

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    for example_input, example_target in dataset.batch(1).take(1):
        gen_out = generator(example_input, training=False)
        check_images(example_input, gen_out)

In [None]:
class CustomCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        with strategy.scope():
            gen_out = self.model.gen(example_input, training=False)
            check_images(example_input, gen_out)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(generator, discriminator)
    cycle_gan_model.compile(
                                gen_optimizer = generator_optimizer,
                                disc_optimizer = discriminator_optimizer,
                                gen_loss_fn = generator_loss,
                                disc_loss_fn = discriminator_loss,
                            )

In [None]:
cycle_gan_model.fit(
    art_dataset_train.take(len(art_dataset_train)//8*8).batch(8),
    epochs=40,
    batch_size=8,
#     callbacks=[CustomCallback()],
#     validation_data=art_dataset_val
)

In [None]:
with strategy.scope():
    for example_input, example_target in art_dataset_val.batch(1):
        gen_out = generator(example_input*1, training=False)
        check_images(example_input, gen_out*1.5)

In [None]:
generator.save('generator_lab_gan_art', options=localhost_save_option)
shutil.make_archive('generator_lab_gan_art', 'zip', './generator_lab_gan_art')