# CSE 465 Project

In [None]:
# !pip install -U tensorflow==2.2.0

In [None]:
import tensorflow as tf
from keras.callbacks import ModelCheckpoint
from keras.models import Sequential, load_model

import os
import time

from matplotlib import pyplot as plt

# Change PATH variable to absolute/ relative path to the images directory on your machine which contains the train and val folders
PATH = '../input/anime-sketch-colorization-pair/data'

# Change these variables as per your need
EPOCHS = 100
# BUFFER_SIZE is used when we shuffle the data samples while training. 
# Higher the value of this more will be the degree of shuffling, and hence, 
# higher will be the accuracy of the model.But with large data, it takes a lot of processing power to shuffle the images. 
BUFFER_SIZE = 1000 #14224
BATCH_SIZE = 10
# originally it is 512, for faster training we have chosen 256X256
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
tf.__version__

In [None]:
def load(image_file):
    # read the image file and make it a tensor string
    image = tf.io.read_file(image_file)
    # Decode a PNG-encoded image to a uint8(by default)
    image = tf.image.decode_png(image)
    
    w = tf.shape(image)[1]
    print(f'w:{w}\n')

    w = w // 2
    print(f'w//2:{w}\n')
    
    # seperate the real image and the input image from the 512X1024(height*width) image
    real_image = image[:, :w, :] # height, width, channel
    print(f'real_image = {real_image}\n')
    print(f'real_image = {tf.shape(real_image)}\n')
    input_image = image[:, w:, :] # height, width, channel
    print(f'input_image = {input_image}\n')
    print(f'input_image = {tf.shape(input_image)}\n')

    # now make/cast those tensors in float32 type
    input_image = tf.cast(input_image, tf.float32)
    print(f'input_image:{input_image}\n')
    real_image = tf.cast(real_image, tf.float32)
    print(f'real_image:{real_image}\n')

    return input_image, real_image

# Preprocessing

As you can see in the images below
that they are going through random jittering,
Random jittering as described in the paper is to

1. Resize an image to bigger height and width
2. Randomly crop to the target size
3. Randomly flip the image horizontally

In [None]:
inp, re = load(PATH+'/train/1005024.png')
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp/255.0)
plt.figure()
plt.imshow(re/255.0)

In [None]:
def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    print(f'resize --> input_image:{tf.shape(input_image)}\n')
    real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    print(f'resize --> real_image:{tf.shape(real_image)}\n')

    return input_image, real_image

def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    print(f'random_crop --> stacked_image:{tf.shape(stacked_image)}\n')
    cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
    print(f'random_crop --> cropped_image:{tf.shape(cropped_image)}\n')

    return cropped_image[0], cropped_image[1]

# normalizing the images to [-1, 1]
def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    print(f'normalize --> input_image:{tf.shape(input_image)}\n')
    real_image = (real_image / 127.5) - 1
    print(f'normalize --> real_image:{tf.shape(real_image)}\n')

    return input_image, real_image

@tf.function() # @tf.function in order to turn plain Python code into graph.
def random_jitter(input_image, real_image):
    # resizing to 286 x 286 x 3
    input_image, real_image = resize(input_image, real_image, 286, 286)
    # randomly cropping to 256 x 256 x 3
    input_image, real_image = random_crop(input_image, real_image)

    # random mirroring
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)

    return input_image, real_image

In [None]:
# for testing the settings
plt.figure(figsize=(8, 8))
for i in range(4):
    rj_inp, rj_re = random_jitter(inp, re)
    plt.subplot(2, 2, i+1)
    plt.imshow(rj_inp/255.0)
    plt.axis('off')
plt.show()

# Loading the Train & Test Data

In [None]:
def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image
    
train_dataset = tf.data.Dataset.list_files(PATH+'/train/*.png')
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) # automatically tune performance knobs
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image
  
test_dataset = tf.data.Dataset.list_files(PATH+'/val/*.png')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)

# Building the Generator Model

## Build the Generator
  * The architecture of generator is a modified U-Net.
  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder (as in U-Net).
  * U-Net is a convolutional neural network that was developed for biomedical image segmentation at the Computer Science Department of the University of Freiburg, Germany.

In [None]:
OUTPUT_CHANNELS = 3 # because of the RGB channel

# used for compressing the image
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
        
    # LeakyReLU --> allowing some negative values to pass through.The range of the Leaky ReLU is (-infinity to infinity) 
    # The whole idea behind making the Generator work is to receive gradient values from the Discriminator, 
    # and if the network is stuck in a dying state situation, the learning process won’t happen.
    result.add(tf.keras.layers.LeakyReLU())

    return result

In [None]:
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)

In [None]:
# to make the image from to its original size
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
    
    # ReLU --> take the maximum between the input value and zero.
    # If we use the ReLU activation function, 
    # sometimes the network gets stuck in a popular state called the dying state, 
    # and that’s because the network produces nothing but zeros for all the outputs.
    result.add(tf.keras.layers.ReLU())

    return result

In [None]:
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)

In [None]:
def buildGenerator():
    inputs = tf.keras.layers.Input(shape=[256,256,3])

    # bs -> Batch Size
    down_stack = [
        downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    
    # The range of the tanh function is from (-1 to 1). tanh is also sigmoidal (s-shaped).
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                           strides=2,
                                           padding='same',
                                           kernel_initializer=initializer,
                                           activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    print(f'skips = {skips}\n')

    skips = reversed(skips[:-1])
    print(f'skips2 = {skips}\n')

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

generator = buildGenerator()
generator.summary()

In [None]:
generator = buildGenerator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=80)

# Building the Discriminator Model

## Build the Discriminator
  * The Discriminator is a PatchGAN.
  * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU)
  * The shape of the output after the last layer is (batch_size, 30, 30, 1)
  * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).
  * Discriminator receives 2 inputs.
    * Input image and the target image, which it should classify as real.
    * Input image and the generated image (output of generator), which it should classify as fake.
    * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)

* PatchGAN is a type of discriminator for generative adversarial networks which only penalizes structure at the scale of local image patches. The PatchGAN discriminator tries to classify if each  patch in an image is real or fake. This discriminator is run convolutionally across the image, averaging all responses to provide the ultimate output of . Such a discriminator effectively models the image as a Markov random field, assuming independence between pixels separated by more than a patch diameter. It can be understood as a type of texture/style loss.

In [None]:
def buildDiscriminator():
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
    tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

    x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)
  
discriminator = buildDiscriminator()
discriminator.summary()

In [None]:
discriminator = buildDiscriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=80)

# Loss Functions for the Models

* **Generator loss**
  * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.
  * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.
  * This allows the generated image to become structurally similar to the target image.
  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004).

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

LAMBDA = 100

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

    total_gen_loss = gan_loss + (LAMBDA * l1_loss)

    return total_gen_loss, gan_loss, l1_loss

# print('::::generator_loss::::\n')
# print(f'total_gen_loss:{total_gen_loss}\n, gan_loss:{gan_loss}\n, l1_loss:{l1_loss}\n')

**Discriminator loss**
  * The discriminator loss function takes 2 inputs; **real images, generated images**
  * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**
  * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**
  * Then the total_loss is the sum of real_loss and the generated_loss


In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss
# print('::::discriminator_loss::::\n')
# print(f'total_disc_loss:{total_disc_loss}\n')

# Optimizers

In [None]:
# beta_1 is the exponential decay rate, by default it is 0.9
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# Creating Checkpoints

In [None]:
checkpoint_dir = './Sketch2Color_training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# Displaying Output Images

## Generate Images

Write a function to plot some images during training.

* We pass images from the test dataset to the generator.
* The generator will then translate the input image into the output.
* Last step is to plot the predictions and **voila!**

Note: The `training=True` is intentional here since
we want the batch statistics while running the model
on the test dataset. If we use training=False, we will get
the accumulated statistics learned from the training dataset
(which we don't want)

In [None]:
def generate_images(model, test_input, tar):
    print("Displaying Output Images")
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15,15))

    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

# Logging the Losses

In [None]:
import datetime
log_dir="Sketch2Coloe_logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

# Train Step

## Training

* For each example input generate an output.
* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.
* Next, we calculate the generator and the discriminator loss.
* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.
* Then log the losses to TensorBoard.

In [None]:
@tf.function
def train_step(input_image, target, epoch):
    # gen_tape-> compute gradients for generator
    # disc_tape-> compute gradients for discriminator
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        
        # giving the discriminator sketch image and target image
        disc_real_output = discriminator([input_image, target], training=True)
        # giving the discriminator sketch image and the colorized image by generator
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        
        
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

    with summary_writer.as_default():
        tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)
        tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)
        tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)
        tf.summary.scalar('disc_loss', disc_loss, step=epoch)


# Define the fit() function

The actual training loop:

* Iterates over the number of epochs.
* On each epoch it clears the display, and runs `generate_images` to show it's progress.
* On each epoch it iterates over the training dataset, printing a '.' for each example.
* It saves a checkpoint every 20 epochs.

In [None]:
def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        start = time.time()

        for example_input, example_target in test_ds.take(1):
            generate_images(generator, example_input, example_target)
        print("Epoch: ", epoch)
        
        # Train
        for n, (input_image, target) in train_ds.enumerate():
            print('.', end="")
            if (n+1) % 100 == 0:
                print()
            train_step(input_image, target, epoch)
        print()

        # saving (checkpoint) the model every 20 epochs
        if (epoch + 1) % 20 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))
    checkpoint.save(file_prefix = checkpoint_prefix)

# Tensorborad Graph

In [None]:
%load_ext tensorboard
%tensorboard --logdir {log_dir}

# Now start the training

In [None]:
fit(train_dataset, EPOCHS, test_dataset)

# Restore Latest Checkpoint or Load the saved H5 file

In [None]:
from tensorflow import keras
# model = keras.models.load_model('../input/h5-model/2d_Skatch_Colorization_Model.h5')
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
# print(checkpoint)

# Testing Outputs

In [None]:
for example_input, example_target in test_dataset.take(5):
    generate_images(generator, example_input, example_target)

# Saving the Model

In [None]:
generator.save('2d_Skatch_Colorization_Model.h5')

# Generate using the h5 file for your skatched Images

### Load the model

In [None]:
from tensorflow import keras
model = keras.models.load_model('../input/h5-model/2d_Skatch_Colorization_Model.h5')

### Pre-process the images

In [None]:
def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)

    w = tf.shape(image)[1]
    print(f'w:{w}')

    w = w // 2
    print(f'w//2:{w}')
    
    input_image = image[:, :, :]
    print(f'input_image:{input_image}')

    input_image = tf.cast(input_image, tf.float32)
    print(f'input_image:{input_image}')
    
    return input_image

def resize(input_image, height, width):
    input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_image


def normalize(input_image):
    input_image = (input_image / 127.5) - 1

    return input_image

In [None]:
def load_image_test2(image_file):
    input_image = load(image_file)
    input_image = resize(input_image, IMG_HEIGHT, IMG_WIDTH)
    input_image = normalize(input_image)

    return input_image
  
test_dataset = tf.data.Dataset.list_files('../input/skatch-images'+'/*.jpg')
test_dataset = test_dataset.map(load_image_test2)
test_dataset = test_dataset.batch(BATCH_SIZE)

### Generate Images

In [None]:
def generate_images2(model, test_input):
    print("Displaying Output Images")
    prediction = model(test_input, training=True)
#     print(f'test_input: {test_input}')
#     print(f'prediction: {prediction}')
    plt.figure(figsize=(15,15))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

In [None]:
for i in range(10):
    for example_input in test_dataset.take(10):
        generate_images2(model, example_input)