<!-- RGB Model -->
<div class="alert" style="background: linear-gradient(to right,rgb(255, 0, 0), rgb(0,255,0),rgb(0, 0, 255)); 
color:rgb(255, 255, 255);">

# **U-NET RGB Colorization Model - V2**
***
### **U-NET RGB Colorization Model.**


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('../Src')
from unet_model import build_unet, get_callbacks
import tensorflow as tf

from tensorflow.keras import layers, models
from tensorflow.keras.models import load_model

In [None]:
# GPU Check & Mixed Precision Setup for TensorFlow
import tensorflow as tf
from tensorflow.python.client import device_lib

print(f"TensorFlow version: {tf.__version__}")
print(f"Built with CUDA: {tf.test.is_built_with_cuda()}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

print("Device List:")
for device in device_lib.list_local_devices():
    print(f" - {device.name} ({device.device_type})")

# Enable mixed precision for performance if GPU is available
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        from tensorflow.keras import mixed_precision
        mixed_precision.set_global_policy("mixed_float16")
        print("Mixed precision enabled (float16).")
    except Exception as e:
        print(f"Could not enable mixed precision: {e}")
else:
    print("No GPU detected. Running on CPU.")

In [None]:
import sys
print(sys.executable)


In [None]:
X_train = np.load("../Data/prepared_data/RGB/comic_input_grayscale_train.npy")
print(X_train.dtype)

In [None]:
x = np.load("../Data/prepared_data/RGB/comic_input_grayscale_train.npy", mmap_mode=None)
print("Shape:", x.shape)
print("Min/max:", np.min(x), np.max(x))
print("Type before:", x.dtype)

x = x.astype(np.float32)
print("Type after:", x.dtype)

x = (x - 0.5) * 2.0
print("Final min/max:", np.min(x), np.max(x))

In [None]:
import tensorflow as tf
print(tf.__version__)

In [None]:
# Load and preprocess data
X_train = np.load("../Data/prepared_data/RGB/comic_input_grayscale_train.npy").astype(np.float32)
y_train = np.load("../Data/prepared_data/RGB/comic_output_color_train.npy").astype(np.float32)
X_test = np.load("../Data/prepared_data/RGB/comic_input_grayscale_test.npy").astype(np.float32)
y_test = np.load("../Data/prepared_data/RGB/comic_output_color_test.npy").astype(np.float32)

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# Build model
input_shape = X_train.shape[1:]  # (256, 256, 1) for grayscale
model = build_unet(input_shape)

# Combined loss with correct dtype
def combined_loss(y_true, y_pred):
    mse = tf.keras.losses.MeanSquaredError()(y_true, y_pred)
    mae = tf.keras.losses.MeanAbsoluteError()(y_true, y_pred)
    loss = 0.84 * mse + 0.16 * mae
    return tf.cast(loss, tf.float32)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=combined_loss,
    metrics=['mae']
)

# Callbacks
checkpoint_cb = ModelCheckpoint("U-NET_RGB_best_model_v2.keras", save_best_only=True, monitor='val_loss', mode='min')
earlystop_cb = EarlyStopping(patience=10, restore_best_weights=True)

# Train
history = model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    epochs=100,
    batch_size=32,
    callbacks=[checkpoint_cb, earlystop_cb]
)

model.save("U-NET_RGB_final_trained_v2.keras")
print("Saved best and final model.")


In [None]:
# Visualize some results
def plot_results(model, X, y, num_samples=50):
    predictions = model.predict(X[:num_samples])
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    titles = ['Grayscale Input', 'Predicted Color', 'Ground Truth']
    
    for i in range(num_samples):
        axes[i, 0].imshow(X[i].squeeze(), cmap='gray')
        axes[i, 1].imshow(predictions[i])
        axes[i, 2].imshow(y[i])
        
        for j in range(3):
            axes[i, j].axis('off')
            if i == 0:
                axes[i, j].set_title(titles[j])
    
    plt.tight_layout()
    plt.show()

# Plot test results
plot_results(model, X_test, y_test)

In [None]:
# Check variables
%whos