# Model 1: Convolutional Neural Networks

In this notebook you will find the first model (after some development iterations).

## Project setup

### Library setup

In [None]:
import numpy as np
import os
import cv2
from tensorflow.keras import Model # type: ignore
from tensorflow.keras.metrics import MeanSquaredError, MeanAbsoluteError # type: ignore
from tensorflow.keras.layers import Conv2D, UpSampling2D, Input, BatchNormalization # type: ignore
from wandb.integration.keras import WandbMetricsLogger
from tensorflow.keras.models import load_model  # type: ignore
from tensorflow.keras.preprocessing.image import img_to_array, load_img # type: ignore
import wandb
import matplotlib.pyplot as plt


print("Libraries loaded")


### Weights and Biases integration

In this step, the program will ask for an API key which is unique to each program or user.
To integrate this platform to this run, please configure it with your `entity`, `project` and `name` and run the cell below.

In [None]:
batch_size = 32
epochs = 50
learning_rate = 0.0001

wandb.init(
    entity="mehher_ghevandiani-american-university-of-armenia",
    project="Capstone",
    name="U-net model iteration 5",
    config={
        "learning_rate": learning_rate,
        "architecture": "U-Net",
        "dataset": "human_detection_dataset",
        "epochs": epochs,
        "batch_size": batch_size,
    },
)

print("WandB initiated")


### Utility functions

In [5]:
def preprocess_image(img_path, img_size=(256, 256)):
    img = cv2.imread(img_path)
    if img is not None:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB).astype(np.float32)
        img = cv2.resize(img, img_size)
        l_channel = img[:, :, 0:1] / 255.0  # Normalize L channel
        ab_channels = (img[:, :, 1:] - 128) / 128.0  # Normalize AB channels
        return l_channel, ab_channels
    else:
        print(f"Image at {img_path} could not be read.")

def preprocess_images(data_dir, img_size=(256, 256)):
    l_channels = []
    ab_channels = []

    for root, _, files in os.walk(data_dir):
        print("Processing directory:", root)
        for file in files:
            print("Processing file:", file)
            img_path = os.path.join(root, file)
            l_channel, ab_channel = preprocess_image(img_path, img_size)
            l_channels.append(l_channel)
            ab_channels.append(ab_channel)

    l_channels = np.array(l_channels, dtype=np.float32)
    ab_channels = np.array(ab_channels, dtype=np.float32)

    return l_channels, ab_channels

def lab_to_rgb(l_channel, ab_channels):
    l_channel = (l_channel * 255).astype(np.uint8)
    ab_channels = (ab_channels * 128 + 128).astype(np.uint8)
    lab_image = np.concatenate((l_channel, ab_channels), axis=-1)
    rgb_image = cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
    return rgb_image

def prepare_data(data_dir):
    print("Preprocessing images")
    images, labels = preprocess_images(data_dir)
    print("Data Preprocessed")
    return images, labels

### The model

Below you will also find the function responcible for training the model

In [None]:
def cnn_model(input_shape=(256, 256, 1)):
    inputs = Input(shape=input_shape)

    # Encoder
    x = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Conv2D(64, (3, 3), activation="relu", padding="same", strides=2)(x)
    x = BatchNormalization()(x)
    x = Conv2D(128, (3, 3), activation="relu", padding="same", strides=2)(x)
    x = BatchNormalization()(x)

    # Decoder
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation="relu", padding="same")(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(32, (3, 3), activation="relu", padding="same")(x)
    x = Conv2D(2, (3, 3), activation="tanh", padding="same")(x)  # 2 output channels (A, B)

    model = Model(inputs, x)
    return model

def train_model(images, labels, callbacks=None):
    model = cnn_model()
    
    model.compile(optimizer='adam', loss='mae', metrics=[MeanSquaredError(), MeanAbsoluteError()])

    print("Model Compiled")

    print("Fitting The Model")

    model.fit(
        images, labels,
        validation_split=0.2,
        batch_size=batch_size,
        epochs=epochs,
        callbacks=callbacks
    )

    model.save(f'cnn_{epochs}_epochs_model.keras')
    return model

## The training process

### Images and Labels array setup

This is the part where we load or get the images and labels arrays. If the files images.npy and labels.npy already exist, they will get loaded, if not, they will get created by the 
`prepare_data()` function.

In [None]:
dataset_dir = "../../train_data" # Refer to the readme if you haven't downloaded the dataset yet

try:
    if os.path.exists("./images.npy") and os.path.exists("./labels.npy"):
        print("Loading images and labels from .npy files")
        images = np.load("./images.npy")
        labels = np.load("./labels.npy")
except FileNotFoundError:   
    print("Couldn't find .npy files, preprocessing images from dataset directory")
    images, labels = prepare_data(dataset_dir)
    np.save("./images.npy", images)
    np.save("./labels.npy", labels)


In [None]:
model = train_model(images, labels) # be cautios to run, it will take a while

## Testing the results

This part is responsible for plotting the visualisation results of the model.
For research purposes, all of the images in the dataset will be colorized, feel free to interrupt the process with `cntrl+c`. In this case you will see a KeyboardInterrupt error.

In [None]:

# model = load_model(f'cnn_{epochs}_epochs_model.keras', compile=False) # Uncomment this line to load a pre-trained model

for filename in os.listdir(dataset_dir):
    if filename.endswith(".png") or filename.endswith(".jpg"):
        image_path = os.path.join(dataset_dir, filename)

        original_image = load_img(image_path)
        original_image = img_to_array(original_image) / 255.0
        l_channel, ab_channels = preprocess_image(image_path)

        l_channel_input = np.expand_dims(l_channel, axis=0)
        predicted_ab_channels = model.predict(l_channel_input)[0]

        colorized_image = lab_to_rgb(l_channel, predicted_ab_channels)
        grayscale_image = l_channel[:, :, 0]

        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.title(f"Original Image: {filename}")
        plt.imshow(original_image)
        plt.axis("off")

        plt.subplot(1, 3, 2)
        plt.title("Grayscale Image")
        plt.imshow(grayscale_image, cmap="gray")
        plt.axis("off")

        plt.subplot(1, 3, 3)
        plt.title("Colorized Image")
        plt.imshow(colorized_image)
        plt.axis("off")

        plt.show()