<!-- RGB Model -->
<div class="alert" style="background: linear-gradient(to right,rgb(255, 0, 0), rgb(0,255,0),rgb(0, 0, 255)); 
color:rgb(255, 255, 255);">

# **U-NET RGB Colorization Model - V2**
***
### **U-NET RGB Colorization Model.**


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('../Src')
from unet_model import build_unet, get_callbacks
import tensorflow as tf

from tensorflow.keras import layers, models
from tensorflow.keras.models import load_model

In [None]:
# Load and preprocess data
X_train = np.load("../Data/prepared_data/comic_input_grayscale_train.npy")
y_train = np.load("../Data/prepared_data/comic_output_color_train.npy")
X_test = np.load("../Data/prepared_data/comic_input_grayscale_test.npy")
y_test = np.load("../Data/prepared_data/comic_output_color_test.npy")

# Normalize input images to [-1, 1] range for better training with tanh
X_train = (X_train - 0.5) * 2
X_test = (X_test - 0.5) * 2

In [None]:
# Build and compile model
input_shape = X_train.shape[1:]
model = build_unet(input_shape)

# Using a combination of MSE and MAE losses
def combined_loss(y_true, y_pred):
    mse = tf.keras.losses.MeanSquaredError()(y_true, y_pred)
    mae = tf.keras.losses.MeanAbsoluteError()(y_true, y_pred)
    return 0.84 * mse + 0.16 * mae

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=['mae']
)

In [8]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# Create model
input_shape = (256, 256, 1)  # Adjustable
model = build_unet(input_shape)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss=combined_loss,
              metrics=['mae'])

# Callback to save the best model (based on val_loss)
checkpoint_cb = ModelCheckpoint("best_unet_model_rgb_V2.keras", save_best_only=True, monitor='val_loss', mode='min')

# Stop training early if no improvement
earlystop_cb = EarlyStopping(patience=10, restore_best_weights=True)

# Start training
history = model.fit(
    X_train, y_train,
    epochs=100,
    batch_size=32,
    validation_data=(X_test, y_test),
    callbacks=[checkpoint_cb, earlystop_cb]
)

# Save final model 
model.save("final_trained_unet_rgb_v2.keras")
print("Saved both best and final model.")


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m177s[0m 22s/step - loss: 0.0645 - mae: 0.1540 - val_loss: 0.1230 - val_mae: 0.2684
Epoch 3/100
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m174s[0m 22s/step - loss: 0.0467 - mae: 0.1256 - val_loss: 0.1218 - val_mae: 0.2667
Epoch 4/100
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m177s[0m 22s/step - loss: 0.0390 - mae: 0.1114 - val_loss: 0.1213 - val_mae: 0.2660
Epoch 5/100
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m175s[0m 22s/step - loss: 0.0344 - mae: 0.1028 - val_loss: 0.1203 - val_mae: 0.2645
Epoch 6/100
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m182s[0m 23s/step - loss: 0.0333 - mae: 0.1015 - val_loss: 0.1189 - val_mae: 0.2625
Epoch 7/100
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m190s[0m 24s/step - loss: 0.0291 - mae: 0.0915 - val_loss: 0.1172 - val_mae: 0.2600
Epoch 8/100
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m178s[0m 22s/step - loss: 0.0279 - mae:

In [None]:
# Manual save the model
model.save("improved_unet_colorization_rgb_comics.keras")
print("Model saved as 'improved_unet_colorization_rgb_comics.keras'")

In [None]:
# Visualize some results
def plot_results(model, X, y, num_samples=3):
    predictions = model.predict(X[:num_samples])
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    titles = ['Grayscale Input', 'Predicted Color', 'Ground Truth']
    
    for i in range(num_samples):
        axes[i, 0].imshow(X[i].squeeze(), cmap='gray')
        axes[i, 1].imshow(predictions[i])
        axes[i, 2].imshow(y[i])
        
        for j in range(3):
            axes[i, j].axis('off')
            if i == 0:
                axes[i, j].set_title(titles[j])
    
    plt.tight_layout()
    plt.show()

# Plot test results
plot_results(model, X_test, y_test)

In [None]:
# Check variables
%whos