# Model 2: U-net

In this notebook you will find the second (the main) model (after some development iterations).

## Project setup

### Library setup

In [None]:
import numpy as np
from tensorflow.keras import layers, Model # type: ignore
from tensorflow.keras.metrics import MeanSquaredError, MeanAbsoluteError # type: ignore
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping # type: ignore
from tensorflow.keras.optimizers import Adam # type: ignore
from tensorflow.keras.models import load_model  # type: ignore
from tensorflow.keras.preprocessing.image import img_to_array, load_img # type: ignore
import wandb
from wandb.integration.keras import WandbMetricsLogger
import os
import cv2
import matplotlib.pyplot as plt


print("Libraries loaded")

### Weights and Biases integration

In this step, the program will ask for an API key which is unique to each program or user.
To integrate this platform to this run, please configure it with your `entity`, `project` and `name` and run the cell below.

In [None]:
batch_size = 32
epochs = 50
learning_rate = 0.0001

wandb.init(
    entity="mehher_ghevandiani-american-university-of-armenia",
    project="Capstone",
    name="U-net model iteration 5",
    config={
        "learning_rate": learning_rate,
        "architecture": "U-Net",
        "dataset": "combined dataset",
        "epochs": epochs,
        "batch_size": batch_size,
    },
)

print("WandB initiated")

### Utility functions

In [None]:
def preprocess_image(img_path, img_size=(256, 256)):
    """
    Reads an image, converts it to LAB color space, resizes it, and normalizes the channels.

    Parameters:
        img_path (str): Path to the image file.
        img_size (tuple): Desired image size (width, height). Default is (256, 256).

    Returns:
        l_channel (np.ndarray): Normalized L channel in range [0, 1].
        ab_channels (np.ndarray): Normalized AB channels in range [-1, 1].

    Note:
        If the image cannot be read, returns None and prints an error message.
    """

    img = cv2.imread(img_path)
    if img is not None:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB).astype(np.float32)
        img = cv2.resize(img, img_size)
        l_channel = img[:, :, 0:1] / 255.0  # Normalize L channel
        ab_channels = (img[:, :, 1:] - 128) / 128.0  # Normalize AB channels
        return l_channel, ab_channels
    else:
        print(f"Image at {img_path} could not be read.")

def preprocess_images(data_dir, img_size=(256, 256)):
    """
    Traverses a directory to preprocess all images using LAB color space conversion.

    Parameters:
        data_dir (str): Directory path containing image files (recursively searched).
        img_size (tuple): Target size for resizing each image. Default is (256, 256).

    Returns:
        l_channels (np.ndarray): Stack of normalized L channels for all images.
        ab_channels (np.ndarray): Stack of normalized AB channels for all images.
    """

    l_channels = []
    ab_channels = []

    for root, _, files in os.walk(data_dir):
        print("Processing directory:", root)
        for file in files:
            print("Processing file:", file)
            img_path = os.path.join(root, file)
            l_channel, ab_channel = preprocess_image(img_path, img_size)
            l_channels.append(l_channel)
            ab_channels.append(ab_channel)

    l_channels = np.array(l_channels, dtype=np.float32)
    ab_channels = np.array(ab_channels, dtype=np.float32)

    return l_channels, ab_channels


def lab_to_rgb(l_channel, ab_channels):
    """
    Converts normalized LAB image data back into RGB format.

    Parameters:
        l_channel (np.ndarray): L channel with values in [0, 1].
        ab_channels (np.ndarray): AB channels with values in [-1, 1].

    Returns:
        rgb_image (np.ndarray): Image converted to RGB format.
    """

    l_channel = (l_channel * 255).astype(np.uint8)
    ab_channels = (ab_channels * 128 + 128).astype(np.uint8)
    lab_image = np.concatenate((l_channel, ab_channels), axis=-1)
    rgb_image = cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
    return rgb_image

def prepare_data(data_dir):
    """
    Preprocesses all images in a given directory and separates them into input (L) and label (AB) data.

    Parameters:
        data_dir (str): Path to the directory containing image data.

    Returns:
        images (np.ndarray): Array of normalized L channel data.
        labels (np.ndarray): Array of normalized AB channel data.
    """
    print("Preprocessing images")
    images, labels = preprocess_images(data_dir)
    print("Data Preprocessed")
    return images, labels


### The model

Below you will also find the function responcible for training the model

In [None]:
def unet_model(input_shape=(256, 256, 1)):
    inputs = layers.Input(shape=input_shape)
    x = layers.Rescaling(1./255)(inputs)  # Normalize grayscale input

    # Encoder
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    c1 = layers.BatchNormalization()(c1)
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    c1 = layers.BatchNormalization()(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = layers.BatchNormalization()(c2)
    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    c2 = layers.BatchNormalization()(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = layers.BatchNormalization()(c3)
    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    c3 = layers.BatchNormalization()(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)

    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = layers.BatchNormalization()(c4)
    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    c4 = layers.BatchNormalization()(c4)
    p4 = layers.MaxPooling2D((2, 2))(c4)

    # Bottleneck
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = layers.BatchNormalization()(c5)
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)
    c5 = layers.BatchNormalization()(c5)

    # Decoder
    u6 = layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = layers.BatchNormalization()(c6)
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
    c6 = layers.BatchNormalization()(c6)

    u7 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = layers.BatchNormalization()(c7)
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
    c7 = layers.BatchNormalization()(c7)

    u8 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = layers.BatchNormalization()(c8)
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
    c8 = layers.BatchNormalization()(c8)

    u9 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = layers.BatchNormalization()(c9)
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9)
    c9 = layers.BatchNormalization()(c9)

    outputs = layers.Conv2D(2, (1, 1), activation='tanh')(c9)

    model = Model(inputs, outputs)
    return model

def train_model(images, labels, callbacks=None):
    """
    Trains a U-Net model on the given dataset using LAB color space inputs and saves the trained model.

    Parameters:
        images (np.ndarray): Input data, typically L channel images with shape (num_samples, height, width, 1).
        labels (np.ndarray): Target data, typically AB channels with shape (num_samples, height, width, 2).
        callbacks (list, optional): List of Keras callbacks to apply during training. Default is None.

    Returns:
        model (tf.keras.Model): The trained U-Net model.

    Note:
        - Uses global variables: `learning_rate`, `batch_size`, and `epochs`.
        - The trained model is saved to a file named using the number of epochs (e.g., 'u_net_50_epoch_iter_5.h5').
    """
    model = unet_model()

    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='mae',
        metrics=[MeanSquaredError(), MeanAbsoluteError()]
    )
    print("Model Compiled")

    print("Fitting The Model")

    model.fit(
        images, labels,
        validation_split=0.2,
        batch_size=batch_size,
        epochs=epochs,
        callbacks=callbacks
    )

    model.save(f'u_net_{epochs}_epoch_iter_5.h5')
    return model

## The training process

### Images and Labels array setup

This is the part where we load or get the images and labels arrays. If the files images.npy and labels.npy already exist, they will get loaded, if not, they will get created by the 
`prepare_data()` function.

In [None]:
dataset_dir = "../../train_data" # Refer to the readme if you haven't downloaded the dataset yet

try:
    if os.path.exists("./images.npy") and os.path.exists("./labels.npy"):
        print("Loading images and labels from .npy files")
        images = np.load("./images.npy")
        labels = np.load("./labels.npy")
except FileNotFoundError:   
    print("Couldn't find .npy files, preprocessing images from dataset directory")
    images, labels = prepare_data(dataset_dir)
    np.save("./images.npy", images)
    np.save("./labels.npy", labels)


In [None]:

checkpoint_callback = ModelCheckpoint(
    filepath='u_net_colorization_checkpoint.h5',
    save_best_only=True,
    monitor='val_loss',
    mode='min'
)

early_stopping = EarlyStopping(
    monitor="val_loss",
    patience=10,
    restore_best_weights=True
)

callbacks = [checkpoint_callback,
              WandbMetricsLogger()
              ]

model = train_model(images, labels, callbacks=callbacks) # be cautios to run, it will take a while


## Testing the results

This part is responsible for plotting the visualisation results of the model.
For research purposes, all of the images in the dataset will be colorized, feel free to interrupt the process with `cntrl+c`. In this case you will see a KeyboardInterrupt error.

In [None]:
# model = load_model(f'u_net_{epochs}_epoch_iter_5.h5', compile=False) # Uncomment this line to load a pre-trained model

nighttime_dir = "../../nighttime_footage" # a sample dataset to test the model

for filename in os.listdir(dataset_dir): # change to nighttime_dir to test on the nighttime dataset
    if filename.endswith(".png") or filename.endswith(".jpg"):
        image_path = os.path.join(dataset_dir, filename)

        original_image = load_img(image_path)
        original_image = img_to_array(original_image) / 255.0
        l_channel, ab_channels = preprocess_image(image_path)

        l_channel_input = np.expand_dims(l_channel, axis=0)
        predicted_ab_channels = model.predict(l_channel_input)[0]

        colorized_image = lab_to_rgb(l_channel, predicted_ab_channels)
        grayscale_image = l_channel[:, :, 0]

        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.title(f"Original Image: {filename}")
        plt.imshow(original_image)
        plt.axis("off")

        plt.subplot(1, 3, 2)
        plt.title("Grayscale Image")
        plt.imshow(grayscale_image, cmap="gray")
        plt.axis("off")

        plt.subplot(1, 3, 3)
        plt.title("Colorized Image")
        plt.imshow(colorized_image)
        plt.axis("off")

        plt.show()