In [4]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, Concatenate, BatchNormalization, Activation
from tensorflow.keras.preprocessing.image import load_img, img_to_array, ImageDataGenerator
import random
from sklearn.model_selection import KFold
from keras.saving import register_keras_serializable
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import RMSprop




# Register custom functions (loss functions in this case)
@register_keras_serializable()
def mse(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

@register_keras_serializable()
def mae(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_true - y_pred))

# Define paths to your dataset
data_dir = "..//cv_p3_images_split"
train_gray_dir = os.path.join(data_dir, "train/grayscale")
train_color_dir = os.path.join(data_dir, "train/colored")
val_gray_dir = os.path.join(data_dir, "validation/grayscale")
val_color_dir = os.path.join(data_dir, "validation/colored")

# Image dimensions for resizing
IMG_HEIGHT, IMG_WIDTH = 256, 256  # Resize all images to 256x256

# Utility function to preprocess images (loading and resizing)
def preprocess_image(image_path, target_size):
    image = load_img(image_path, target_size=target_size, color_mode="rgb")
    image = img_to_array(image) / 255.0  # Normalize to [0, 1]
    return image

# Function to load a dataset from a folder (grayscale and color images)
def load_dataset(gray_folder, color_folder, target_size, fraction=1.0):
    gray_images = []
    color_images = []
    
    filenames = os.listdir(gray_folder)
    if fraction < 1.0:
        filenames = random.sample(filenames, int(len(filenames) * fraction))  # Sample fraction of filenames
    
    for filename in filenames:
        gray_path = os.path.join(gray_folder, filename)
        color_path = os.path.join(color_folder, filename)
        gray_images.append(preprocess_image(gray_path, target_size)[..., 0:1])  # Extract grayscale channel
        color_images.append(preprocess_image(color_path, target_size))
    
    return np.array(gray_images), np.array(color_images)

# Load train and validation datasets
train_gray, train_color = load_dataset(train_gray_dir, train_color_dir, (IMG_HEIGHT, IMG_WIDTH), 0.25)
val_gray, val_color = load_dataset(val_gray_dir, val_color_dir, (IMG_HEIGHT, IMG_WIDTH), 0.25)

# Define the colorization model architecture
def build_model():
    inputs = Input(shape=(IMG_HEIGHT, IMG_WIDTH, 1))
    
    # Encoder
    x = Conv2D(64, (3, 3), padding="same", strides=1)(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(128, (3, 3), padding="same", strides=2)(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(256, (3, 3), padding="same", strides=2)(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    # Decoder with skip connections
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(128, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = UpSampling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    # Final layer
    x = Conv2D(3, (3, 3), padding="same")(x)
    outputs = Activation("sigmoid")(x)

    return Model(inputs, outputs)


# Define cross-validation setup with KFold (5 folds)
num_folds = 5
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# Create an EarlyStopping callback to prevent overfitting
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

# Perform cross-validation and train the model
for fold, (train_index, val_index) in enumerate(kf.split(train_gray)):
    print(f'Fold {fold + 1}/{num_folds}')

    # Split the data into training and validation for this fold
    fold_train_gray, fold_val_gray = train_gray[train_index], train_gray[val_index]
    fold_train_color, fold_val_color = train_color[train_index], train_color[val_index]

    # Build and compile the model
    model = build_model()
    
    # Compile the model using RMSProp optimizer
    model.compile(
        optimizer=RMSprop(learning_rate=0.001),  # RMSProp optimizer with learning rate 0.001
        loss="mse",  # Mean squared error loss function
        metrics=['accuracy']  # You can add other metrics if needed
    )

    # Train the model
    history = model.fit(
        fold_train_gray, fold_train_color,
        validation_data=(fold_val_gray, fold_val_color),
        epochs=15,
        steps_per_epoch=300,
        batch_size=16,
        callbacks=[early_stopping]  # Use early stopping to prevent overfitting
    )

    # Save the model after training
    model.save(f'colorization_model_fold_{fold + 1}.h5')

    # Optionally, you can save the training history and plot results
    # Save the training and validation loss plot
    plt.figure(figsize=(12, 6))
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title(f'Training and Validation Loss - Fold {fold + 1}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(f'..//raport_sources/training_validation_loss_fold_{fold + 1}.png')  # Save plot
    plt.close()

    # Optionally, you can save predictions for further analysis
    # predictions = model.predict(fold_val_gray)
    # Save or process the predictions as required.

print("Cross-validation completed.") 

Fold 1/5
Epoch 1/15
[1m 74/300[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m18:27[0m 5s/step - accuracy: 0.4722 - loss: 0.0208

  self.gen.throw(value)


[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m380s[0m 1s/step - accuracy: 0.5077 - loss: 0.0156 - val_accuracy: 0.6516 - val_loss: 0.0649
Epoch 2/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m382s[0m 1s/step - accuracy: 0.5549 - loss: 0.0108 - val_accuracy: 0.6507 - val_loss: 0.0588
Epoch 3/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m369s[0m 1s/step - accuracy: 0.5572 - loss: 0.0102 - val_accuracy: 0.6515 - val_loss: 0.0547
Epoch 4/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m363s[0m 1s/step - accuracy: 0.5690 - loss: 0.0101 - val_accuracy: 0.6525 - val_loss: 0.0434
Epoch 5/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m366s[0m 1s/step - accuracy: 0.5707 - loss: 0.0096 - val_accuracy: 0.6529 - val_loss: 0.0345
Epoch 6/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m363s[0m 1s/step - accuracy: 0.5677 - loss: 0.0096 - val_accuracy: 0.6530 - val_loss: 0.0279
Epoch 7/15
[1m300/300[0m [32m━



Fold 2/5
Epoch 1/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m364s[0m 1s/step - accuracy: 0.5155 - loss: 0.0158 - val_accuracy: 0.6083 - val_loss: 0.0653
Epoch 2/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m362s[0m 1s/step - accuracy: 0.5663 - loss: 0.0107 - val_accuracy: 0.6084 - val_loss: 0.0586
Epoch 3/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m360s[0m 1s/step - accuracy: 0.5741 - loss: 0.0104 - val_accuracy: 0.6057 - val_loss: 0.0534
Epoch 4/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m361s[0m 1s/step - accuracy: 0.5735 - loss: 0.0103 - val_accuracy: 0.6022 - val_loss: 0.0425
Epoch 5/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m361s[0m 1s/step - accuracy: 0.5896 - loss: 0.0103 - val_accuracy: 0.5998 - val_loss: 0.0362
Epoch 6/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m361s[0m 1s/step - accuracy: 0.5734 - loss: 0.0097 - val_accuracy: 0.6075 - val_loss: 0.0273
Epoch 7/15
[



Fold 3/5
Epoch 1/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m376s[0m 1s/step - accuracy: 0.5048 - loss: 0.0167 - val_accuracy: 0.6333 - val_loss: 0.0657
Epoch 2/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m374s[0m 1s/step - accuracy: 0.5552 - loss: 0.0107 - val_accuracy: 0.6333 - val_loss: 0.0614
Epoch 3/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m373s[0m 1s/step - accuracy: 0.5610 - loss: 0.0105 - val_accuracy: 0.6333 - val_loss: 0.0519
Epoch 4/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m375s[0m 1s/step - accuracy: 0.5778 - loss: 0.0099 - val_accuracy: 0.6333 - val_loss: 0.0440
Epoch 5/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m372s[0m 1s/step - accuracy: 0.5805 - loss: 0.0100 - val_accuracy: 0.6336 - val_loss: 0.0373
Epoch 6/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m374s[0m 1s/step - accuracy: 0.5804 - loss: 0.0096 - val_accuracy: 0.6319 - val_loss: 0.0257
Epoch 7/15
[



Fold 4/5
Epoch 1/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m378s[0m 1s/step - accuracy: 0.5253 - loss: 0.0156 - val_accuracy: 0.6120 - val_loss: 0.0642
Epoch 2/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m375s[0m 1s/step - accuracy: 0.5584 - loss: 0.0108 - val_accuracy: 0.6095 - val_loss: 0.0566
Epoch 3/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m376s[0m 1s/step - accuracy: 0.5763 - loss: 0.0107 - val_accuracy: 0.6136 - val_loss: 0.0494
Epoch 4/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m371s[0m 1s/step - accuracy: 0.5799 - loss: 0.0102 - val_accuracy: 0.6136 - val_loss: 0.0434
Epoch 5/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m372s[0m 1s/step - accuracy: 0.5801 - loss: 0.0100 - val_accuracy: 0.6136 - val_loss: 0.0329
Epoch 6/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m373s[0m 1s/step - accuracy: 0.5891 - loss: 0.0100 - val_accuracy: 0.6138 - val_loss: 0.0227
Epoch 7/15
[



Fold 5/5
Epoch 1/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m377s[0m 1s/step - accuracy: 0.5027 - loss: 0.0161 - val_accuracy: 0.6215 - val_loss: 0.0653
Epoch 2/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m374s[0m 1s/step - accuracy: 0.5455 - loss: 0.0108 - val_accuracy: 0.6231 - val_loss: 0.0613
Epoch 3/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m375s[0m 1s/step - accuracy: 0.5655 - loss: 0.0104 - val_accuracy: 0.6236 - val_loss: 0.0518
Epoch 4/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m376s[0m 1s/step - accuracy: 0.5686 - loss: 0.0097 - val_accuracy: 0.6233 - val_loss: 0.0437
Epoch 5/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m375s[0m 1s/step - accuracy: 0.5721 - loss: 0.0097 - val_accuracy: 0.6191 - val_loss: 0.0349
Epoch 6/15
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m376s[0m 1s/step - accuracy: 0.5749 - loss: 0.0098 - val_accuracy: 0.6111 - val_loss: 0.0245
Epoch 7/15
[



Cross-validation completed.
