<a href="https://colab.research.google.com/github/Jess-Lau/Real-Life-B-W-Video-Colorization-Project/blob/main/ColorizerImageNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ----------------------
# Image Colorization GAN
# ----------------------
import os
import time
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, Sequential
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

# ------------------
# Configuration
# ------------------
IMAGE_SIZE = 64
CHANNELS = 1  # L channel input
EPOCHS = 1
BATCH_SIZE = 64
LAMBDA = 100
data_dir = "/content/ImageNet"
WORKDIR = "/content/colorization"
CHECKPOINT_DIR = os.path.join(WORKDIR, "checkpoints")
RESULTS_DIR = os.path.join(WORKDIR, "results")

# Reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Create directories
os.makedirs(WORKDIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# ------------------
# ImageNet64 Data Pipeline
# ------------------
def unpickle(file):
    """Load ImageNet64x64 batch"""
    with open(file, 'rb') as f:
        return pickle.load(f, encoding='bytes')

def load_imagenet_batch(data_dir, batch_file):
    data = unpickle(os.path.join(data_dir, batch_file))
    x = data[b'data'].astype(np.float32) / 255.0
    if b'mean' in data:
        mean = data[b'mean'].astype(np.float32) / 255.0
        x -= mean
    # Reshape to NHWC format
    return x.reshape(-1, 3, IMAGE_SIZE, IMAGE_SIZE).transpose(0, 2, 3, 1)

def load_dataset(data_dir, split='train'):
    """Load ImageNet64 dataset"""
    if split == 'train':
        files = [f'train_data_batch_{i}' for i in range(1, 11)]
        # Load mean from first batch
        first_batch = load_imagenet_batch(data_dir, 'train_data_batch_1')
        mean = unpickle(os.path.join(data_dir, 'train_data_batch_1'))[b'mean'] / 255.0
    else:
        files = ['val_data']
        mean = None

    images = []
    for file in files:
        batch_images = load_imagenet_batch(data_dir, file)
        if mean is not None:
            batch_images -= mean
        images.append(batch_images)

    return np.concatenate(images, axis=0)

def prepare_datasets(data_dir):
    """Prepare training and validation datasets"""
    # Load RGB data
    train_rgb = load_dataset(data_dir, 'train')
    val_rgb = load_dataset(data_dir, 'val')

    # Convert to LAB color space
    def process_batch(rgb_batch):
        lab_batch = []
        for img in rgb_batch:
            lab = rgb2lab(img)
            L = lab[..., 0:1]          # (64, 64, 1) [0-100]
            AB = lab[..., 1:] / 128.0  # (64, 64, 2) [-1, 1]
            lab_batch.append((L, AB))
        return zip(*lab_batch)

    print("Processing training data...")
    train_L, train_AB = process_batch(train_rgb)
    print("Processing validation data...")
    val_L, val_AB = process_batch(val_rgb)

    # Create TensorFlow datasets
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (np.array(train_L, dtype=np.float32),
        np.array(train_AB, dtype=np.float32))
    ).shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    val_dataset = tf.data.Dataset.from_tensor_slices(
        (np.array(val_L, dtype=np.float32),
        np.array(val_AB, dtype=np.float32))
    ).batch(BATCH_SIZE)

    return train_dataset, val_dataset

# ------------------
# Model Architectures
# ------------------
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                            kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        result.add(layers.BatchNormalization())
    result.add(layers.LeakyReLU(alpha=0.2))
    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                     padding='same',
                                     kernel_initializer=initializer,
                                     use_bias=False))
    result.add(layers.BatchNormalization())
    if apply_dropout:
        result.add(layers.Dropout(0.5))
    result.add(layers.ReLU())
    return result

def build_generator():
    inputs = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, CHANNELS])

    # Encoder (64x64 → 4x4)
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # 32x32
        downsample(128, 4),                        # 16x16
        downsample(256, 4),                        # 8x8
        downsample(512, 4),                        # 4x4
    ]

    # Decoder (4x4 → 64x64)
    up_stack = [
        upsample(512, 4, apply_dropout=True),      # 8x8
        upsample(256, 4),                          # 16x16
        upsample(128, 4),                          # 32x32
        upsample(64, 4),                           # 64x64
    ]

    # Output layer
    last = layers.Conv2D(2, 3, padding='same', activation='tanh')

    # U-Net with skip connections
    x = inputs
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)
    return Model(inputs=inputs, outputs=x)

def build_discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 3], name='input_image')

    x = layers.Conv2D(64, 4, strides=2, padding='same',
                     kernel_initializer=initializer)(inp)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128, 4, strides=2, padding='same',
                     kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(256, 4, strides=2, padding='same',
                     kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1, activation='sigmoid')(x)

    return Model(inputs=inp, outputs=x)

# ------------------
# Training Setup
# ------------------
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

def generator_loss(fake_output, generated_images, real_images):
    gan_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
    l1_loss = tf.reduce_mean(tf.abs(real_images - generated_images))
    return gan_loss + LAMBDA * l1_loss, gan_loss, l1_loss

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

generator = build_generator()
discriminator = build_discriminator()

checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator=generator,
    discriminator=discriminator,
    epoch=tf.Variable(0)
)

manager = tf.train.CheckpointManager(
    checkpoint, CHECKPOINT_DIR, max_to_keep=3)

# ------------------
# Training Loop
# ------------------
def generate_images(model, test_input, epoch):
    prediction = model(test_input, training=False)

    # Convert LAB to RGB
    def to_rgb(L, AB):
        return lab2rgb(np.concatenate([L, AB*128], axis=-1))

    plt.figure(figsize=(12, 4))

    # Input (Grayscale)
    plt.subplot(1, 3, 1)
    plt.imshow(to_rgb(test_input[0].numpy(), np.zeros_like(prediction[0].numpy())))
    plt.title("Input")
    plt.axis('off')

    # Ground Truth
    plt.subplot(1, 3, 2)
    plt.imshow(to_rgb(test_input[0].numpy(), test_input[1].numpy())))
    plt.title("Ground Truth")
    plt.axis('off')

    # Prediction
    plt.subplot(1, 3, 3)
    plt.imshow(to_rgb(test_input[0].numpy(), prediction[0].numpy())))
    plt.title("Predicted")
    plt.axis('off')

    plt.savefig(os.path.join(RESULTS_DIR, f'epoch_{epoch+1}.png'))
    plt.close()

@tf.function
def train_step(input_l, target_ab):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_ab = generator(input_l, training=True)
        real_images = tf.concat([input_l, target_ab], axis=-1)
        fake_images = tf.concat([input_l, generated_ab], axis=-1)

        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(fake_images, training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(
            fake_output, generated_ab, target_ab)
        disc_loss = discriminator_loss(real_output, fake_output)

    # Apply gradient clipping
    generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
    generator_gradients = [tf.clip_by_norm(g, 1.0) for g in generator_gradients]

    discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)
    discriminator_gradients = [tf.clip_by_norm(g, 1.0) for g in discriminator_gradients]

    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

    return gen_total_loss, disc_loss

def train(dataset, epochs):
    start_epoch = checkpoint.epoch.numpy()
    for epoch in range(start_epoch, epochs):
        start_time = time.time()

        # Training
        gen_losses = []
        disc_losses = []
        for batch, (input_l, target_ab) in enumerate(dataset):
            gen_loss, disc_loss = train_step(input_l, target_ab)
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)

            if batch % 100 == 0:
                print(f'Epoch {epoch+1} Batch {batch} | '
                      f'Gen Loss: {gen_loss:.4f} | Disc Loss: {disc_loss:.4f}')

        # Save checkpoint and generate samples
        if (epoch + 1) % 5 == 0:
            manager.save()
            test_batch = next(iter(dataset))
            generate_images(generator, test_batch, epoch)

        # Print epoch statistics
        epoch_time = time.time() - start_time
        print(f'Epoch {epoch+1}/{EPOCHS} | '
              f'Gen Loss: {np.mean(gen_losses):.4f} | '
              f'Disc Loss: {np.mean(disc_losses):.4f} | '
              f'Time: {epoch_time:.2f}s')

        checkpoint.epoch.assign_add(1)

# ------------------
# Execution
# ------------------
if __name__ == "__main__":
    # Initialize datasets
    DATA_DIR = "/path/to/imagenet64"  # Set your dataset path
    train_dataset, val_dataset = prepare_datasets(DATA_DIR)

    # Restore checkpoints if available
    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        print(f"Restored from {manager.latest_checkpoint}")

    # Start training
    train(train_dataset, EPOCHS)

    # Final evaluation
    test_losses = []
    for test_input, test_target in val_dataset:
        gen_output = generator(test_input, training=False)
        test_losses.append(tf.reduce_mean(tf.abs(test_target - gen_output)))
    print(f"Final Validation MAE: {np.mean(test_losses):.4f}")