<a href="https://colab.research.google.com/github/Jess-Lau/Real-Life-B-W-Video-Colorization-Project/blob/main/ImageColorizerGANV4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
pip install tensorflow scikit-image matplotlib



In [10]:
# ----------------------
# Image Colorization GAN
# ----------------------
import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, Sequential
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
from datetime import datetime

# ------------------
# Configuration
# ------------------
IMAGE_SIZE = 32
CHANNELS = 1  # L channel input
EPOCHS = 1
BATCH_SIZE = 128
LAMBDA = 100  # L1 loss weight
WORKDIR = "/content/colorization"
CHECKPOINT_DIR = os.path.join(WORKDIR, "checkpoints")
RESULTS_DIR = os.path.join(WORKDIR, "results")

# Reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Create directories
os.makedirs(WORKDIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# ------------------
# Data Pipeline
# ------------------
def load_cifar10():
    (train_images, _), (test_images, _) = tf.keras.datasets.cifar10.load_data()
    return train_images / 255.0, test_images / 255.0

def rgb_to_lab(images, debug=False):
    lab_images = []
    for rgb in images:
        lab = rgb2lab(rgb)
        L = lab[:, :, 0:1]          # (32, 32, 1) [0-100]
        AB = lab[:, :, 1:] / 128.0  # (32, 32, 2) [-1, 1]
        lab_images.append((L, AB))
    return zip(*lab_images)

def create_dataset(images, batch_size=32):
    L, AB = rgb_to_lab(images)
    dataset = tf.data.Dataset.from_tensor_slices(
        (  # Wrap both arrays in a single tuple
            np.array(L, dtype=np.float32),
            np.array(AB, dtype=np.float32)
        )
    )
    return dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

# Load and prepare data
train_images, test_images = load_cifar10()
train_dataset = create_dataset(train_images, BATCH_SIZE)
test_dataset = create_dataset(test_images, BATCH_SIZE)

# ------------------
# Generator (U-Net)
# ------------------
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        result.add(layers.BatchNormalization())
    result.add(layers.LeakyReLU())
    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))
    result.add(layers.BatchNormalization())
    if apply_dropout:
        result.add(layers.Dropout(0.5))
    result.add(layers.ReLU())
    return result

def build_generator():
    inputs = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, CHANNELS])

    # Encoder (3 downsampling steps)
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # 16x16
        downsample(128, 4),                        # 8x8
        downsample(256, 4),                        # 4x4
    ]

    # Decoder (3 upsampling steps)
    up_stack = [
        upsample(256, 4, apply_dropout=True),      # 8x8
        upsample(128, 4),                          # 16x16
        upsample(64, 4),                           # 32x32 (critical!)
    ]

    # Final output layer
    last = layers.Conv2D(2, 3, padding='same', activation='tanh')

    x = inputs
    skips = []

    # Downsampling
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])  # Use first two skip connections

    # Upsampling with skip connections
    for up, skip in zip(up_stack[:-1], skips):  # First two up layers use skips
        x = up(x)
        x = layers.Concatenate()([x, skip])

    # Final upsampling without skip
    x = up_stack[-1](x)  # Third up layer (32x32)
    x = last(x)

    return Model(inputs=inputs, outputs=x)

# ------------------
# Discriminator
# ------------------
def build_discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 3], name='input_image')

    x = layers.Conv2D(64, 4, strides=2, padding='same',
                      kernel_initializer=initializer)(inp)  # (16, 16, 64)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(128, 4, strides=2, padding='same',
                      kernel_initializer=initializer)(x)    # (8, 8, 128)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(256, 4, strides=2, padding='same',
                      kernel_initializer=initializer)(x)    # (4, 4, 256)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1, activation='sigmoid')(x)

    return Model(inputs=inp, outputs=x)

# ------------------
# Loss & Optimizers
# ------------------
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

def generator_loss(fake_output, generated_images, real_images):
    gan_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
    l1_loss = tf.reduce_mean(tf.abs(real_images - generated_images))
    return gan_loss + LAMBDA * l1_loss, gan_loss, l1_loss

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

# ------------------
# Training Setup
# ------------------
generator = build_generator()
discriminator = build_discriminator()

checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator=generator,
    discriminator=discriminator,
    epoch=tf.Variable(0)
)

manager = tf.train.CheckpointManager(
    checkpoint, CHECKPOINT_DIR, max_to_keep=3
)

# ------------------
# Training Loop
# ------------------

def generate_images(model, test_input, epoch):
    # Get predictions
    prediction = model(test_input, training=False)

    # Convert to LAB and then to RGB
    grayscale_rgb = lab2rgb(np.dstack((
        test_input[0].numpy()[..., 0],
        np.zeros_like(prediction[0]),
        np.zeros_like(prediction[0])
    )))  # Grayscale (L + zero AB)

    original_rgb = lab2rgb(np.dstack((
        test_input[0].numpy()[..., 0],
        (test_input[0].numpy()[..., 1:] * 128).astype(np.float64)
    )))  # Ground truth

    predicted_rgb = lab2rgb(np.dstack((
        test_input[0].numpy()[..., 0],
        (prediction[0].numpy() * 128).astype(np.float64)
    )))  # Colorized prediction

    # Plotting
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.title("Input")
    plt.imshow(grayscale_rgb)
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title("Ground Truth")
    plt.imshow(original_rgb)
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title("Predicted")
    plt.imshow(predicted_rgb)
    plt.axis('off')

    plt.savefig(os.path.join(RESULTS_DIR, f'epoch_{epoch}.png'))
    plt.close()

@tf.function
def train_step(input_l, target_ab):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_ab = generator(input_l, training=True)
        real_images = tf.concat([input_l, target_ab], axis=-1)
        fake_images = tf.concat([input_l, generated_ab], axis=-1)

        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(fake_images, training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(
            fake_output, generated_ab, target_ab
        )
        disc_loss = discriminator_loss(real_output, fake_output)

    # Gradient clipping
    generator_gradients = gen_tape.gradient(
        gen_total_loss, generator.trainable_variables
    )
    generator_gradients = [tf.clip_by_norm(g, 1.0) for g in generator_gradients]

    discriminator_gradients = disc_tape.gradient(
        disc_loss, discriminator.trainable_variables
    )
    discriminator_gradients = [
        tf.clip_by_norm(g, 1.0) for g in discriminator_gradients
    ]

    generator_optimizer.apply_gradients(
        zip(generator_gradients, generator.trainable_variables)
    )
    discriminator_optimizer.apply_gradients(
        zip(discriminator_gradients, discriminator.trainable_variables)
    )

    return gen_total_loss, disc_loss, gen_gan_loss, gen_l1_loss

def train(dataset, epochs):
    for epoch in range(checkpoint.epoch.numpy(), epochs):
        start = time.time()

        # Training
        gen_loss = []
        disc_loss = []
        for input_l, target_ab in dataset:
            losses = train_step(input_l, target_ab)
            gen_loss.append(losses[0])
            disc_loss.append(losses[1])

        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            manager.save()
            generate_images(generator, next(iter(test_dataset))[0], epoch+1)

        # Logging
        print(f'Epoch {epoch+1} | '
              f'Gen Loss: {np.mean(gen_loss):.4f} | '
              f'Disc Loss: {np.mean(disc_loss):.4f} | '
              f'Time: {time.time()-start:.2f}s')

        checkpoint.epoch.assign_add(1)

# ------------------
# Execution
# ------------------
if __name__ == "__main__":
    # Restore checkpoints if available
    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        print(f"Restored from {manager.latest_checkpoint}")

    # Start training
    train(train_dataset, EPOCHS)

    # Final evaluation
    test_loss = []
    for test_input, test_target in test_dataset:
        gen_output = generator(test_input, training=False)
        test_loss.append(tf.reduce_mean(tf.abs(test_target - gen_output)))
    print(f"Final Test MAE: {np.mean(test_loss):.4f}")

Restored from /content/colorization/checkpoints/ckpt-1
Final Test MAE: 0.0786
