In [None]:
import cv2
from keras.models import model_from_json, Sequential
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.layers import BatchNormalization, Activation, Dense, Dropout, Flatten, InputLayer, Conv2D, UpSampling2D
from tensorflow.keras.utils import img_to_array, load_img
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from skimage.color import rgb2lab, lab2rgb
import numpy as np
import os
import tensorflow as tf

In [None]:


# Data generator
def data_generator(directory, batch_size):
    datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        rotation_range=20,
        horizontal_flip=True
    )
    while True:
        file_list = os.listdir(directory)
        random.shuffle(file_list)
        for batch_start in range(0, len(file_list), batch_size):
            batch_files = file_list[batch_start:batch_start + batch_size]
            batch_images = []
            for filename in batch_files:
                img = load_img(os.path.join(directory, filename), target_size=(256, 256))
                img_array = img_to_array(img)
                batch_images.append(img_array)
            batch_images = np.array(batch_images, dtype=float)
            lab_batch = rgb2lab(batch_images / 255.0)
            X_batch = lab_batch[:, :, :, 0]
            Y_batch = lab_batch[:, :, :, 1:] / 128.0
            yield (X_batch.reshape(X_batch.shape + (1,)), Y_batch)


In [None]:

# Set paths
train_data_dir = 'Dataset/Train/'
test_data_dir = 'Dataset/Test/'
save_model_dir = 'Path/To/Save/Model/'  # Change this to the desired path for saving your model


In [None]:

# Parameters
batch_size = 10
steps_per_epoch = len(os.listdir(train_data_dir)) // batch_size
epochs = 500


In [None]:

# CNN model
model = Sequential([
    Conv2D(64, (3, 3), input_shape=(256, 256, 1), activation='relu', padding='same'),
    Conv2D(64, (3, 3), activation='relu', padding='same', strides=2),
    Conv2D(128, (3, 3), activation='relu', padding='same'),
    Conv2D(128, (3, 3), activation='relu', padding='same', strides=2),
    Conv2D(256, (3, 3), activation='relu', padding='same'),
    Conv2D(256, (3, 3), activation='relu', padding='same', strides=2),
    Conv2D(512, (3, 3), activation='relu', padding='same'),
    Conv2D(256, (3, 3), activation='relu', padding='same'),
    Conv2D(128, (3, 3), activation='relu', padding='same'),
    UpSampling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu', padding='same'),
    UpSampling2D((2, 2)),
    Conv2D(32, (3, 3), activation='relu', padding='same'),
    Conv2D(2, (3, 3), activation='tanh', padding='same'),
    UpSampling2D((2, 2))
])


In [None]:

# Compile the CNN
model.compile(optimizer='rmsprop', loss='mse', metrics=['accuracy'])


In [None]:

# Set up callbacks
tensorboard = TensorBoard(log_dir="Dataset/output/beta_run")
checkpoint = ModelCheckpoint(filepath=os.path.join(save_model_dir, 'model-{epoch:02d}.h5'), save_weights_only=True, period=10)


In [None]:
# Train model using the generator
train_generator = data_generator(train_data_dir, batch_size)
history = model.fit(train_generator, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[tensorboard, checkpoint])


In [None]:

# Summarize history for model accuracy
plt.plot(history.history['accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train'], loc='upper left')
plt.show()


In [None]:


# Summarize history for model loss
plt.plot(history.history['loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train'], loc='upper left')
plt.show()


In [None]:

# Save model architecture and weights
model_json = model.to_json()
with open(os.path.join(save_model_dir, "model.json"), "w") as json_file:
    json_file.write(model_json)
model.save_weights(os.path.join(save_model_dir, "model.h5"))


In [None]:

# Load json and create model
json_file = open(os.path.join(save_model_dir, 'model.json'), 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)

In [None]:

# Load weights into new model
loaded_model.load_weights(os.path.join(save_model_dir, "model.h5"))

loaded_model.summary()


In [None]:

# Test images
loaded_model.compile(optimizer='rmsprop', loss='mse', metrics=['accuracy'])

def preprocess_test_images(directory):
    file_list = os.listdir(directory)
    images = []
    for filename in file_list:
        img = load_img(os.path.join(directory, filename), target_size=(256, 256))
        img_array = img_to_array(img)
        images.append(img_array)
    images = np.array(images, dtype=float)
    lab_images = rgb2lab(images / 255.0)
    X = lab_images[:, :, :, 0]
    Y = lab_images[:, :, :, 1:] / 128.0
    return X.reshape(X.shape + (1,)), Y

Xtest, Ytest = preprocess_test_images(test_data_dir)
print(loaded_model.evaluate(Xtest, Ytest, batch_size=batch_size))


In [None]:

# Display test results
fig, ax = plt.subplots(24, 2, figsize=(16, 100))
row = 0
colorize = []

print('Output of the Model')

for filename in os.listdir(test_data_dir):
    img = cv2.imread(os.path.join(test_data_dir, filename))

    if img is None:
        print(f"Couldn't read image {filename}. Skipping.")
        continue

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_resized = cv2.resize(img, (256, 256))
    colorize.append(img_resized)

    ax[row, 0].imshow(cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB), interpolation='nearest')
    row += 1

colorize = np.array(colorize, dtype=float)
colorize = rgb2lab(1.0 / 255 * colorize)[:, :, :, 0]
colorize = colorize.reshape(colorize.shape + (1,))


In [None]:

# Test model
output = loaded_model.predict(colorize)
output *= 128

row = 0


In [None]:

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:, :, 0] = colorize[i][:, :, 0]
    cur[:, :, 1:] = output[i]
    resImage = lab2rgb(cur)

    ax[row, 1].imshow(resImage, interpolation='nearest')
    row += 1

plt.show()
