<!-- RGB Model -->
<div class="alert" style="background: linear-gradient(to right,rgb(255, 0, 0), rgb(0,255,0),rgb(0, 0, 255)); 
color:rgb(255, 255, 255);">

# **U-NET RGB Colorization Model**
***
### **U-NET RGB Colorization Model.**


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
from tensorflow.keras.models import load_model

In [None]:
# Load processed data. Important to run LoadComicData.ipynb first
X_train = np.load("prepared_data/comic_input_grayscale_train.npy")
y_train = np.load("prepared_data/comic_output_color_train.npy")
X_test  = np.load("prepared_data/comic_input_grayscale_test.npy")
y_test  = np.load("prepared_data/comic_output_color_test.npy")

In [None]:
# Define the U-Net model
def build_unet(input_shape):
    inputs = layers.Input(shape=input_shape)

    # Encoder
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    # Bottleneck
    bn = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)

    # Decoder
    u1 = layers.Conv2DTranspose(128, (3, 3), strides=2, padding='same', activation='relu')(bn)
    concat1 = layers.Concatenate()([u1, c2])
    
    u2 = layers.Conv2DTranspose(64, (3, 3), strides=2, padding='same', activation='relu')(concat1)
    concat2 = layers.Concatenate()([u2, c1])

    outputs = layers.Conv2D(3, (1, 1), activation='sigmoid')(concat2)

    model = models.Model(inputs, outputs)
    return model

In [None]:
# Build and compile the model
input_shape = X_train.shape[1:]  # (256, 256, 1)
model = build_unet(input_shape)
model.compile(optimizer='adam', loss='mean_squared_error')
model.summary()

In [None]:
# Train the model
history = model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    epochs=50,
    batch_size=16
)

In [None]:
# Save the trained model
model.save("unet_colorization_rgb_comics.keras")
print("Model saved as 'unet_colorization_rgb_comics.keras'")