# Improved U-NET RGB Colorization Model
This notebook implements an enhanced version of the U-Net model for comic colorization

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from unet_model import build_unet, get_callbacks
import tensorflow as tf

from tensorflow.keras import layers, models
from tensorflow.keras.models import load_model

In [None]:
# Load and preprocess data
X_train = np.load("prepared_data/comic_input_grayscale_train.npy")
y_train = np.load("prepared_data/comic_output_color_train.npy")
X_test = np.load("prepared_data/comic_input_grayscale_test.npy")
y_test = np.load("prepared_data/comic_output_color_test.npy")

# Normalize input images to [-1, 1] range for better training with tanh
X_train = (X_train - 0.5) * 2
X_test = (X_test - 0.5) * 2

In [None]:
# Build and compile model
input_shape = X_train.shape[1:]
model = build_unet(input_shape)

# Using a combination of MSE and MAE losses
def combined_loss(y_true, y_pred):
    mse = tf.keras.losses.MeanSquaredError()(y_true, y_pred)
    mae = tf.keras.losses.MeanAbsoluteError()(y_true, y_pred)
    return 0.84 * mse + 0.16 * mae

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=['mae']
)

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# Create model
input_shape = (256, 256, 1)  # Adjustable
model = build_unet(input_shape)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss=combined_loss,
              metrics=['mae'])

# Callback to save the best model (based on val_loss)
checkpoint_cb = ModelCheckpoint("best_unet_model_rgb.keras", save_best_only=True, monitor='val_loss', mode='min')

# Stop training early if no improvement
earlystop_cb = EarlyStopping(patience=10, restore_best_weights=True)

# Start training
history = model.fit(
    X_train, Y_train,
    epochs=100,
    batch_size=32,
    validation_data=(X_val, Y_val),
    callbacks=[checkpoint_cb, earlystop_cb]
)

# Save final model 
model.save("final_trained_unet_rgb.keras")
print("Saved both best and final model.")


In [None]:
# Manual save the model
model.save("improved_unet_colorization_rgb_comics.keras")
print("Model saved as 'improved_unet_colorization_rgb_comics.keras'")

In [None]:
# Visualize some results
def plot_results(model, X, y, num_samples=3):
    predictions = model.predict(X[:num_samples])
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    titles = ['Grayscale Input', 'Predicted Color', 'Ground Truth']
    
    for i in range(num_samples):
        axes[i, 0].imshow(X[i].squeeze(), cmap='gray')
        axes[i, 1].imshow(predictions[i])
        axes[i, 2].imshow(y[i])
        
        for j in range(3):
            axes[i, j].axis('off')
            if i == 0:
                axes[i, j].set_title(titles[j])
    
    plt.tight_layout()
    plt.show()

# Plot test results
plot_results(model, X_test, y_test)

In [None]:
# Check variables
%whos