In [None]:
import tensorflow as tf
import keras
from keras.api import layers
import numpy as np
from typing import Tuple
import matplotlib.pyplot as plt
import cv2
from keras.api.callbacks import ModelCheckpoint, EarlyStopping
# import time

training_completed = True
ds_path = '/DATA_128/'
ds_path1 = '/DATA_128/'
train_dataset = tf.data.Dataset.load(ds_path + 'train')
test_dataset = tf.data.Dataset.load(ds_path1 + 'test')
tmp = train_dataset.cardinality()

train_dataset = train_dataset.batch(64)
test_dataset = test_dataset.batch(64)
train_dataset = train_dataset.cache()
test_dataset = test_dataset.cache()


train_dataset = train_dataset.shuffle(train_dataset.cardinality())


train_dataset = train_dataset.prefetch(5)
test_dataset = test_dataset.prefetch(5)


# print(train_dataset.cardinality())
# print(test_dataset.cardinality())
# print(train_dataset.element_spec)
# print(test_dataset.element_spec)


In [None]:
SHAPE_ = (128, 128, 1)
L1_NORM = 1e-5 * 0.0
L2_NORM = 1e-6 * 0.0

def encoder_block(filters: int , kernel_size: Tuple[int, int], apply_batch_normalization = True, l1_reg = 0.0, l2_reg = 0.0):
    downsample = keras.models.Sequential()
    downsample.add(layers.Conv2D(filters, kernel_size, padding = 'same', strides = 2,
                                 kernel_regularizer=keras.regularizers.L1L2(l1=l1_reg, l2=l2_reg)))
    if apply_batch_normalization:
        downsample.add(layers.BatchNormalization())
    downsample.add(keras.layers.LeakyReLU())
    return downsample

def decoder_block(filters: int, kernel_size: Tuple[int, int], dropout = False, l1_reg = 0.0, l2_reg = 0.0):
    upsample = keras.models.Sequential()
    upsample.add(layers.Conv2DTranspose(filters, kernel_size, padding = 'same', strides = 2,
                                        kernel_regularizer=keras.regularizers.L1L2(l1=l1_reg, l2=l2_reg)))
    if dropout:
        upsample.add(layers.Dropout(0.2))
    upsample.add(keras.layers.LeakyReLU())
    return upsample

def build_colorizer(input_shape = SHAPE_, l2_reg = L2_NORM, l1_reg = L1_NORM):
    inputs = layers.Input(shape=input_shape)

    # ENCODER

    x1 = encoder_block(128, (3, 3), False)(inputs) # /2
    x2 = encoder_block(128, (3, 3), False) (x1) # /4
    x3 = encoder_block(256, (3, 3), True) (x2) # /8
    x4 = encoder_block(512, (3, 3), True) (x3) # /16
    x5 = encoder_block(1024, (3, 3), True) (x4) # /32
    x6 = encoder_block(2048, (3, 3), True) (x5) # /64
    # LATENT SPACE

    b1 = encoder_block(2048, (3, 3), True) (x6) # /128

    # DECODER

    y6 = decoder_block(2048, (3, 3), False) (b1)
    y6 = layers.concatenate([y6, x6])

    y5 = decoder_block(1024, (3, 3), False) (y6)
    y5 = layers.concatenate([y5, x5])

    y4 = decoder_block(512, (3, 3), False) (y5)
    y4 = layers.concatenate([y4, x4])

    y3 = decoder_block(256, (3, 3), False) (y4)
    y3 = layers.concatenate([y3, x3])

    y2 = decoder_block(128, (3, 3), False) (y3)
    y2 = layers.concatenate([y2, x2])

    y1 = decoder_block(128, (3, 3), False) (y2)
    y1 = layers.concatenate([y1, x1])


    outputs = decoder_block(2, (3, 3), False) (y1)
    outputs = layers.concatenate([outputs, inputs])
    outputs = layers.Conv2D(2, (3,3), padding='same', strides=1, activation='tanh', kernel_initializer=keras.initializers.GlorotNormal()) (outputs)

    return keras.Model(inputs, outputs)

In [None]:
SHAPE_ = (128, 128, 1)
L1_NORM = 1e-5 * 0.0
L2_NORM = 1e-6 * 0.0

In [None]:

colorizer = build_colorizer(SHAPE_, L2_NORM)



if training_completed:
    colorizer = build_colorizer(SHAPE_, L2_NORM)
    _path = '/MK/5_128_2048_checkpoint.weights.h5' 
    colorizer.load_weights(_path)
else:
    colorizer.compile(optimizer=keras.optimizers.Adam(learning_rate = 0.001),
                  loss=keras.losses.mean_squared_error)
    hist = colorizer.fit(
        train_dataset,
        epochs=300,
        validation_data=test_dataset,
        callbacks=[
            EarlyStopping(monitor='val_loss', patience=10, min_delta= 0.00005),
            ModelCheckpoint('/MK/128_2048_checkpoint.weights.h5',
                        save_best_only=True, save_weights_only=True)
        ]
    )

In [None]:
plt.style.use('ggplot')
plt.plot(hist.history['loss'][0:], label = f'loss-{L2_NORM}')
plt.plot(hist.history['val_loss'][0:], label=f'val loss-{L2_NORM}')
plt.title("Loss vs Val_Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()
#plt.savefig(f'L2-{L2_NORM}.png')
#plt.close()

In [None]:
# TRAIN IMAGE
def show_dataset_example():
    for l_batch, ab_batch in train_dataset.take(5):
        l_batch = tf.expand_dims(l_batch[0], axis=0)
        print(type(l_batch))
        print(l_batch.shape)
        l_batch
        pred_ab = colorizer.predict(l_batch)[0]
        # print(pred_ab.shape)
        # print(ab_batch[0][0][0])
        # print(l_batch[0][0][0])

        l_denorm = (l_batch[0].numpy() * 256.0).astype(np.uint8)
        tf.image.flip_left_right(ab_batch)
        ab_denorm = ((ab_batch[0].numpy()) * 128.0 + 128.0).astype(np.uint8)
        pred_ab_denorm = ((pred_ab ) * 128.0 + 128.0).astype(np.uint8)
        print(pred_ab[0][0])
        l_denorm2 = (l_batch[0].numpy() * 256.0).astype(np.uint8)

        original_lab = np.concatenate([l_denorm2, ab_denorm], axis=-1)
        pred_lab = np.concatenate([l_denorm2, pred_ab_denorm], axis=-1)


        original_rgb = cv2.cvtColor(original_lab, cv2.COLOR_LAB2RGB)
        pred_rgb = cv2.cvtColor(pred_lab, cv2.COLOR_LAB2RGB)

        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.imshow(l_denorm.squeeze(), cmap='gray')
        plt.title('Input (L channel)')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(original_rgb)
        plt.title('Real')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(pred_rgb)
        plt.title('Prediction')
        plt.axis('off')

        plt.show()

show_dataset_example()

In [None]:
# TEST IMAGE

def show_dataset_example():
    for l_batch, ab_batch in test_dataset.take(5):
        l_batch = tf.expand_dims(l_batch[0], axis=0)
        print(type(l_batch))
        print(l_batch.shape)
        l_batch
        pred_ab = colorizer.predict(l_batch)[0]
        # print(pred_ab.shape)
        # print(ab_batch[0][0][0])
        # print(l_batch[0][0][0])
        l_denorm = (l_batch[0].numpy() * 256.0).astype(np.uint8)
        tf.image.flip_left_right(ab_batch)
        ab_denorm = ((ab_batch[0].numpy()) * 128.0 + 128.0).astype(np.uint8)
        pred_ab_denorm = ((pred_ab ) * 128.0 + 128.0).astype(np.uint8)
        print(pred_ab[0][0])
        l_denorm2 = (l_batch[0].numpy() * 256.0).astype(np.uint8)

        original_lab = np.concatenate([l_denorm2, ab_denorm], axis=-1)
        pred_lab = np.concatenate([l_denorm2, pred_ab_denorm], axis=-1)

        original_rgb = cv2.cvtColor(original_lab, cv2.COLOR_LAB2RGB)
        pred_rgb = cv2.cvtColor(pred_lab, cv2.COLOR_LAB2RGB)

        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.imshow(l_denorm.squeeze(), cmap='gray')
        plt.title('Input (L channel)')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(original_rgb)
        plt.title('Real')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(pred_rgb)
        plt.title('Prediction')
        plt.axis('off')

        plt.show()

show_dataset_example()

In [None]:
colorizer = build_colorizer(SHAPE_, L2_NORM)
# colorizer.compile(optimizer=keras.optimizers.Adam(), loss = keras.losses.MeanSquaredError()) #tf.keras.losses.MSLE
colorizer.compile(optimizer=keras.optimizers.SGD(learning_rate=0.1, weight_decay=1e-6, momentum=0.0, nesterov=True),
                  loss=keras.losses.mean_squared_error)

colorizer.summary()
keras.utils.plot_model(colorizer, show_shapes=True, show_trainable=True, dpi=64)