In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import Input
from tensorflow.keras.utils import img_to_array, array_to_img
from tensorflow.keras.layers import Dense, Conv2D, Conv2DTranspose, ReLU, BatchNormalization, Concatenate
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from google.colab import drive
import matplotlib.pyplot as plt
import os
from PIL import Image

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
train_images_folder = '/content/drive/MyDrive/coco_test2017'
val_images_folder = '/content/drive/MyDrive/coco_test2017/val'
model_path = '/content/drive/MyDrive/image colorization/grayscale_to_rgb_model_coco_test_2017.keras'

In [None]:
layer_filters = 32
image_size = 512

# Inputs
model_inputs = Input(shape=(image_size, image_size, 1))
x = model_inputs

# Convolutional blocks
x = Conv2D(filters=layer_filters * 1, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
conv_block_1 = x

x = Conv2D(filters=layer_filters * 2, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
conv_block_2 = x

x = Conv2D(filters=layer_filters * 3, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
conv_block_3 = x

x = Conv2D(filters=layer_filters * 4, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
conv_block_4 = x

x = Conv2D(filters=layer_filters * 5, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
conv_block_5 = x

x = Conv2D(filters=layer_filters * 6, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
conv_block_6 = x

x = Conv2D(filters=layer_filters * 7, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
conv_block_7 = x

# Transpose convolution blocks

x = Conv2DTranspose(filters=layer_filters * 7, kernel_size=3, strides=1, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
convt_block_1 = x
x = Concatenate()([convt_block_1, conv_block_7])

x = Conv2DTranspose(filters=layer_filters * 7, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
convt_block_2 = x
x = Concatenate()([convt_block_2, conv_block_6])

x = Conv2DTranspose(filters=layer_filters * 6, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
convt_block_3 = x
x = Concatenate()([convt_block_3, conv_block_5])

x = Conv2DTranspose(filters=layer_filters * 5, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
convt_block_4 = x
x = Concatenate()([convt_block_4, conv_block_4])

x = Conv2DTranspose(filters=layer_filters * 4, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
convt_block_5 = x
x = Concatenate()([convt_block_5, conv_block_3])

x = Conv2DTranspose(filters=layer_filters * 3, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
convt_block_6 = x
x = Concatenate()([convt_block_6, conv_block_2])

x = Conv2DTranspose(filters=layer_filters * 2, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
convt_block_7 = x
x = Concatenate()([convt_block_7, conv_block_1])

x = Conv2DTranspose(filters=layer_filters * 1, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
convt_block_8 = x
x = Concatenate()([convt_block_8, model_inputs])

model_outputs = Conv2DTranspose(filters=3, kernel_size=3, strides=1, padding='same')(x)
model = tf.keras.Model(model_inputs, model_outputs)

In [None]:
loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam()
metrics = ['accuracy']

cp_callback = tf.keras.callbacks.ModelCheckpoint(model_path, monitor='loss', save_best_only=True)

model = tf.keras.Model(model_inputs, model_outputs)
model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
model.summary()

To do:
* Read images from directory
* Reshape images to be 512 * 512 * 1
* Normal preprocessing - batching, 
* Train the model
* Make predictions
* Reshape the predictions to the initial shape

In [None]:
# from keras.utils.image_dataset import image_dataset_from_directory
rgb_datagen = ImageDataGenerator(rescale=1./255)
grayscale_datagen = ImageDataGenerator(rescale=1./255)

rgb_generator = rgb_datagen.flow_from_directory(
    train_images_folder,
    target_size=(512, 512),
    color_mode='rgb',
    batch_size=32,
    shuffle=False,
    class_mode=None
)

grayscale_generator = grayscale_datagen.flow_from_directory(
    train_images_folder,
    target_size=(512, 512),
    color_mode='grayscale',
    batch_size=32,
    shuffle=False,
    class_mode=None
)

train_generator = zip(grayscale_generator, rgb_generator)

Found 40669 images belonging to 1 classes.
Found 40669 images belonging to 1 classes.


In [None]:
# seed = 1024
# batch_size = 32

# train_images = tf.keras.utils.image_dataset_from_directory(
#     train_images_folder,
#     labels=None,
#     label_mode=None,
#     class_names=None,
#     color_mode='rgb',
#     batch_size=None,
#     image_size=(512, 512),
#     shuffle=False,
#     seed=seed,
#     validation_split=None,
#     subset=None,
#     interpolation='bilinear',
#     follow_links=False,
#     crop_to_aspect_ratio=False
# )

# val_images = tf.keras.utils.image_dataset_from_directory(
#     val_images_folder,
#     labels=None,
#     label_mode=None,
#     class_names=None,
#     color_mode='rgb',
#     batch_size=None,
#     image_size=(512, 512),
#     shuffle=False,
#     seed=seed,
#     validation_split=None,
#     subset=None,
#     interpolation='bilinear',
#     follow_links=False,
#     crop_to_aspect_ratio=False
# )

In [None]:
# AUTOTUNE = tf.data.AUTOTUNE
# def normalize_images(image):
#     return tf.cast(tf.image.rgb_to_grayscale(image), tf.float32)/255., tf.cast(image, tf.float32)/255.

# def preprocess_dataset(dataset):
#     return dataset.shuffle(seed).batch(batch_size, drop_remainder=True).prefetch(AUTOTUNE)

# train_images_normalized = train_images.map(normalize_images, num_parallel_calls=AUTOTUNE)
# preprocessed_train_dataset = preprocess_dataset(train_images_normalized)

# val_images_normalized = val_images.map(normalize_images, num_parallel_calls=AUTOTUNE)
# preprocessed_val_dataset = preprocess_dataset(val_images_normalized)

In [None]:
steps_per_epoch = 40669/32
history = model.fit_generator(train_generator, epochs=100, steps_per_epoch=steps_per_epoch, callbacks=[cp_callback])

  history = model.fit_generator(train_generator, epochs=100, steps_per_epoch=steps_per_epoch, callbacks=[cp_callback])


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100

Test data:
* Preprocess - batch_size of 1
* Have the real rgb images to compare with the predictions

In [None]:
ken = np.array(2, 5)

In [None]:
# def normalize_val_images_grayscale(image):
#     return tf.cast(tf.image.rgb_to_grayscale(image), tf.float32)/255.

# def normalize_val_images_rgb(image):
#     return tf.cast(image, tf.float32)/255.

# def preprocess_val_dataset(dataset):
#     return dataset.batch(1).prefetch(AUTOTUNE)

In [None]:
# normalized_grayscale_val_images = val_images.map(normalize_val_images_grayscale, num_parallel_calls=AUTOTUNE)
# preprocessed_grayscale_val_dataset = preprocess_val_dataset(normalized_grayscale_val_images)

# normalized_rgb_val_images = val_images.map(normalize_val_images_rgb, num_parallel_calls=AUTOTUNE)
# preprocessd_rgb_val_dataset = preprocess_val_dataset(normalized_rgb_val_images)

In [None]:
# model = tf.keras.models.load_model(model_path)
# predictions = model.predict(preprocessed_grayscale_val_dataset)

In [None]:
# actual_grayscale_arrays = [np.array(image) for image in normalized_grayscale_val_images]
# actual_rgb_arrays = [np.array(image) for image in normalized_rgb_val_images]

In [None]:
# rand_int = int(np.random.randint(low=0, high=500, size=1))
# plt.imshow(array_to_img(actual_grayscale_arrays[rand_int]), cmap='gray')

In [None]:
# plt.imshow(array_to_img(predictions[rand_int]))

In [None]:
# plt.imshow(array_to_img(actual_rgb_arrays[rand_int]))