In [1]:
# Source for Images: https://www.kaggle.com/code/basu369victor/image-colorization-basic-implementation-with-cnn/input
# Source for Images: https://github.com/guilbera/colorizing
# Source for Code (Reference): https://anne-guilbert.medium.com/black-and-white-image-colorization-with-deep-learning-53855922cda6

# Import needed libraries
from keras.models import Sequential
from keras.layers import Conv2D, UpSampling2D
from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb
from skimage.transform import resize
from skimage.io import imsave, imshow
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
plt.style.use("fivethirtyeight")

  "class": algorithms.Blowfish,


In [2]:
root_dir = "./datasets/training"

img_data_gen = ImageDataGenerator(rescale=(1.0 / 255))

training_images = img_data_gen.flow_from_directory(directory=root_dir, target_size=(224, 224), shuffle=False, batch_size=9294, class_mode=None)
training_images, type(training_images), len(training_images)

Found 9294 images belonging to 1 classes.


(<keras.src.preprocessing.image.DirectoryIterator at 0x1bc267f89d0>,
 keras.src.preprocessing.image.DirectoryIterator,
 1)

In [3]:
X, y = [], []
for image in training_images[0]:
    try:
        image_LAB = rgb2lab(image)
        X.append(image_LAB[:, :, 0])
        y.append((image_LAB[:,:, 1:]) / 128)
    except:
        print("Error in conversion or calculation")
X = np.array(X)
y = np.array(y)
X = X.reshape(X.shape + (1,))
X.shape, y.shape

((9294, 224, 224, 1), (9294, 224, 224, 2))

In [4]:
# Encoder
model = Sequential()
model.add(Conv2D(64, (3, 3), activation="relu", padding="same", strides=2, input_shape=(224, 224, 1)))
model.add(Conv2D(128, (3, 3), activation="relu", padding="same"))
model.add(Conv2D(128, (3, 3), activation="relu", padding="same", strides=2))
model.add(Conv2D(256, (3, 3), activation="relu", padding="same"))
model.add(Conv2D(256, (3, 3), activation="relu", padding="same", strides=2))
model.add(Conv2D(512, (3, 3), activation="relu", padding="same"))
model.add(Conv2D(512, (3, 3), activation="relu", padding="same"))
model.add(Conv2D(256, (3, 3), activation="relu", padding="same"))

In [5]:
# Decoder
model.add(Conv2D(128, (3, 3), activation="relu", padding="same"))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation="relu", padding="same"))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(32, (3, 3), activation="relu", padding="same"))
model.add(Conv2D(16, (3, 3), activation="relu", padding="same"))
model.add(Conv2D(2, (3, 3), activation="tanh", padding="same"))
model.add(UpSampling2D((2, 2)))
model.compile(optimizer="adam", loss="mse", metrics=["accuracy"])
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 112, 112, 64)      640       
                                                                 
 conv2d_1 (Conv2D)           (None, 112, 112, 128)     73856     
                                                                 
 conv2d_2 (Conv2D)           (None, 56, 56, 128)       147584    
                                                                 
 conv2d_3 (Conv2D)           (None, 56, 56, 256)       295168    
                                                                 
 conv2d_4 (Conv2D)           (None, 28, 28, 256)       590080    
                                                                 
 conv2d_5 (Conv2D)           (None, 28, 28, 512)       1180160   
                                                                 
 conv2d_6 (Conv2D)           (None, 28, 28, 512)       2

In [6]:
# tf.config.run_functions_eagerly(True)
# strategy = tf.distribute.MirroredStrategy()
# model.add(tf.keras.layers.BatchNormalization(synchronized=True))

model.fit(X, y, validation_split=0.1, epochs=300)
model.save("models/image_colorization_cnn.model")

Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300

KeyboardInterrupt: 

In [None]:
# Needed if you are loading in the model here from local file directory
model = tf.keras.models.load_model("models/image_colorization_cnn.model",
                                   custom_objects=None,
                                   compile=True)

In [None]:
test_img_file_path = "./datasets/testing/test_images/"
for index, file in enumerate(os.listdir(test_img_file_path)):
    print("{}: {}".format(index, test_img_file_path + file))
    test_img_arr = []
    test_img = load_img(test_img_file_path + file)
    test_img = img_to_array(test_img)
    test_img = resize(test_img, (224, 224), anti_aliasing=True)
    test_img_arr.append(test_img)

    test_img_arr = np.array(test_img_arr, dtype=float)
    test_img_arr = rgb2lab((1.0 / 255) * test_img_arr)[:,:,:,0]
    test_img_arr = test_img_arr.reshape(test_img_arr.shape + (1,))

    predict_img = model.predict(test_img_arr) 
    predict_img *= 128

    final_predict_color_img_result = np.zeros((224, 224, 3))
    final_predict_color_img_result[:,:,0] = test_img_arr[0][:,:,0]
    final_predict_color_img_result[:,:,1:] = predict_img[0]
    imshow((lab2rgb(final_predict_color_img_result) * 255).astype(np.uint8))
    imsave("result-{}.png".format(index), (lab2rgb(final_predict_color_img_result) * 255).astype(np.uint8))