In [3]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, GlobalAveragePooling2D, Dense, Flatten
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.image import array_to_img

# Constants
IMG_HEIGHT, IMG_WIDTH = 224, 224  # Input dimensions for ResNet50
BATCH_SIZE = 32

# Paths to datasets
DATA_DIR = "cv_p3_images_split"
TRAIN_GRAY_DIR = os.path.join(DATA_DIR, "train/grayscale")
TRAIN_COLOR_DIR = os.path.join(DATA_DIR, "train/colored")
VAL_GRAY_DIR = os.path.join(DATA_DIR, "validation/grayscale")
VAL_COLOR_DIR = os.path.join(DATA_DIR, "validation/colored")
TEST_GRAY_DIR = os.path.join(DATA_DIR, "test/grayscale")
TEST_COLOR_DIR = os.path.join(DATA_DIR, "test/colored")

# Function to load and preprocess images
def preprocess_image(image_path, target_size=(IMG_HEIGHT, IMG_WIDTH), is_grayscale=False):
    img = load_img(image_path, target_size=target_size, color_mode="grayscale" if is_grayscale else "rgb")
    img = img_to_array(img)
    if is_grayscale:
        img = img / 255.0  # Normalize grayscale images
    else:
        img = img / 255.0  # Normalize RGB images
    return img

# Load dataset
def load_data(gray_dir, color_dir):
    gray_images = []
    color_images = []

    for filename in os.listdir(gray_dir):
        gray_path = os.path.join(gray_dir, filename)
        color_path = os.path.join(color_dir, filename)
        
        gray_images.append(preprocess_image(gray_path, is_grayscale=True))
        color_images.append(preprocess_image(color_path))

    return np.array(gray_images), np.array(color_images)

# Load train and validation data
train_gray, train_color = load_data(TRAIN_GRAY_DIR, TRAIN_COLOR_DIR)
val_gray, val_color = load_data(VAL_GRAY_DIR, VAL_COLOR_DIR)

def build_colorization_model():
    # Grayscale input
    input_layer = Input(shape=(IMG_HEIGHT, IMG_WIDTH, 1))
    
    # Convert grayscale to 3 channels for pretrained model compatibility
    x = Conv2D(3, (3, 3), padding="same", activation="relu")(input_layer)
    
    # Pretrained ResNet50 as the base model
    base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))
    base_model.trainable = False  # Freeze base model initially
    x = base_model(x)

    # Add upsampling layers to restore dimensions to (224, 224, 3)
    x = UpSampling2D((2, 2))(x)  # Upsample to (56, 56, channels)
    x = UpSampling2D((2, 2))(x)  # Upsample to (112, 112, channels)
    x = UpSampling2D((2, 2))(x)  # Upsample to (224, 224, channels)
    x = Conv2D(64, (3, 3), activation="relu", padding="same")(x)
    x = Conv2D(3, (3, 3), activation="sigmoid", padding="same")(x)  # Final output layer
    
    return Model(inputs=input_layer, outputs=x)


# Build and compile model
model = build_colorization_model()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss="mse", metrics=["accuracy"])

# Training
early_stopping = EarlyStopping(patience=5, restore_best_weights=True)
history = model.fit(
    train_gray,
    train_color,
    validation_data=(val_gray, val_color),
    batch_size=BATCH_SIZE,
    epochs=20,
    callbacks=[early_stopping]
)

# Evaluate on test set
test_gray, test_color = load_data(TEST_GRAY_DIR, TEST_COLOR_DIR)
loss, accuracy = model.evaluate(test_gray, test_color)
print(f"Test Loss: {loss:.4f}, Test Accuracy: {accuracy:.4f}")

# Visualize Predictions
def visualize_predictions(model, gray_images, color_images, num_samples=5):
    predictions = model.predict(gray_images[:num_samples])
    for i in range(num_samples):
        predicted_img = array_to_img(predictions[i])
        gray_img = array_to_img(gray_images[i])
        true_color_img = array_to_img(color_images[i])

        print("Sample", i + 1)
        predicted_img.show(title="Predicted Image")
        gray_img.show(title="Grayscale Image")
        true_color_img.show(title="Ground Truth")

visualize_predictions(model, test_gray, test_color, num_samples=5)


MemoryError: Unable to allocate 13.3 GiB for an array with shape (23648, 224, 224, 3) and data type float32