<!-- HSV U-Net Model -->
<div class="alert" style="background: linear-gradient(to right, hsl(125, 100.00%, 50.00%), hsl(6, 100.00%, 50.00%), hsl(290, 100.00%, 53.10%)); color:white;">

# **U-NET HSV Colorization Model**
***
This notebook implements a U-Net model for image colorization using the HSV color space. The model takes grayscale images (Value channel) as input and predicts the Hue and Saturation channels.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('../Src')
from unet_model import build_unet, get_callbacks
import tensorflow as tf

from tensorflow.keras import layers, models
from tensorflow.keras.models import load_model

In [None]:
# GPU Check & Mixed Precision Setup for TensorFlow
import tensorflow as tf
from tensorflow.python.client import device_lib

print(f"TensorFlow version: {tf.__version__}")
print(f"Built with CUDA: {tf.test.is_built_with_cuda()}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

print("Device List:")
for device in device_lib.list_local_devices():
    print(f" - {device.name} ({device.device_type})")

# Enable mixed precision for performance if GPU is available
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        from tensorflow.keras import mixed_precision
        mixed_precision.set_global_policy("mixed_float16")
        print("Mixed precision enabled (float16).")
    except Exception as e:
        print(f"Could not enable mixed precision: {e}")
else:
    print("No GPU detected. Running on CPU.")

In [None]:
# Load and preprocess data
x_train = np.load("../Data/prepared_data/HSV/comic_input_grayscale_train.npy")
y_train = np.load("../Data/prepared_data/HSV/comic_output_color_train.npy")
x_test  = np.load("../Data/prepared_data/HSV/comic_input_grayscale_test.npy")
y_test  = np.load("../Data/prepared_data/HSV/comic_output_color_test.npy")

# casting to float32 because GPU training is better at float32
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
y_train = y_train.astype(np.float32)
y_test = y_test.astype(np.float32)

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("Memory growth enabled")
    except RuntimeError as e:
        print("Memory growth setup failed:", e)


In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# Build model
input_shape = x_train.shape[1:]  # (256, 256, 1) for grayscale
model = build_unet(input_shape)

# Combined loss with correct dtype
def combined_loss(y_true, y_pred):
    mse = tf.keras.losses.MeanSquaredError()(y_true, y_pred)
    mae = tf.keras.losses.MeanAbsoluteError()(y_true, y_pred)
    loss = 0.84 * mse + 0.16 * mae
    return tf.cast(loss, tf.float32)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=['mae']
)

# Callbacks
checkpoint_cb = ModelCheckpoint("U-NET_HSV_best_model_v2.keras", save_best_only=True, monitor='val_loss', mode='min')
earlystop_cb = EarlyStopping(patience=20, restore_best_weights=True)

In [None]:
print("x_train:", x_train.shape)
print("y_train:", y_train.shape)
print("x_test:", x_test.shape)
print("y_test:", y_test.shape)

print("Model input:", model.input_shape)
print("Model output:", model.output_shape)


In [None]:
print("x_train dtype:", x_train.dtype)
print("y_train dtype:", y_train.dtype)
print("Model output dtype:", model.output.dtype)


In [None]:
# Train
history = model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=300,
    batch_size=32,
    callbacks=[checkpoint_cb, earlystop_cb]
)

# Save training history to CSV
import pandas as pd
pd.DataFrame(history.history).to_csv("training_log_hsv.csv", index=False)

model.save("U-NET_HSV_final_trained_v2.keras")
print("Saved best and final model.")


In [None]:
pd.DataFrame(history.history)[['loss', 'val_loss']].plot()

In [None]:
predictions = model.predict(x_test[:100])
print("Pred min/max:", predictions.min(), predictions.max())
print("GT min/max:", y_test[:100].min(), y_test[:100].max())

In [None]:
# Visualize some results without denormalization
def plot_results(model, x, y, num_samples=10):
    predictions = model.predict(x[:num_samples])
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    titles = ['Grayscale Input', 'Predicted Color', 'Ground Truth']
    
    for i in range(num_samples):
        axes[i, 0].imshow(x[i].squeeze(), cmap='gray')
        axes[i, 1].imshow(predictions[i].astype(np.float32)) # Too dark
       # axes[i, 1].imshow(((predictions[i] * 0.5 + 0.5).clip(0, 1)).astype(np.float32))  # Too bright
        axes[i, 2].imshow(((y[i] * 0.5 + 0.5).clip(0, 1)).astype(np.float32))  # Denormalize
                
        for j in range(3):
            axes[i, j].axis('off')
            if i == 0:
                axes[i, j].set_title(titles[j])
    
    plt.tight_layout()
    plt.show()

# Plot test results
plot_results(model, x_test, y_test)

In [None]:
# Visualize same results with denormalization
def plot_results(model, x, y, num_samples=10):
    predictions = model.predict(x[:num_samples])
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    titles = ['Grayscale Input', 'Predicted Color', 'Ground Truth']
    
    for i in range(num_samples):
        axes[i, 0].imshow(x[i].squeeze(), cmap='gray')
       # axes[i, 1].imshow(predictions[i].astype(np.float32)) # Too dark
        axes[i, 1].imshow(((predictions[i] * 0.5 + 0.5).clip(0, 1)).astype(np.float32))  # Too bright
        axes[i, 2].imshow(((y[i] * 0.5 + 0.5).clip(0, 1)).astype(np.float32))  # Denormalize
                
        for j in range(3):
            axes[i, j].axis('off')
            if i == 0:
                axes[i, j].set_title(titles[j])
    
    plt.tight_layout()
    plt.show()

# Plot test results
plot_results(model, x_test, y_test)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Reload training log
log_path = "training_log_hsv.csv"
df = pd.read_csv(log_path)

# Plot loss and MAE
plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(df["loss"], label="Train Loss")
plt.plot(df["val_loss"], label="Val Loss")
plt.title("Loss Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(df["mae"], label="Train MAE")
plt.plot(df["val_mae"], label="Val MAE")
plt.title("MAE Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Mean Absolute Error")
plt.legend()

plt.tight_layout()
plt.show()