<a href="https://www.kaggle.com/victororlov/photo-landscape-colorization-with-u-net-cgan?scriptVersionId=89476382" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Table of Contents

* [1. Preparation](#1)
    * [1.1 Import modules](#1_1)
    * [1.2 TPU setup](#1_2)
    * [1.3 Load and preprocess the data](#1_3)
* [2. Generator](#2)
    * [2.1 Make a generator](#2_1)
    * [2.2 Pretrain the generator](#2_2)
    * [2.3 Standalone generator results](#2_3)
    * [2.4 Save pretrained generator](#2_4)
* [3. Discriminator](#3)
* [4. cGAN](#4)
    * [4.1 Build the cGAN model](#4_1)
    * [4.2 Define cGAN loss functions](#4_2)
    * [4.3 Train the cGAN](#4_3)
    * [4.4 Final results](#4_4)    
    * [4.5 Save the final generator](#4_5)
* [5. References](#another_cell)

I took the core idea I was implementing here from [this article](https://towardsdatascience.com/colorizing-black-white-images-with-u-net-and-conditional-gan-a-tutorial-81b2df111cd8) on colorizing the black and white photos: <br>
The author is using PyTorch and fast.ai, but i decided to go with Tensorflow. So I found useful the [official Tensorflow tutorial of pix2pix cGAN implementation](http://https://www.tensorflow.org/tutorials/generative/pix2pix?hl=en).<br>
I started my work by forking the Amy Jang's brilliant [Monet CycleGAN Tutorial](https://www.kaggle.com/amyjang/monet-cyclegan-tutorial)

# 1. Preparation  <a class="anchor" id="1"></a>

## 1.1. Import modules  <a class="anchor" id="1_1"></a>

Due to some compatibility reasons I decided to get the sourse containing `rgb_to_lab` from GitHib instead of importing the `tensorflow_io` module:

In [None]:
!wget https://raw.githubusercontent.com/tensorflow/io/v0.20.0/tensorflow_io/python/experimental/color_ops.py

Installing the [Segmentation Models](https://github.com/qubvel/segmentation_models) module. It is very useful to construct U-Net from different pretrained backbone models such as ResNet18:

In [None]:
!pip install segmentation_models

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import math
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
from skimage.color import lab2rgb
import seaborn as sns
from color_ops import rgb_to_lab

## 1.2. TPU setup  <a class="anchor" id="1_2"></a>

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

## 1.3. Load and preprocess the data <a class="anchor" id="1_3"></a>

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('landscape-pictures')
print(GCS_PATH)

In [None]:
FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/*.jpg'))
print('Landscape Files:', len(FILENAMES))

In [None]:
COLOR_MODEL = 'lab'
BATCH_SIZE=1
IMAGE_SIZE = 256

As mentioned in [[1]](#another_cell), the "Lab" color space is the most popular choice for colorization problem. <br>
The `L` channel defines perceptual lighness, and `a` and `b` channels define color.
As our task is to predict the "color" image given the "colorless" image, it comes down to predicting the `a` and `b` channels given the `L` channel. <br>
Lets wright a function to load the image in Lab space and normalize it to [-1, 1] range:

In [None]:
def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
    
    image = rgb_to_lab(image)
    lightness = image[:,:,0]
    lightness = lightness/50-1
    lightness = lightness[...,tf.newaxis]
    color = image[:,:,1:]/100
    return lightness, color

Similarly, let's make a function to get readable grayscale and color images from a dataset instance:

In [None]:
def get_image_and_grayscale(input):
    l, ab = input
    color_image = np.zeros((IMAGE_SIZE,IMAGE_SIZE,3))
    color_image[:,:,:1] = l[0,...]*50+50
    color_image[:,:,1:] = ab[0,...]*100
    color_image = lab2rgb(color_image)
    grayscale = np.array(l[0,...,0])

    return color_image, grayscale

def color_hist(image):
    for i in range(image.shape[-1]):
        sns.distplot(image[...,i])

Now it's time to make a dataset from out filenames, map them with our `load()` function and split into train and validation dataset:

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(FILENAMES)
dataset = dataset.map(load, num_parallel_calls=AUTOTUNE)
dataset_train = dataset.skip(100)
dataset_val = dataset.take(100)

Let's check what's inside our dataset:

In [None]:
for example_input, example_target in dataset.batch(1).take(1):
    light = example_input[:1,...]
    color = example_target[:1,...]

def check_images(l, ab):
    color_image, grayscale = get_image_and_grayscale((l, ab))
    plt.figure(figsize=(10,10))
    plt.subplot(1,2,1)
    plt.imshow(tf.squeeze(grayscale), cmap='gray')
    plt.subplot(1,2,2)
    plt.imshow(tf.squeeze(color_image))
    
check_images(light, color)
input = (light, color)

# 2. Generator <a class="anchor" id="2"></a>

As described in [[1]](#another_cell), in order to avoid the “the blind leading the blind” problem we shall use an already pretrained classification model (in out case the ResNet18) as a downsample path in our U-Net generator. <br>
Than we should pretrain the generator using just the L1 loss. <br>
And than finally we should train the whole cGAN. <br>
So its like we are training the smallest component (in or case its already pretrained), than training the part incorporating it and so on.

## 2.1. Make a generator  <a class="anchor" id="2_1"></a>
In order to make a generator from the ResNet18 in Keras I'll use the [Segmentation Models](https://github.com/qubvel/segmentation_models) library

In [None]:
import segmentation_models
segmentation_models.set_framework('tf.keras')

In [None]:
OUTPUT_CHANNELS = 2

In [None]:
def Generator():
    unet = segmentation_models.Unet('resnet18', encoder_weights='imagenet', classes=OUTPUT_CHANNELS, activation='tanh', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), decoder_use_batchnorm=False)
    inp = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 1], name='input')
    x = layers.Concatenate()([inp, inp, inp])
    x = unet(x)
    model = tf.keras.Model(inputs=inp, outputs=x)
    return model

with strategy.scope():
    generator = Generator()

In [None]:
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

Let's check if the generator is working (although it's output will be nonsense):

In [None]:
gen_out = generator(example_input, training=False)
check_images(example_input, gen_out)

In [None]:
color_hist(gen_out)

## 2.2. Pretrain the generator <a class="anchor" id="2_2"></a>

Now we should compile the model. In the [[2]](#another_cell) they use the L1 objective (Mean Absolute Error), so I'll stick with it for the standalone generator training also.

In [None]:
loss_object = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
generator.compile(optimizer=tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss='mae')

In [None]:
generator.fit(
    dataset_train.batch(8, drop_remainder=True),
    epochs=20,
    batch_size=8
)

## 2.3. Standalone generator results <a class="anchor" id="2_3"></a>
Let's look at the intermediate results of the generator pre-training:

In [None]:
with strategy.scope():
    for example_input, example_target in dataset_val.batch(1).take(5):
        gen_out = generator(example_input, training=False)
        check_images(example_input, gen_out)

## 2.4. Save pretrained generator <a class="anchor" id="2_4"></a>

In [None]:
localhost_save_option = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")
localhost_load_option = tf.saved_model.LoadOptions(experimental_io_device="/job:localhost")
generator.save('generator_lab_deterministic', options=localhost_save_option)

In [None]:
import shutil
shutil.make_archive('generator_lab_deterministic', 'zip', './generator_lab_deterministic')

# 3. Discriminator <a class="anchor" id="3"></a>

The discriminator takes in the input image and classifies it as real or fake (generated). Instead of outputing a single node, the discriminator outputs a smaller 2D image with higher pixel values indicating a real classification and lower values indicating a fake classification.

In [None]:
def downsample(filters, size, apply_instancenorm=True, name=None):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential(name=name)
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result


In [None]:
def upsample(filters, size, apply_dropout=False, name=None):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential(name=name)
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result



In [None]:

def Discriminator():
    initializer = tf.random_normal_initializer(0., 5)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.2)

    inp = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 1], name='input')
    tar = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, OUTPUT_CHANNELS], name='target')
#     tar_d = layers.Lambda(color_decoder)(tar)
#     tar_d = color_decoder(tar)
    x = layers.Concatenate()([tar, inp])
#     x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(IMAGE_SIZE, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
# from keras import backend as K 
# K.clear_session()

with strategy.scope():
    discriminator = Discriminator()

In [None]:
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

In [None]:
disc_out = discriminator([example_input, gen_out], training=False)
plt.imshow(disc_out[0, ..., -1], cmap='RdBu_r')
plt.colorbar()

# 4. cGAN <a class="anchor" id="4"></a>

##  4.1. Build the cGAN model <a class="anchor" id="4_1"></a>

We will subclass a `tf.keras.Model` so that we can run `fit()` later to train our model. 

In [None]:
class CycleGan(tf.keras.Model):
    def __init__(
                self,
                generator,
                discriminator,
                lambda_cycle=10,
                ):
        super(CycleGan, self).__init__()
        self.gen = generator
        self.disc = discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
                self,
                gen_optimizer,
                disc_optimizer,
                gen_loss_fn,
                disc_loss_fn,
                ):
        super(CycleGan, self).compile()
        self.gen_optimizer = gen_optimizer
        self.disc_optimizer = disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        
    def train_step(self, batch_data):
        input_image, target = batch_data
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

            gen_output = self.gen(input_image, training=True)

            # discriminator used to check, inputing real images
            disc_real_output = self.disc([input_image, target], training=True)
            # discriminator used to check, inputing fake images
            disc_generated_output = self.disc([input_image, gen_output], training=True)

            # evaluates generator loss
            gen_total_loss, gen_gan_loss, gen_l1_loss = self.gen_loss_fn(disc_generated_output, gen_output, target)
            # evaluates discriminator loss
            disc_loss = self.disc_loss_fn(disc_real_output, disc_generated_output)

        # Calculate the gradients for generator and discriminat
        generator_gradients = gen_tape.gradient(gen_total_loss,
                                              self.gen.trainable_variables)

        discriminator_gradients = disc_tape.gradient(disc_loss,
                                                   self.disc.trainable_variables)

        # Apply the gradients to the optimizer
        
        generator_optimizer.apply_gradients(zip(generator_gradients,
                                              self.gen.trainable_variables))

        discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                  self.disc.trainable_variables))
        return {
                'gen_total_loss': gen_total_loss,
                'gen_gan_loss': gen_gan_loss,
                'gen_l1_loss': gen_l1_loss,
                'disc_loss': disc_loss
                }

## 4.2. Define cGAN loss functions <a class="anchor" id="4_2"></a>

The discriminator loss function below compares real images to a matrix of 1s and fake images to a matrix of 0s. The perfect discriminator will output all 1s for real images and all 0s for fake images. The discriminator loss outputs the average of the real and generated loss.

In [None]:
with strategy.scope():
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
    def discriminator_loss(disc_real_output, disc_generated_output):
        real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
        generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
        total_disc_loss = real_loss + generated_loss
        return total_disc_loss

The generator wants to fool the discriminator into thinking the generated image is real. The perfect generator will have the discriminator output only 1s. Thus, it compares the generated image to a matrix of 1s to find the loss.

In [None]:
with strategy.scope():
    LAMBDA = 100000
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
    def generator_loss(disc_generated_output, gen_output, target):
        gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
        # Mean absolute error
        l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
        total_gen_loss = gan_loss + (LAMBDA * l1_loss)

        return total_gen_loss, gan_loss, l1_loss

## 4.3. Train the cGAN <a class="anchor" id="4_3"></a>

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(generator, discriminator)
    cycle_gan_model.compile(
        gen_optimizer = generator_optimizer,
        disc_optimizer = discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss)

In [None]:
cycle_gan_model.fit(
    dataset.take(len(dataset_train)//8*8).batch(8),
    epochs=20,
    batch_size=8,
)

## 4.4. Final results <a class="anchor" id="4_4"></a>

In [None]:
with strategy.scope():
    for example_input, example_target in dataset_val.batch(1):
        gen_out = generator(example_input, training=False)
        check_images(example_input, gen_out)

## 4.5. Save the final generator <a class="anchor" id="4_5"></a>

In [None]:
localhost_save_option = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")
generator.save('generator_lab_gan', options=localhost_save_option)

In [None]:
import shutil
shutil.make_archive('generator_lab_gan', 'zip', './generator_lab_gan')

<a id='another_cell'></a>
# 5. References
1. Moein Shariatnia (2020). [Colorizing black & white images with U-Net and conditional GAN — A Tutorial](https://towardsdatascience.com/colorizing-black-white-images-with-u-net-and-conditional-gan-a-tutorial-81b2df111cd8)
2. Isola, P., Zhu, J. Y., Zhou, T., & Efros, A. A. (2017). [Image-to-image translation with conditional adversarial networks.](https://arxiv.org/abs/1611.07004) In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 1125-1134).
