In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, Concatenate, BatchNormalization, Activation
from tensorflow.keras.preprocessing.image import load_img, img_to_array, ImageDataGenerator
import random
from sklearn.model_selection import KFold

# Define paths
data_dir = "cv_p3_images_split"
train_gray_dir = os.path.join(data_dir, "train/grayscale")
train_color_dir = os.path.join(data_dir, "train/colored")
val_gray_dir = os.path.join(data_dir, "validation/grayscale")
val_color_dir = os.path.join(data_dir, "validation/colored")

# Image dimensions
IMG_HEIGHT, IMG_WIDTH = 256, 256  # Resize all images to 256x256

# Utility to preprocess images
def preprocess_image(image_path, target_size):
    image = load_img(image_path, target_size=target_size, color_mode="rgb")
    image = img_to_array(image) / 255.0  # Normalize to [0, 1]
    return image

def load_images_from_folder(folder, target_size):
    images = []
    for filename in os.listdir(folder):
        img_path = os.path.join(folder, filename)
        images.append(preprocess_image(img_path, target_size))
    return np.array(images)


def load_dataset(gray_folder, color_folder, target_size, fraction=1.0):
    gray_images = []
    color_images = []
    
    filenames = os.listdir(gray_folder)
    if fraction < 1.0:
        filenames = random.sample(filenames, int(len(filenames) * fraction))  # Sample fraction of filenames
    
    for filename in filenames:
        gray_path = os.path.join(gray_folder, filename)
        color_path = os.path.join(color_folder, filename)
        gray_images.append(preprocess_image(gray_path, target_size)[..., 0:1])  # Extract grayscale channel
        color_images.append(preprocess_image(color_path, target_size))
    
    return np.array(gray_images), np.array(color_images)


train_gray, train_color = load_dataset(train_gray_dir, train_color_dir, (IMG_HEIGHT, IMG_WIDTH), 0.25)
val_gray, val_color = load_dataset(val_gray_dir, val_color_dir, (IMG_HEIGHT, IMG_WIDTH), 0.25)

# Define the colorization model
def build_model(loss="mse"):
    inputs = Input(shape=(IMG_HEIGHT, IMG_WIDTH, 1))

    # Encoder
    x = Conv2D(32, (3, 3), padding="same", strides=2)(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(64, (3, 3), padding="same", strides=2)(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    # Decoder
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(32, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = UpSampling2D((2, 2))(x)
    x = Conv2D(3, (3, 3), padding="same")(x)
    outputs = Activation("sigmoid")(x)

    return Model(inputs, outputs)

# Define the number of folds for cross-validation
num_folds = 5
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# Prepare the data
gray_images, color_images = load_dataset(train_gray_dir, train_color_dir, (IMG_HEIGHT, IMG_WIDTH))

# Perform cross-validation
for fold, (train_index, val_index) in enumerate(kf.split(gray_images)):
    print(f'Fold {fold + 1}/{num_folds}')
    train_gray, val_gray = gray_images[train_index], gray_images[val_index]
    train_color, val_color = color_images[train_index], color_images[val_index]

    # Build and compile the model
    model = build_model()
    model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

    # Train the model
    history = model.fit(
        train_gray, train_color,
        validation_data=(val_gray, val_color),
        epochs=20,
        batch_size=16
    )

    # Save the model for each fold
    model.save(f'colorization_model_fold_{fold + 1}.h5')

# Optionally, you can average the results or perform further analysis on the cross-validation results





Fold 1/5
Epoch 1/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m108s[0m 359ms/step - accuracy: 0.5321 - loss: 0.0157 - val_accuracy: 0.6256 - val_loss: 0.0371
Epoch 2/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 479ms/step - accuracy: 0.5866 - loss: 0.0098 - val_accuracy: 0.5990 - val_loss: 0.0095
Epoch 3/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 326ms/step - accuracy: 0.5857 - loss: 0.0093 - val_accuracy: 0.5394 - val_loss: 0.0077
Epoch 4/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 307ms/step - accuracy: 0.5776 - loss: 0.0092 - val_accuracy: 0.6123 - val_loss: 0.0091
Epoch 5/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 301ms/step - accuracy: 0.6006 - loss: 0.0086 - val_accuracy: 0.5728 - val_loss: 0.0075
Epoch 6/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m90s[0m 302ms/step - accuracy: 0.5833 - loss: 0.0088 - val_accuracy: 0.6032 - val_loss: 0.0078



Fold 2/5
Epoch 1/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 318ms/step - accuracy: 0.5329 - loss: 0.0168 - val_accuracy: 0.6146 - val_loss: 0.0362
Epoch 2/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 317ms/step - accuracy: 0.5735 - loss: 0.0098 - val_accuracy: 0.5892 - val_loss: 0.0101
Epoch 3/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 316ms/step - accuracy: 0.5787 - loss: 0.0094 - val_accuracy: 0.5850 - val_loss: 0.0086
Epoch 4/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 318ms/step - accuracy: 0.5801 - loss: 0.0093 - val_accuracy: 0.5670 - val_loss: 0.0082
Epoch 5/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 314ms/step - accuracy: 0.5869 - loss: 0.0087 - val_accuracy: 0.5793 - val_loss: 0.0083
Epoch 6/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 320ms/step - accuracy: 0.5931 - loss: 0.0088 - val_accuracy: 0.5868 - val_loss: 0.0081
E



Fold 3/5
Epoch 1/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 319ms/step - accuracy: 0.5362 - loss: 0.0173 - val_accuracy: 0.6209 - val_loss: 0.0358
Epoch 2/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 319ms/step - accuracy: 0.5829 - loss: 0.0098 - val_accuracy: 0.5921 - val_loss: 0.0105
Epoch 3/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 319ms/step - accuracy: 0.5911 - loss: 0.0095 - val_accuracy: 0.6146 - val_loss: 0.0084
Epoch 4/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 319ms/step - accuracy: 0.5921 - loss: 0.0091 - val_accuracy: 0.6103 - val_loss: 0.0080
Epoch 5/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 319ms/step - accuracy: 0.5944 - loss: 0.0091 - val_accuracy: 0.6129 - val_loss: 0.0078
Epoch 6/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m94s[0m 318ms/step - accuracy: 0.5989 - loss: 0.0086 - val_accuracy: 0.5476 - val_loss: 0.0080
E



Fold 4/5
Epoch 1/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m99s[0m 328ms/step - accuracy: 0.5103 - loss: 0.0164 - val_accuracy: 0.6237 - val_loss: 0.0334
Epoch 2/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 323ms/step - accuracy: 0.5789 - loss: 0.0098 - val_accuracy: 0.5214 - val_loss: 0.0094
Epoch 3/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 320ms/step - accuracy: 0.5841 - loss: 0.0097 - val_accuracy: 0.5356 - val_loss: 0.0079
Epoch 4/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 320ms/step - accuracy: 0.5862 - loss: 0.0089 - val_accuracy: 0.5538 - val_loss: 0.0081
Epoch 5/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 322ms/step - accuracy: 0.5846 - loss: 0.0090 - val_accuracy: 0.5661 - val_loss: 0.0075
Epoch 6/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 320ms/step - accuracy: 0.5970 - loss: 0.0087 - val_accuracy: 0.6187 - val_loss: 0.0082
E



Fold 5/5
Epoch 1/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m97s[0m 324ms/step - accuracy: 0.5247 - loss: 0.0163 - val_accuracy: 0.6319 - val_loss: 0.0354
Epoch 2/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 324ms/step - accuracy: 0.5786 - loss: 0.0099 - val_accuracy: 0.6269 - val_loss: 0.0115
Epoch 3/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 325ms/step - accuracy: 0.5898 - loss: 0.0091 - val_accuracy: 0.6160 - val_loss: 0.0085
Epoch 4/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 325ms/step - accuracy: 0.5922 - loss: 0.0089 - val_accuracy: 0.6143 - val_loss: 0.0082
Epoch 5/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 324ms/step - accuracy: 0.5866 - loss: 0.0089 - val_accuracy: 0.6214 - val_loss: 0.0078
Epoch 6/20
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m97s[0m 329ms/step - accuracy: 0.5974 - loss: 0.0084 - val_accuracy: 0.6273 - val_loss: 0.0078
E

