# Imports

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import pickle

from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from tensorflow.keras import Model, Sequential, Input
from tensorflow.keras.layers import Conv2D, UpSampling2D, Reshape, RepeatVector, Concatenate, Activation, BatchNormalization, MaxPooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.initializers import RandomNormal

from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import cv2
from PIL import Image

%matplotlib inline

# Load Dataset

In [None]:
def load_data():
    l = np.load('../input/image-colorization/l/gray_scale.npy')
    ab1 = np.load('../input/image-colorization/ab/ab/ab1.npy')
    ab2 = np.load('../input/image-colorization/ab/ab/ab2.npy')
    ab3 = np.load('../input/image-colorization/ab/ab/ab3.npy')
    ab = np.concatenate((ab1, ab2, ab3))
    return train_test_split(l, ab, train_size=0.9, random_state=42)
    

l_train, l_test, ab_train, ab_test = load_data()
l_test, l_valid, ab_test, ab_valid = train_test_split(l_test, ab_test, test_size = 0.5, random_state=42)
l_train = l_train.reshape((-1, 224, 224, 1))
l_valid = l_valid.reshape((-1, 224, 224, 1))
l_test = l_test.reshape((-1, 224, 224, 1))
l_train.shape, l_valid.shape, l_test.shape, ab_train.shape, ab_valid.shape, ab_test.shape

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((l_train, ab_train))
valid_dataset = tf.data.Dataset.from_tensor_slices((l_valid, ab_valid))
test_dataset = tf.data.Dataset.from_tensor_slices((l_test, ab_test))

train_dataset = train_dataset.batch(64).prefetch(tf.data.AUTOTUNE)
valid_dataset = valid_dataset.batch(64).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(64).prefetch(tf.data.AUTOTUNE)

# Model

In [None]:
inception = InceptionResNetV2(include_top=True, classifier_activation=None, weights='imagenet')
inception.trainable = False

In [None]:
def ModelBuilder():
    def _encoder():
        model = Sequential(name="encoder")
        
        model.add(Conv2D(64, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D())
        
        model.add(Conv2D(128, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        
        model.add(Conv2D(128, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D())
        
        model.add(Conv2D(256, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        
        model.add(Conv2D(256, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D())
        
        model.add(Conv2D(512, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        
        model.add(Conv2D(512, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        
        model.add(Conv2D(256, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        return model


    def _decoder():
        model = Sequential(name="decoder")
        
        model.add(Conv2D(128, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(UpSampling2D((2, 2)))
        
        model.add(Conv2D(64, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        
        model.add(Conv2D(64, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(UpSampling2D((2, 2)))

        model.add(Conv2D(32, (3, 3), padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        
        model.add(Conv2D(2, (3, 3), activation="tanh", padding="same", kernel_initializer=RandomNormal(stddev=0.02)))
        model.add(UpSampling2D((2, 2)))
        return model


    def _inception_embedding(inputs):
        i = tf.image.resize(inputs, [299,299])
        i = tf.image.grayscale_to_rgb(i) # Stack the image with itself to obtain a three-channel image
        return inception(i, training=False)


    def _fusion(enc, emb):
        h, w = enc.shape[1:3]
        shape = (h, w, 1000)
        f = RepeatVector(h * w)(emb)
        f = Reshape(shape)(f)
        f = Concatenate(axis=3)([enc, f])
        return f
    
    
    inputs = Input(shape=(224, 224, 1))
    enc = _encoder()(inputs)
    emb = _inception_embedding(inputs)
    fus = _fusion(enc, emb)
    conv = Conv2D(256, (1, 1), kernel_initializer=RandomNormal(stddev=0.02))(fus)
    norm = BatchNormalization()(conv)
    act = Activation('relu')(norm)
    dec = _decoder()(act)
    
    return Model(inputs, dec)

In [None]:
model = ModelBuilder()
model.summary()

# Utility Functions

In [None]:
def lab_to_rgb(l, ab):
    lab = np.concatenate((l, ab), axis=2)
    lab = lab.astype("uint8")
    rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) 
    rgb = Image.fromarray(rgb)
    return rgb

In [None]:
def scale(x):
    x = tf.cast(x, tf.float32)
    return preprocess_input(x)

In [None]:
def unscale(x):
    x = x * 127.5 + 127.5
    x = tf.cast(x, tf.uint8)
    return x

# Dataset Preprocessing

In [None]:
# Scale pixel values between -1 and 1
preprocessed_train_dataset = train_dataset.map(lambda x, y: (scale(x), scale(y)))
preprocessed_valid_dataset = valid_dataset.map(lambda x, y: (scale(x), scale(y)))
preprocessed_test_dataset = test_dataset.map(lambda x, y: (scale(x), scale(y)))

# Metrics

In [None]:
def ssim(y_true, y_pred):
    y_true = (y_true + 1) / 2.0
    y_pred = (y_pred + 1) / 2.0
    return tf.image.ssim(y_true, y_pred, max_val=1.0)

In [None]:
def psnr(y_true, y_pred):
    y_true = (y_true + 1.0) / 2.0
    y_pred = (y_pred + 1.0) / 2.0
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

# Train

In [None]:
training_history = {'loss': [], 'val_loss': [],
                   'psnr': [], 'val_psnr': [],
                   'ssim': [], 'val_ssim': []}

steps_per_epoch = len(preprocessed_train_dataset)
model.compile(
    optimizer=Adam(learning_rate=ExponentialDecay(
        initial_learning_rate=1e-4, decay_steps=steps_per_epoch, decay_rate=0.90, staircase=True)),
    loss='mse',
    metrics=[psnr, ssim])

In [None]:
model.load_weights('model_weights.h5')
with open('history.pkl', "rb") as file:
    training_history = pickle.load(file)

In [None]:
history = model.fit(preprocessed_train_dataset, validation_data=preprocessed_valid_dataset, epochs=15)

training_history['loss'].extend(history.history['loss'])
training_history['val_loss'].extend(history.history['val_loss'])
training_history['psnr'].extend(history.history['psnr'])
training_history['val_psnr'].extend(history.history['val_psnr'])
training_history['ssim'].extend(history.history['ssim'])
training_history['val_ssim'].extend(history.history['val_ssim'])

In [None]:
plt.plot(training_history['loss'], label='Training Loss')
plt.plot(training_history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

In [None]:
plt.plot(training_history['psnr'], label='Training PSNR')
plt.plot(training_history['val_psnr'], label='Validation PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.title('Training and Validation PSNR')
plt.legend()
plt.show()

In [None]:
plt.plot(training_history['ssim'], label='Training SSIM')
plt.plot(training_history['val_ssim'], label='Validation SSIM')
plt.xlabel('Epoch')
plt.ylabel('SSIM')
plt.title('Training and Validation SSIM')
plt.legend()
plt.show()

In [None]:
model.save_weights('model_weights.h5')
with open('history.pkl', 'wb') as file:
    pickle.dump(training_history, file)

# Evaluation

In [None]:
model.evaluate(preprocessed_test_dataset)

In [None]:
predictions = model.predict(preprocessed_test_dataset.map(lambda x, y: x))
predictions = unscale(predictions)

In [None]:
fig, axes = plt.subplots(50, 2, figsize=(5, 120))

for i in range(50):
    axes[i, 0].imshow(lab_to_rgb(l_test[i], predictions[i]))
    axes[i, 0].axis('off')
    axes[i, 1].imshow(lab_to_rgb(l_test[i], ab_test[i]))
    axes[i, 1].axis('off')
    
    if i == 0:
        axes[i, 0].set_title('Prediction', fontsize=12)
        axes[i, 1].set_title('Ground Truth', fontsize=12)

plt.tight_layout()
plt.show()