<a href="https://colab.research.google.com/github/Jess-Lau/Real-Life-B-W-Video-Colorization-Project/blob/main/ImageColorizerGANV2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ----------------------
# Image Colorization GAN
# ----------------------
import os
import time
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, Sequential
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb



In [2]:
# ------------------
# Configuration
# ------------------
IMAGE_SIZE = 64
CHANNELS = 1
EPOCHS = 5
BATCH_SIZE = 128
LAMBDA = 100
DATA_DIR = "/content/drive/MyDrive/ImageNet"  # Update with your path
WORKDIR = "/content/colorization"
CHECKPOINT_DIR = os.path.join(WORKDIR, "checkpoints")
RESULTS_DIR = os.path.join(WORKDIR, "results")

# Enable mixed precision
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# Create directories
os.makedirs(WORKDIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

In [3]:
# ------------------
# Data Pipeline
# ------------------
def load_mean(data_dir):
    """Load mean image from first training batch"""
    with open(os.path.join(data_dir, 'train_data_batch_1'), 'rb') as f:
        data = pickle.load(f)
        mean = data['mean'].astype(np.float32) / 255.0
        return mean.reshape(3, IMAGE_SIZE, IMAGE_SIZE).transpose(1, 2, 0)

def data_generator(data_dir, split='train'):
    """Memory-efficient data generator"""
    mean = load_mean(data_dir) if split == 'train' else None
    files = [f'train_data_batch_{i}' for i in range(1, 11)] if split == 'train' else ['val_data']

    for file in files:
        path = os.path.join(data_dir, file)
        with open(path, 'rb') as f:
            data = pickle.load(f)
            x = data['data'].astype(np.float32) / 255.0
            x = x.reshape(-1, 3, IMAGE_SIZE, IMAGE_SIZE).transpose(0, 2, 3, 1)

            if mean is not None:
                x -= mean

            for i in range(0, x.shape[0], BATCH_SIZE):
                batch_rgb = x[i:i+BATCH_SIZE]
                batch_lab = np.array([rgb2lab(img) for img in batch_rgb])
                L = batch_lab[..., 0:1].astype(np.float32)  # (B,64,64,1)
                AB = (batch_lab[..., 1:] / 128.0).astype(np.float32)  # (B,64,64,2)
                yield L, AB

def create_dataset(data_dir, split='train'):
    return tf.data.Dataset.from_generator(
        lambda: data_generator(data_dir, split),
        output_signature=(  # ✅ Proper parentheses
            tf.TensorSpec(shape=(None, 64, 64, 1), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 64, 64, 2), dtype=tf.float32)
        )
    ).prefetch(tf.data.AUTOTUNE)  # ✅ .prefetch() called on dataset

In [4]:
# ------------------
# Model Architectures
# ------------------
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    model = Sequential()
    model.add(layers.Conv2D(filters, size, strides=2, padding='same',
                          kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(0.2))
    return model

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    model = Sequential()
    model.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                    kernel_initializer=initializer, use_bias=False))
    model.add(layers.BatchNormalization())
    if apply_dropout:
        model.add(layers.Dropout(0.5))
    model.add(layers.ReLU())
    return model

def build_generator():
    inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))

    # Encoder
    d1 = downsample(64, 4, False)(inputs)    # 32x32
    d2 = downsample(128, 4)(d1)              # 16x16
    d3 = downsample(256, 4)(d2)              # 8x8
    d4 = downsample(512, 4)(d3)              # 4x4

    # Decoder
    u1 = upsample(512, 4, True)(d4)          # 8x8
    u1 = layers.Concatenate()([u1, d3])
    u2 = upsample(256, 4)(u1)                # 16x16
    u2 = layers.Concatenate()([u2, d2])
    u3 = upsample(128, 4)(u2)                # 32x32
    u3 = layers.Concatenate()([u3, d1])
    u4 = upsample(64, 4)(u3)                 # 64x64

    output = layers.Conv2D(2, 3, padding='same', activation='tanh')(u4)
    return Model(inputs, output)

def build_discriminator():
    inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))

    x = layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1, activation='sigmoid')(x)
    return Model(inputs, x)

In [5]:
# ------------------
# Training Setup
# ------------------
generator = build_generator()
discriminator = build_discriminator()

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator=generator,
    discriminator=discriminator,
    epoch=tf.Variable(0)
)
manager = tf.train.CheckpointManager(checkpoint, CHECKPOINT_DIR, max_to_keep=3)

In [6]:
# ------------------
# Training Utilities
# ------------------
def generate_images(model, test_input, epoch):
    prediction = model(test_input, training=False)[0].numpy()
    L = test_input[0].numpy()[..., 0]

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(L, cmap='gray')
    plt.title("Input")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    true_rgb = lab2rgb(np.dstack((L, test_input[1].numpy()[0] * 128)))
    plt.imshow(true_rgb)
    plt.title("Ground Truth")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    pred_rgb = lab2rgb(np.dstack((L, prediction * 128)))
    plt.imshow(pred_rgb)
    plt.title("Predicted")
    plt.axis('off')

    plt.savefig(os.path.join(RESULTS_DIR, f'epoch_{epoch+1}.png'))
    plt.close()

@tf.function
def train_step(input_L, input_AB):
    # Cast to mixed precision
    input_L = tf.cast(input_L, tf.float16)
    input_AB = tf.cast(input_AB, tf.float16)

    with tf.GradientTape(persistent=True) as tape:
        generated_AB = generator(input_L, training=True)

        # Create concatenated images
        real_images = tf.concat([input_L, input_AB], axis=-1)
        fake_images = tf.concat([input_L, generated_AB], axis=-1)

        # Discriminator outputs
        disc_real = discriminator(real_images, training=True)
        disc_fake = discriminator(fake_images, training=True)

        # Loss calculations
        gen_loss = tf.keras.losses.binary_crossentropy(
            tf.ones_like(disc_fake), disc_fake) + LAMBDA * tf.reduce_mean(tf.abs(input_AB - generated_AB))
        disc_loss = tf.keras.losses.binary_crossentropy(
            tf.ones_like(disc_real), disc_real) + tf.keras.losses.binary_crossentropy(
            tf.zeros_like(disc_fake), disc_fake)

    # Apply gradient clipping
    gen_grads = tape.gradient(gen_loss, generator.trainable_variables)
    gen_grads = [tf.clip_by_norm(g, 1.0) for g in gen_grads]
    generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))

    disc_grads = tape.gradient(disc_loss, discriminator.trainable_variables)
    disc_grads = [tf.clip_by_norm(g, 1.0) for g in disc_grads]
    discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))

    return tf.reduce_mean(gen_loss), tf.reduce_mean(disc_loss)





In [7]:
# ------------------
# Training Loop
# ------------------
def train():
    train_dataset = create_dataset(DATA_DIR, 'train')
    val_dataset = create_dataset(DATA_DIR, 'val')

    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        print(f"Resumed from epoch {checkpoint.epoch.numpy()}")

    for epoch in range(checkpoint.epoch.numpy(), EPOCHS):
        start = time.time()
        gen_losses, disc_losses = [], []

        for batch, (L, AB) in enumerate(train_dataset):
            gen_loss, disc_loss = train_step(L, AB)
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)

            if batch % 100 == 0:
                gen_loss_val = gen_loss.numpy().item()
                disc_loss_val = disc_loss.numpy().item()
                print(f"Epoch {epoch+1} Batch {batch} | Gen: {gen_loss_val:.2f} Disc: {disc_loss_val:.2f}")
                tf.keras.backend.clear_session()

        if (epoch + 1) % 5 == 0:
            manager.save()
            test_batch = next(iter(val_dataset))
            generate_images(generator, test_batch, epoch)

        print(f"\nEpoch {epoch+1}/{EPOCHS}")
        print(f"Time: {time.time()-start:.2f}s")
        print(f"Gen Loss: {np.mean(gen_losses):.4f}")
        print(f"Disc Loss: {np.mean(disc_losses):.4f}\n")
        checkpoint.epoch.assign_add(1)

In [8]:
if __name__ == "__main__":
    train()

  nparray = values.astype(dtype.as_numpy_dtype)


Epoch 1 Batch 0 | Gen: 50.56 Disc: 1.57
Epoch 1 Batch 100 | Gen: 75.50 Disc: 1.41
Epoch 1 Batch 200 | Gen: 71.19 Disc: 1.40
Epoch 1 Batch 300 | Gen: 69.12 Disc: 1.39
Epoch 1 Batch 400 | Gen: 68.12 Disc: 1.39
Epoch 1 Batch 500 | Gen: 67.81 Disc: 1.39
Epoch 1 Batch 600 | Gen: 67.81 Disc: 1.39
Epoch 1 Batch 700 | Gen: 67.06 Disc: 1.39
Epoch 1 Batch 800 | Gen: 66.81 Disc: 1.39
Epoch 1 Batch 900 | Gen: 66.62 Disc: 1.39
Epoch 1 Batch 1000 | Gen: 65.62 Disc: 1.39
Epoch 1 Batch 1100 | Gen: 65.44 Disc: 1.39
Epoch 1 Batch 1200 | Gen: 65.56 Disc: 1.39
Epoch 1 Batch 1300 | Gen: 64.50 Disc: 1.39
Epoch 1 Batch 1400 | Gen: 65.56 Disc: 1.39
Epoch 1 Batch 1500 | Gen: 70.00 Disc: 1.39
Epoch 1 Batch 1600 | Gen: 73.50 Disc: 1.39
Epoch 1 Batch 1700 | Gen: 73.81 Disc: 1.39
Epoch 1 Batch 1800 | Gen: 75.62 Disc: 1.39
Epoch 1 Batch 1900 | Gen: 69.31 Disc: 1.39
Epoch 1 Batch 2000 | Gen: 67.19 Disc: 1.39
Epoch 1 Batch 2100 | Gen: 61.78 Disc: 1.39
Epoch 1 Batch 2200 | Gen: 59.53 Disc: 1.39
Epoch 1 Batch 2300 | Ge

UnknownError: {{function_node __wrapped__IteratorGetNext_output_types_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/ImageNet/val_data'
Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/tensorflow/python/ops/script_ops.py", line 269, in __call__
    ret = func(*args)
          ^^^^^^^^^^^

  File "/usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/from_generator_op.py", line 198, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "<ipython-input-3-ab932882313d>", line 18, in data_generator
    with open(path, 'rb') as f:
         ^^^^^^^^^^^^^^^^

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/ImageNet/val_data'


	 [[{{node PyFunc}}]] [Op:IteratorGetNext] name: 