In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, Concatenate, BatchNormalization, Activation
from tensorflow.keras.preprocessing.image import load_img, img_to_array, ImageDataGenerator
import random

# Define paths
data_dir = "cv_p3_images_split"
train_gray_dir = os.path.join(data_dir, "train/grayscale")
train_color_dir = os.path.join(data_dir, "train/colored")
val_gray_dir = os.path.join(data_dir, "validation/grayscale")
val_color_dir = os.path.join(data_dir, "validation/colored")

# Image dimensions
IMG_HEIGHT, IMG_WIDTH = 256, 256  # Resize all images to 256x256

# Utility to preprocess images
def preprocess_image(image_path, target_size):
    image = load_img(image_path, target_size=target_size, color_mode="rgb")
    image = img_to_array(image) / 255.0  # Normalize to [0, 1]
    return image

def load_images_from_folder(folder, target_size):
    images = []
    for filename in os.listdir(folder):
        img_path = os.path.join(folder, filename)
        images.append(preprocess_image(img_path, target_size))
    return np.array(images)


def load_dataset(gray_folder, color_folder, target_size, fraction=1.0):
    gray_images = []
    color_images = []
    
    filenames = os.listdir(gray_folder)
    if fraction < 1.0:
        filenames = random.sample(filenames, int(len(filenames) * fraction))  # Sample fraction of filenames
    
    for filename in filenames:
        gray_path = os.path.join(gray_folder, filename)
        color_path = os.path.join(color_folder, filename)
        gray_images.append(preprocess_image(gray_path, target_size)[..., 0:1])  # Extract grayscale channel
        color_images.append(preprocess_image(color_path, target_size))
    
    return np.array(gray_images), np.array(color_images)


train_gray, train_color = load_dataset(train_gray_dir, train_color_dir, (IMG_HEIGHT, IMG_WIDTH), 0.25)
val_gray, val_color = load_dataset(val_gray_dir, val_color_dir, (IMG_HEIGHT, IMG_WIDTH), 0.25)

# Define the colorization model
def build_model(loss="mse"):
    inputs = Input(shape=(IMG_HEIGHT, IMG_WIDTH, 1))

    # Encoder
    x = Conv2D(32, (3, 3), padding="same", strides=2)(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(64, (3, 3), padding="same", strides=2)(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    # Decoder
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(32, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = UpSampling2D((2, 2))(x)
    x = Conv2D(3, (3, 3), padding="same")(x)
    outputs = Activation("sigmoid")(x)

    return Model(inputs, outputs)

model_mse = build_model()
model_mse.compile(optimizer="adam", loss="mse", metrics=["accuracy"])
history = model_mse.fit(
    train_gray, train_color,
    validation_data=(val_gray, val_color),
    epochs=20,
    batch_size=16
)
model_mse.save("colorization_model_mse.h5")

model_mae = build_model(loss="mae")
model_mae.compile(optimizer="adam", loss="mae", metrics=["accuracy"])
history = model_mae.fit(
    train_gray, train_color,
    validation_data=(val_gray, val_color),
    epochs=20,
    batch_size=16
)
model_mae.save("colorization_model_mae.h5")

model_bce = build_model(loss="binary_crossentropy")
model_bce.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
history = model_bce.fit(
    train_gray, train_color,
    validation_data=(val_gray, val_color),
    epochs=20,
    batch_size=16
)
model_bce.save("colorization_model_bce.h5")

Epoch 1/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 323ms/step - accuracy: 0.4909 - loss: 0.0194 - val_accuracy: 0.6255 - val_loss: 0.0597
Epoch 2/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 299ms/step - accuracy: 0.5813 - loss: 0.0106 - val_accuracy: 0.6262 - val_loss: 0.0493
Epoch 3/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 302ms/step - accuracy: 0.5674 - loss: 0.0101 - val_accuracy: 0.6248 - val_loss: 0.0370
Epoch 4/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 300ms/step - accuracy: 0.5700 - loss: 0.0096 - val_accuracy: 0.6227 - val_loss: 0.0275
Epoch 5/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 300ms/step - accuracy: 0.5720 - loss: 0.0100 - val_accuracy: 0.6238 - val_loss: 0.0139
Epoch 6/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 299ms/step - accuracy: 0.5891 - loss: 0.0096 - val_accuracy: 0.5784 - val_loss: 0.0114
Epoch 7/20
[1m93/93[



Epoch 1/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 302ms/step - accuracy: 0.5105 - loss: 0.0975 - val_accuracy: 0.6230 - val_loss: 0.2113
Epoch 2/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 303ms/step - accuracy: 0.5680 - loss: 0.0719 - val_accuracy: 0.6248 - val_loss: 0.1916
Epoch 3/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 307ms/step - accuracy: 0.5868 - loss: 0.0714 - val_accuracy: 0.5781 - val_loss: 0.1696
Epoch 4/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 304ms/step - accuracy: 0.5666 - loss: 0.0717 - val_accuracy: 0.6259 - val_loss: 0.1485
Epoch 5/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 301ms/step - accuracy: 0.5849 - loss: 0.0654 - val_accuracy: 0.6229 - val_loss: 0.1115
Epoch 6/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 303ms/step - accuracy: 0.5735 - loss: 0.0684 - val_accuracy: 0.6180 - val_loss: 0.0906
Epoch 7/20
[1m93/93[



Epoch 1/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 320ms/step - accuracy: 0.4787 - loss: 0.5781 - val_accuracy: 0.6260 - val_loss: 0.6523
Epoch 2/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 313ms/step - accuracy: 0.5545 - loss: 0.5449 - val_accuracy: 0.6260 - val_loss: 0.6388
Epoch 3/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 316ms/step - accuracy: 0.5808 - loss: 0.5441 - val_accuracy: 0.6256 - val_loss: 0.6067
Epoch 4/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 317ms/step - accuracy: 0.5763 - loss: 0.5420 - val_accuracy: 0.6248 - val_loss: 0.5917
Epoch 5/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 315ms/step - accuracy: 0.5784 - loss: 0.5403 - val_accuracy: 0.6221 - val_loss: 0.5587
Epoch 6/20
[1m93/93[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 317ms/step - accuracy: 0.5696 - loss: 0.5441 - val_accuracy: 0.6220 - val_loss: 0.5442
Epoch 7/20
[1m93/93[

