In [None]:
import numpy as np
import tensorflow as tf
from skimage.color import rgb2lab, lab2rgb
import matplotlib.pyplot as plt
from tensorflow.keras.applications import vgg19
from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D
from google.colab.patches import cv2_imshow
import zipfile
from tf_rgb_lab_formulation import *

In [None]:
"""
    Trained the model using 32K images of WIDER FACE Detection datset.
    Trained model on on google colab using their free GPU.
    It took approx. 5  hrs and trained for only 3 epochs
"""

In [None]:
with zipfile.ZipFile("/content/drive/MyDrive/DATASETS/HUMAN FACES /WIDER_train.zip") as zfile:
    zfile.extractall()

datasets1 = tf.keras.preprocessing.image_dataset_from_directory("/content/WIDER_train/", batch_size=8, image_size=(256,256))

Found 12880 files belonging to 1 classes.


In [None]:
with zipfile.ZipFile("/content/drive/MyDrive/DATASETS/HUMAN FACES /WIDER_test.zip") as zfile:
    zfile.extractall()

datasets2 = tf.keras.preprocessing.image_dataset_from_directory("/content/WIDER_test/", batch_size=8, image_size=(256,256))

Found 16097 files belonging to 1 classes.


In [None]:
with zipfile.ZipFile("/content/drive/MyDrive/DATASETS/HUMAN FACES /WIDER_val.zip") as zfile:
    zfile.extractall()
    
datasets3 = tf.keras.preprocessing.image_dataset_from_directory("/content/WIDER_val/", batch_size=8, image_size=(256,256))

Found 3226 files belonging to 1 classes.


In [None]:
ds = datasets1.concatenate(datasets2)
ds = ds.concatenate(datasets3).unbatch()
norm_ds = ds.map(lambda x,y : x/255)

In [None]:
'''
    step 1 : Converting normalized image from RGB to Lab
    step 2 : splitting Lab image to L and ab 
             L - acts as Gray image 
             ab - acts as a ground truth to map with
    step 3 : Model excepts the input to be 3D, so Concatenating 1D L layer gives 3D
             (3D_L - model input && ab - model output)
'''

In [None]:
# function : rgb_to_lab() - Excepts the inputs to be range of 0 to 1
#                           Returns Lab colorspace image with values [l: 0 to 100 and ab: -128 to 128]
processed_ds = norm_ds.map(lambda x : (tf.concat((rgb_to_lab(x)[:,:,:1]/100, rgb_to_lab(x)[:,:,:1]/100,rgb_to_lab(x)[:,:,:1]/100), axis=-1), 
                             rgb_to_lab(x)[:,:,1:]/128))
processed_ds = processed_ds.batch(8)

In [None]:
def Colourise_model_with_VGGbase(input_shape=[256,256,3]):
    vgg = vgg19.VGG19(include_top=False, input_shape=input_shape)
    vgg.trainable = False
    
    concat_layers_name = ['block1_conv2', 'block2_conv2', 'block3_conv4', 'block4_conv4', 'block5_conv4']
    outputs = dict((name,vgg.get_layer(name).output) for name in concat_layers_name)
    outputs['inputs'] = vgg.layers[0].output
    up_samp = list(map(lambda x : UpSampling2D(size=(256//(x.shape[1]), 256//(x.shape[1])))(x), outputs.values()))
    concat_up_scaled = tf.concat(up_samp, axis=-1)
    
    #norm_concate = BatchNormalization(axis=-1)(concat_up_scaled)

    X = Conv2D(512, (1,1), padding='SAME')(concat_up_scaled)
    X = tf.keras.layers.LeakyReLU()(X)
    X = BatchNormalization(axis=-1)(X)
    X = Conv2D(256, (1,1), activation='relu', padding='SAME')(X)
    X = Conv2D(128, (1,1), padding='SAME')(X)
    X = tf.keras.layers.LeakyReLU()(X)
    #X = BatchNormalization(axis=-1)(X)
    X = Conv2D(64, (1,1), padding='SAME', activation='relu')(X)
    X = Conv2D(2, (1,1), padding='SAME', activation='tanh')(X)

    return tf.keras.Model([vgg.input], X)

In [None]:
MODEL = Colourise_model_with_VGGbase()

In [None]:
def dist_loss(true_ab, predict_ab):
    return tf.math.squared_difference(true_ab, predict_ab, name='distance_loss')

In [None]:
MODEL.compile(optimizer=tf.keras.optimizers.Adam(), loss=dist_loss, metrics=['mse', 'mae'])
MODEL.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 256, 256, 64) 1792        input_1[0][0]                    
__________________________________________________________________________________________________
block1_conv2 (Conv2D)           (None, 256, 256, 64) 36928       block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_pool (MaxPooling2D)      (None, 128, 128, 64) 0           block1_conv2[0][0]               
______________________________________________________________________________________________

In [None]:
# Callback for saving the best model by monitoring loss at every epoch
callback = tf.keras.callbacks.ModelCheckpoint('/content/drive/MyDrive/Colourise/Colourise_v0_models/', monitor='loss', verbose=0, save_best_only=False)

In [None]:
MODEL.fit(processed_ds, epochs=3, callbacks=[callback])

Epoch 1/3
INFO:tensorflow:Assets written to: /content/drive/MyDrive/Colourise/Colourise_v2_models/assets
Epoch 2/3
INFO:tensorflow:Assets written to: /content/drive/MyDrive/Colourise/Colourise_v2_models/assets
Epoch 3/3
INFO:tensorflow:Assets written to: /content/drive/MyDrive/Colourise/Colourise_v2_models/assets


<tensorflow.python.keras.callbacks.History at 0x7f700a5219b0>

In [None]:
Model.save('/content/drive/MyDrive/Colourise/Colourise_v0_models/FinalModel') #Saving the entire model