<a href="https://colab.research.google.com/github/Jess-Lau/Real-Life-B-W-Video-Colorization-Project/blob/main/ColorizerGANVXZ.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [66]:
# ----------------------
# Colorization GAN
# ----------------------
import os
import cv2
import time
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, Sequential
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb



In [100]:
# ------------------
# Configuration
# ------------------
IMAGE_SIZE = 64
CHANNELS = 1
EPOCHS = 55
BATCH_SIZE = 256
LAMBDA = 100
DATA_DIR = "/content/drive/MyDrive/ImageNet"  # Update with your path
WORKDIR = "/content/drive/MyDrive/Colorization"
CHECKPOINT_DIR = os.path.join(WORKDIR, "checkpoints")
RESULTS_DIR = os.path.join(WORKDIR, "results")

# Enable mixed precision with proper policy


# Create directories
os.makedirs(WORKDIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

In [68]:
# ------------------
# Data Pipeline
# ------------------
def load_mean(data_dir):
    """Load mean image from first training batch"""
    with open(os.path.join(data_dir, 'train_data_batch_1'), 'rb') as f:
        data = pickle.load(f)
        mean = data['mean'].astype(np.float32) / 255.0
        return mean.reshape(3, IMAGE_SIZE, IMAGE_SIZE).transpose(1, 2, 0)

def data_generator(data_dir, split='train'):
    mean = load_mean(data_dir) if split == 'train' else None
    files = [f'train_data_batch_{i}' for i in range(1, 11)] if split == 'train' else ['val_data']

    for file in files:
        path = os.path.join(data_dir, file)
        try:
            with open(path, 'rb') as f:
                data = pickle.load(f)
                x = data['data'].astype(np.float32) / 255.0
                x = x.reshape(-1, 3, IMAGE_SIZE, IMAGE_SIZE).transpose(0, 2, 3, 1)

                if mean is not None:
                    x -= mean

                for i in range(0, x.shape[0], BATCH_SIZE):
                    batch_rgb = x[i:i+BATCH_SIZE]
                    batch_lab = np.array([rgb2lab(img) for img in batch_rgb])
                    batch_lab = np.nan_to_num(batch_lab, nan=0.0, posinf=100.0, neginf=0.0)
                    L = np.clip(batch_lab[..., 0:1], 0, 100).astype(np.float32)
                    AB = np.clip(batch_lab[..., 1:], -128, 127).astype(np.float32) / 128.0
                    yield L, AB
        except Exception as e:
            print(f"❌ Failed to load {path}: {str(e)}")
            continue  # Skip problematic files

def create_dataset(data_dir, split='train'):
    return tf.data.Dataset.from_generator(
        lambda: data_generator(data_dir, split),
        output_signature=(  # ✅ Proper parentheses
            tf.TensorSpec(shape=(None, 64, 64, 1), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 64, 64, 2), dtype=tf.float32)
        )
    ).prefetch(tf.data.AUTOTUNE)  # ✅ .prefetch() called on dataset

In [86]:
# ------------------
# Custom Instance Normalization (using GroupNorm)
# ------------------
class InstanceNormalization(layers.Layer):
    def __init__(self, epsilon=1e-3):
        super(InstanceNormalization, self).__init__()
        self.epsilon = epsilon

    def build(self, input_shape):
        # Create variables in float32 for numerical stability
        self.gamma = self.add_weight(
            name='gamma',
            shape=(input_shape[-1],),
            initializer='ones',
            dtype=tf.float32  # Store as float32
        )
        self.beta = self.add_weight(
            name='beta',
            shape=(input_shape[-1],),
            initializer='zeros',
            dtype=tf.float32  # Store as float32
        )

    def call(self, inputs):
        input_dtype = inputs.dtype  # Preserve original dtype (float16)

        # Compute in float32 for numerical stability
        inputs_float32 = tf.cast(inputs, tf.float32)
        mean, variance = tf.nn.moments(inputs_float32, axes=[1, 2], keepdims=True)
        inv = tf.math.rsqrt(variance + self.epsilon)

        # Cast parameters to match computation dtype
        gamma = tf.cast(self.gamma, tf.float32)
        beta = tf.cast(self.beta, tf.float32)

        # Compute and cast back to original dtype
        output = gamma * (inputs_float32 - mean) * inv + beta
        return tf.cast(output, input_dtype)  # Return original dtype

# ------------------
# Model Architectures
# ------------------
def residual_block(filters):
    """Residual block with custom instance normalization"""
    block = Sequential()
    block.add(layers.Conv2D(filters, 3, padding='same'))
    block.add(InstanceNormalization())
    block.add(layers.ReLU())
    block.add(layers.Conv2D(filters, 3, padding='same'))
    block.add(InstanceNormalization())
    return block

def attention_block(skip, gate, filters):
    """Attention gate using core TF layers"""
    g = layers.Conv2D(filters, 1)(gate)
    x = layers.Conv2D(filters, 1)(skip)
    psi = layers.Activation('relu')(layers.Add()([g, x]))
    psi = layers.Conv2D(1, 1, activation='sigmoid')(psi)
    return layers.Multiply()([skip, psi])

def build_generator():
    inputs = layers.Input(shape=(64, 64, 1))

    # Encoder
    d1 = layers.Conv2D(64, 4, strides=2, padding='same')(inputs)  # 32x32
    d1 = InstanceNormalization()(d1)
    d1 = layers.LeakyReLU(0.2)(d1)

    d2 = layers.Conv2D(128, 4, strides=2, padding='same')(d1)     # 16x16
    d2 = InstanceNormalization()(d2)
    d2 = layers.LeakyReLU(0.2)(d2)

    d3 = layers.Conv2D(256, 4, strides=2, padding='same')(d2)     # 8x8
    d3 = InstanceNormalization()(d3)
    d3 = layers.LeakyReLU(0.2)(d3)

    d4 = layers.Conv2D(512, 4, strides=2, padding='same')(d3)     # 4x4
    d4 = InstanceNormalization()(d4)
    d4 = layers.LeakyReLU(0.2)(d4)

    # Bottleneck with residual
    res = residual_block(512)(d4)
    d4 = layers.Add()([d4, res])

    # Decoder with PROPER upsampling
    # Layer 1: 4x4 → 8x8
    u1 = layers.Conv2DTranspose(512, 4, strides=2, padding='same')(d4)
    u1 = InstanceNormalization()(u1)
    u1 = layers.ReLU()(u1)
    u1 = layers.Concatenate()([u1, d3])  # Skip connection from d3 (8x8)

    # Layer 2: 8x8 → 16x16
    u2 = layers.Conv2DTranspose(256, 4, strides=2, padding='same')(u1)
    u2 = InstanceNormalization()(u2)
    u2 = layers.ReLU()(u2)
    u2 = layers.Concatenate()([u2, d2])  # Skip connection from d2 (16x16)

    # Layer 3: 16x16 → 32x32
    u3 = layers.Conv2DTranspose(128, 4, strides=2, padding='same')(u2)
    u3 = InstanceNormalization()(u3)
    u3 = layers.ReLU()(u3)
    u3 = layers.Concatenate()([u3, d1])  # Skip connection from d1 (32x32)

    # Layer 4: 32x32 → 64x64 (FIXED: Added final upsampling layer)
    u4 = layers.Conv2DTranspose(64, 4, strides=2, padding='same')(u3)
    u4 = InstanceNormalization()(u4)
    u4 = layers.ReLU()(u4)

    # Final output layer
    output = layers.Conv2D(2, 3, padding='same',
                          activation='tanh',
                          dtype='float32')(u4)
    return Model(inputs, output)

def build_discriminator():
    inputs = layers.Input(shape=(64, 64, 3))

    # Layer 1: 64x64 → 32x32
    x = layers.Conv2D(64, 4, strides=2, padding='same', dtype='float32')(inputs)
    x = layers.LeakyReLU(0.2)(x)

    # Layer 2: 32x32 → 16x16
    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = layers.LeakyReLU(0.2)(x)

    # Layer 3: 16x16 → 8x8
    x = layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = layers.LeakyReLU(0.2)(x)

    # Final layer: 8x8 → 8x8 (no striding)
    x = layers.Conv2D(1, 4, padding='same')(x)
    return Model(inputs, x)


In [87]:
# ------------------
# Training Setup
# ------------------
generator = build_generator()
generator.summary()  # Should show output shape (None, 64, 64, 2)
print(generator.output_shape)  # Should be (None, 64, 64, 2)


discriminator = build_discriminator()

generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5, clipnorm=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-6, beta_1=0.5, clipnorm=0.5)

checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator=generator,
    discriminator=discriminator,
    epoch=tf.Variable(0)
)
manager = tf.train.CheckpointManager(checkpoint, CHECKPOINT_DIR, max_to_keep=3)

(None, 64, 64, 2)


In [88]:
# ------------------
# Training Utilities
# ------------------
def generate_images(model, test_input, epoch):
    input_L = test_input[0]  # ✅ Extract L channel
    target_AB = test_input[1]  # Ground truth AB

    # Predict using only L
    prediction = model(input_L, training=False)[0].numpy()
    L = input_L[0].numpy()[..., 0]  # Use first sample in batch

    plt.figure(figsize=(12, 4))

    # Input (grayscale)
    plt.subplot(1, 3, 1)
    plt.imshow(L, cmap='gray')
    plt.title("Input")
    plt.axis('off')

    # Ground truth (colorized)
    plt.subplot(1, 3, 2)
    true_rgb = lab2rgb(np.dstack((L, target_AB[0].numpy() * 128)))  # ✅ Use target_AB
    plt.imshow(true_rgb)
    plt.title("Ground Truth")
    plt.axis('off')

    # Predicted (colorized)
    plt.subplot(1, 3, 3)
    pred_rgb = lab2rgb(np.dstack((L, prediction * 128)))
    plt.imshow(pred_rgb)
    plt.title("Predicted")
    plt.axis('off')

    plt.savefig(os.path.join(RESULTS_DIR, f'epoch_{epoch+1}.png'))
    plt.close()


# --- Add PSNR/SSIM Calculations ---
# Convert LAB to RGB for metrics
def lab_to_rgb(lab):
    """Convert LAB tensor to RGB tensor (0-255 range) with mixed precision support"""
    # Ensure LAB tensor is float32 for stable calculations
    lab = tf.cast(lab, tf.float32)

    # Denormalize LAB
    L = lab[..., 0] * 100.0          # L: [0,100]
    ab = lab[..., 1:] * 128.0        # ab: [-128, 127]

    # Convert LAB to XYZ
    y = (L + 16.0) / 116.0
    x = ab[..., 0] / 500.0 + y
    z = y - ab[..., 1] / 200.0

    xyz = tf.stack([x, y, z], axis=-1)
    xyz = tf.where(xyz > 0.2068966, xyz**3, (xyz - 16.0/116.0)/7.787)

    # D65 reference white (cast to float32)
    xyz = xyz * tf.constant([95.047, 100.0, 108.883], dtype=tf.float32)

    # XYZ to RGB matrix
    rgb = tf.tensordot(xyz, tf.constant([
        [3.2406, -1.5372, -0.4986],
        [-0.9689, 1.8758, 0.0415],
        [0.0557, -0.2040, 1.0570]
    ], dtype=tf.float32), axes=1)

    # Gamma correction
    rgb = tf.where(rgb > 0.0031308,
                    1.055 * (rgb ** (1/2.4)) - 0.055,
                    12.92 * rgb)

    # Final conversion to float16 if needed
    return tf.cast(tf.clip_by_value(rgb * 255.0, 0.0, 255.0), tf.float32)

# ------------------
# 1. Data Validation Layer
# ------------------
def validate_data(L, AB):
    """Ensure inputs are within valid numerical ranges"""
    # Check for NaN/Inf in inputs
    L = tf.debugging.assert_all_finite(L, "Invalid values in L channel")
    AB = tf.debugging.assert_all_finite(AB, "Invalid values in AB channels")

    # Clip LAB values to valid ranges
    L = tf.clip_by_value(L, 0.0, 100.0)
    AB = tf.clip_by_value(AB, -128.0, 127.0)

    return L, AB

# ------------------
# 2. Safe Loss Functions
# ------------------

def safe_generator_loss(fake_output, real_ab, gen_ab):
    # Check input shapes
    # For PatchGAN output (8x8 feature map)
    tf.debugging.assert_shapes([
        (fake_output, ('batch', 8, 8, 1)),  # Adjusted to match PatchGAN output
        (real_ab, ('batch', 'height', 'width', 2)),
        (gen_ab, ('batch', 'height', 'width', 2))
    ])

    # Adversarial loss with stability epsilon
    adv_loss = tf.reduce_mean(tf.square(fake_output - 1.0 + 1e-7))

    # L1 loss with clipping
    l1_diff = tf.clip_by_value(real_ab - gen_ab, -1.0, 1.0)
    l1_loss = tf.reduce_mean(tf.abs(l1_diff))

    # Numerical checks
    adv_loss = tf.debugging.check_numerics(adv_loss, "Generator adv_loss NaN/Inf")
    l1_loss = tf.debugging.check_numerics(l1_loss, "Generator l1_loss NaN/Inf")

    total_loss = adv_loss + LAMBDA * l1_loss
    return tf.debugging.check_numerics(total_loss, "Generator total_loss NaN/Inf")

def safe_discriminator_loss(real_output, fake_output):
    # Add small epsilon to prevent log(0)
    real_loss = tf.reduce_mean(tf.square(real_output - 1.0 + 1e-7))
    fake_loss = tf.reduce_mean(tf.square(fake_output + 1e-7))

    # Numerical checks
    real_loss = tf.debugging.check_numerics(real_loss, "Disc real_loss NaN/Inf")
    fake_loss = tf.debugging.check_numerics(fake_loss, "Disc fake_loss NaN/Inf")

    total_loss = 0.5 * (real_loss + fake_loss)
    return tf.debugging.check_numerics(total_loss, "Disc total_loss NaN/Inf")


# ------------------
# 4. Modified Training Step
# ------------------
@tf.function
def train_step(input_L, input_AB):
    # Cast inputs to float16 to match mixed precision policy
    input_L = tf.cast(input_L, tf.float16)
    input_AB = tf.cast(input_AB, tf.float16)

    with tf.GradientTape(persistent=True) as tape:
        generated_AB = generator(input_L, training=True)

        # Cast generator output to match discriminator input dtype
        generated_AB = tf.cast(generated_AB, tf.float16)

        # Ensure matching dtypes before concatenation
        real_images = tf.concat([
            tf.cast(input_L, tf.float16),
            tf.cast(input_AB, tf.float16)
        ], axis=-1)

        fake_images = tf.concat([
            tf.cast(input_L, tf.float16),
            generated_AB
        ], axis=-1)

        rgb_real = lab_to_rgb(real_images)
        rgb_fake = lab_to_rgb(fake_images)

        # Discriminator outputs
        disc_real = discriminator(real_images, training=True)
        disc_fake = discriminator(fake_images, training=True)

        # Calculate losses
        gen_loss = safe_generator_loss(disc_fake, input_AB, generated_AB)
        disc_loss = safe_discriminator_loss(disc_real, disc_fake)

    # Calculate and clip gradients
    gen_grads = tape.gradient(gen_loss, generator.trainable_variables)
    disc_grads = tape.gradient(disc_loss, discriminator.trainable_variables)

    # Gradient clipping and validation
    gen_grads = [tf.clip_by_norm(g, 1.0) for g in gen_grads]
    disc_grads = [tf.clip_by_norm(g, 1.0) for g in disc_grads]

    # Check gradients before applying
    for g in gen_grads + disc_grads:
        tf.debugging.check_numerics(g, "NaN/Inf in gradients")

    # Apply gradients
    generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))

    # Calculate metrics
    psnr = safe_psnr(rgb_real, rgb_fake)
    ssim = safe_ssim(rgb_real, rgb_fake)

    return gen_loss, disc_loss, psnr, ssim




In [89]:
from skimage.metrics import peak_signal_noise_ratio
import cv2
import numpy as np

def calculate_frame_psnr(original, colorized):
    """
    Calculate PSNR for a single frame pair.
    Args:
        original: Ground truth frame (BGR)
        colorized: Colorized frame (BGR)
    Returns:
        PSNR value in dB
    """
    # Convert to YCrCb for luminance comparison (optional)
    original_yuv = cv2.cvtColor(original, cv2.COLOR_BGR2YCrCb)
    colorized_yuv = cv2.cvtColor(colorized, cv2.COLOR_BGR2YCrCb)

    # Calculate PSNR for each channel
    psnr_y = peak_signal_noise_ratio(original_yuv[...,0], colorized_yuv[...,0], data_range=255)
    psnr_cr = peak_signal_noise_ratio(original_yuv[...,1], colorized_yuv[...,1], data_range=255)
    psnr_cb = peak_signal_noise_ratio(original_yuv[...,2], colorized_yuv[...,2], data_range=255)

    return np.mean([psnr_y, psnr_cr, psnr_cb])

def calculate_video_psnr(original_video_path, colorized_video_path):
    """
    Calculate average PSNR between two videos.
    Returns:
        Mean PSNR (dB), Frame-wise PSNR array
    """
    cap_orig = cv2.VideoCapture(original_video_path)
    cap_color = cv2.VideoCapture(colorized_video_path)

    psnr_values = []

    while True:
        ret_orig, frame_orig = cap_orig.read()
        ret_color, frame_color = cap_color.read()

        if not ret_orig or not ret_color:
            break

        # Resize if necessary (match dimensions)
        if frame_orig.shape != frame_color.shape:
            frame_color = cv2.resize(frame_color, (frame_orig.shape[1], frame_orig.shape[0]))

        psnr = calculate_frame_psnr(frame_orig, frame_color)
        psnr_values.append(psnr)

    cap_orig.release()
    cap_color.release()

    return np.mean(psnr_values), psnr_values

In [94]:
# ------------------
# Training Loop
# ------------------
def train():
    train_dataset = create_dataset(DATA_DIR, 'train')
    val_dataset = create_dataset(DATA_DIR, 'val')

    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        print(f"Resumed from epoch {checkpoint.epoch.numpy()}")

    # Initialize metrics
    psnr_metric = tf.keras.metrics.Mean(name='psnr')
    ssim_metric = tf.keras.metrics.Mean(name='ssim')

    for epoch in range(checkpoint.epoch.numpy(), EPOCHS):
        start = time.time()
        gen_losses, disc_losses = [], []
        # Reset metrics each epoch (CORRECTED METHOD NAME)
        psnr_metric.reset_state()
        ssim_metric.reset_state()

        # Training phase
        for batch, (L, AB) in enumerate(train_dataset):
            gen_loss, disc_loss, psnr, ssim = train_step(L, AB)
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)

            # Update metrics
            psnr_metric.update_state(psnr)
            ssim_metric.update_state(ssim)


            # Existing logging
            if batch % 100 == 0:
                gen_loss_val = gen_loss.numpy().item()
                disc_loss_val = disc_loss.numpy().item()
                print(f"Batch {batch} | PSNR: {psnr_metric.result():.2f} | SSIM: {ssim_metric.result():.3f}")
                print(f"Gen: {gen_loss_val:.2f} Disc: {disc_loss_val:.2f}")


        # In your training loop after the epoch's training phase:
        if (epoch + 1) % 5 == 0:
            manager.save()
            test_batch = next(iter(val_dataset))
            generate_images(generator, test_batch, epoch)

        # Validation phase
        val_psnr = []
        val_ssim = []
        for val_L, val_AB in val_dataset.take(10):
           val_gen_AB = generator(val_L, training=False)

        # Convert to float32 before concatenation
           val_real_lab = tf.concat([
               tf.cast(val_L, tf.float32),
               tf.cast(val_AB, tf.float32)
           ], axis=-1)
           val_real_rgb = lab_to_rgb(val_real_lab)

           val_fake_lab = tf.concat([
               tf.cast(val_L, tf.float32),
               tf.cast(val_gen_AB, tf.float32)
           ], axis=-1)
           val_fake_rgb = lab_to_rgb(val_fake_lab)

        # Calculate metrics
        batch_psnr = tf.reduce_mean(tf.image.psnr(val_real_rgb, val_fake_rgb, max_val=255))
        batch_ssim = tf.reduce_mean(tf.image.ssim(val_real_rgb, val_fake_rgb, max_val=255))

        val_psnr.append(batch_psnr.numpy())
        val_ssim.append(batch_ssim.numpy())

        # Epoch summary
        print(f"\nEpoch {epoch+1}")
        print(f"Time: {time.time()-start:.2f}s")
        print(f"Gen Loss: {np.mean(gen_losses):.4f}")
        print(f"Disc Loss: {np.mean(disc_losses):.4f}\n")
        print(f"Train PSNR: {psnr_metric.result():.2f} dB")
        print(f"Train SSIM: {ssim_metric.result():.4f}")
        print(f"Val PSNR: {np.mean(val_psnr):.2f} dB")
        print(f"Val SSIM: {np.mean(val_ssim):.4f}")

        checkpoint.epoch.assign_add(1)


In [91]:
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)


In [101]:
if __name__ == "__main__":
    train()

Resumed from epoch 49
Batch 0 | PSNR: 10.78 | SSIM: 0.598
Gen: 9.03 Disc: 0.25
Batch 100 | PSNR: 9.74 | SSIM: 0.573
Gen: 9.73 Disc: 0.25
Batch 200 | PSNR: 10.12 | SSIM: 0.579
Gen: 9.57 Disc: 0.25
Batch 300 | PSNR: 10.03 | SSIM: 0.578
Gen: 10.38 Disc: 0.25
Batch 400 | PSNR: 10.08 | SSIM: 0.581
Gen: 10.70 Disc: 0.25
Batch 500 | PSNR: 9.99 | SSIM: 0.579
Gen: 9.55 Disc: 0.25
Batch 600 | PSNR: 10.06 | SSIM: 0.580
Gen: 9.27 Disc: 0.25
Batch 700 | PSNR: 10.15 | SSIM: 0.582
Gen: 9.21 Disc: 0.25
Batch 800 | PSNR: 10.21 | SSIM: 0.583
Gen: 10.38 Disc: 0.25
Batch 900 | PSNR: 10.24 | SSIM: 0.585
Gen: 9.68 Disc: 0.25
Batch 1000 | PSNR: 10.23 | SSIM: 0.585
Gen: 11.65 Disc: 0.25
Batch 1100 | PSNR: 10.09 | SSIM: 0.579
Gen: 13.61 Disc: 0.25
Batch 1200 | PSNR: 10.08 | SSIM: 0.579
Gen: 9.30 Disc: 0.25
Batch 1300 | PSNR: 10.13 | SSIM: 0.581
Gen: 9.26 Disc: 0.25
Batch 1400 | PSNR: 10.18 | SSIM: 0.582
Gen: 9.93 Disc: 0.25
Batch 1500 | PSNR: 10.22 | SSIM: 0.584
Gen: 9.27 Disc: 0.25
Batch 1600 | PSNR: 10.27 | 

In [10]:
# ----------------------
# Inference Function
# ----------------------
def colorize_image(model, image_path, output_path, image_size=64):
    """
    Colorizes a single image using the trained generator.

    Args:
        model: Trained generator model
        image_path: Path to input grayscale/RGB image
        output_path: Path to save colorized image
        image_size: Size to resize image (must match model input)
    """
    # Load and preprocess image
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Could not load image at {image_path}")

    # Convert to RGB if needed
    if image.ndim == 2:  # Grayscale
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    else:  # BGR to RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Resize and normalize
    image = cv2.resize(image, (image_size, image_size))
    image = image.astype(np.float32) / 255.0

    # Convert to LAB and extract L channel
    lab = rgb2lab(image)
    L = lab[:, :, 0:1]  # (H, W, 1)

    # Add batch dimension and predict
    L_batch = np.expand_dims(L, axis=0)  # (1, H, W, 1)
    AB_pred = model.predict(L_batch, verbose=0)[0]  # (H, W, 2)

    # Denormalize AB channels
    AB_pred = (AB_pred * 128.0).astype(np.float32)

    # Combine with L and convert to RGB
    colorized_lab = np.concatenate([L, AB_pred], axis=-1)
    colorized_rgb = lab2rgb(colorized_lab)

    # Clip and save
    colorized_rgb = np.clip(colorized_rgb, 0, 1)
    plt.imsave(output_path, colorized_rgb)
    print(f"Colorized image saved to {output_path}")

In [18]:
# After training, use like this:
colorize_image(
    generator,
    "/content/Test/rose.jpeg",  # Input path
    "/content/colorized_result.jpg"     # Output path
)

calculate_frame_psnr("/content/Test/rose.jpeg","/content/colorized_result.jpg")

Colorized image saved to /content/colorized_result.jpg


In [11]:
# ----------------------
# Parallel Video Colorization with Temporal Consistency
# ----------------------
import cv2
import numpy as np
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from skimage.color import rgb2lab, lab2rgb

# ----------------------
# Video Colorizer with Content Directory Temp Files
# ----------------------
import cv2
import numpy as np
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from skimage.color import rgb2lab, lab2rgb

class VideoColorizer:
    def __init__(self, model, temporal_alpha=0.8, blend_factor=0.7):
        self.model = model
        self.temporal_alpha = temporal_alpha
        self.blend_factor = blend_factor
        # Optical flow parameters
        self.flow_params = {
            'pyr_scale': 0.5,
            'levels': 3,
            'winsize': 15,
            'iterations': 3,
            'poly_n': 5,
            'poly_sigma': 1.2,
            'flags': cv2.OPTFLOW_FARNEBACK_GAUSSIAN
        }
        # Create temp directory in Colab's content directory
        self.temp_dir = tempfile.TemporaryDirectory(
            dir='/content',
            prefix='colorizer_temp_'
        )
        os.makedirs(self.temp_dir.name, exist_ok=True)
        os.chmod(self.temp_dir.name, 0o777)  # Ensure write permissions
        print(f"Created temp directory at: {self.temp_dir.name}")

    def _process_chunk(self, frames, start_idx):
        """Process a chunk of frames in parallel"""
        print(f"Processing chunk starting at {start_idx} with {len(frames)} frames")

        for local_idx, frame in enumerate(frames):
            global_idx = start_idx + local_idx
            save_path = os.path.join(
                self.temp_dir.name,
                f"frame_{global_idx:06d}.npy"  # Consistent naming
            )

            # Debug: Verify frame content
            if frame is None or frame.size == 0:
                raise ValueError(f"Invalid frame at index {global_idx}")

            # Colorization pipeline
            resized_rgb = cv2.resize(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), (64, 64))
            lab = rgb2lab(resized_rgb.astype(np.float32)/255.0)
            L = lab[:, :, 0:1]
            AB = self.model.predict(np.expand_dims(L, axis=0), verbose=0)[0]

            # Save data with verification
            data = {
                'frame': frame,
                'L': L,
                'AB': AB,
                'gray': cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            }
            np.save(save_path, data)
            print(f"Saved frame {global_idx} to {save_path}")

        # Immediate verification of saved files
        saved_files = [f for f in os.listdir(self.temp_dir.name)
                      if f.startswith(f"frame_{start_idx:06d}")]
        print(f"Saved {len(saved_files)} files in this chunk")

        return True

    def _temporal_smooth(self, prev_data, current_data):
        """Apply temporal consistency between frames"""
        if prev_data is None:
            return current_data['AB']

        # Compute optical flow at model resolution
        target_size = (64, 64)
        prev_gray = cv2.resize(prev_data['gray'], target_size)
        current_gray = cv2.resize(current_data['gray'], target_size)

        flow = cv2.calcOpticalFlowFarneback(
            prev_gray, current_gray,
            None, **self.flow_params
        )

        # Create normalized coordinate grid
        h, w = target_size
        x_map, y_map = np.meshgrid(np.arange(w), np.arange(h))
        flow_map = np.stack([
           (x_map + flow[..., 0]).astype(np.float32),
           (y_map + flow[..., 1]).astype(np.float32)
        ], axis=-1)

        # Ensure coordinates stay within image bounds
        flow_map[..., 0] = np.clip(flow_map[..., 0], 0, w-1)
        flow_map[..., 1] = np.clip(flow_map[..., 1], 0, h-1)



        # Warp previous AB channels
        warped_AB = cv2.remap(
            prev_data['AB'].astype(np.float32),  # Ensure float32 input
            flow_map,
            None,
            cv2.INTER_LINEAR,
            borderMode=cv2.BORDER_REFLECT
        )

        # Handle invalid regions (black borders from warping)
        mask = (warped_AB == 0).all(axis=-1, keepdims=True)
        blended_AB = np.where(mask, current_data['AB'],
                         self.blend_factor * current_data['AB'] +
                         (1 - self.blend_factor) * warped_AB)

        smoothed_AB = self.temporal_alpha * blended_AB + \
                     (1 - self.temporal_alpha) * warped_AB
        return smoothed_AB


    def colorize_video(self, input_path, output_path, batch_size=16, workers=8):
        # Open video once for all processing
        cap = cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f"Couldn't open video {input_path}")

        # Get video properties from the first frame
        ret, first_frame = cap.read()
        if not ret:
            cap.release()
            raise ValueError("Couldn't read first frame")
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        print(f"Width: {frame_width} | Height: {frame_height} | FPS: {fps}")
        # Rewind to beginning
        cap.set(cv2.CAP_PROP_POS_FRAMES, 0)

        # Initialize video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height), isColor=True)

        # Process frames in a single pass
        executor = ThreadPoolExecutor(max_workers=workers)
        futures = []
        chunk = []
        frame_count = 0

        while True:
            ret, frame = cap.read()
            if not ret:
                break
            chunk.append(frame)
            frame_count += 1
            if len(chunk) == batch_size:
                future = executor.submit(self._process_chunk, chunk.copy(), frame_count - len(chunk))
                futures.append(future)
                chunk = []

        # Process remaining frames
        if chunk:
            future = executor.submit(self._process_chunk, chunk, frame_count - len(chunk))
            futures.append(future)

        # Wait for all processing to finish
        for f in futures:
            f.result()

        cap.release()


        # 4. Reconstruct video with temporal smoothing
        saved_files = sorted(
            [f for f in os.listdir(self.temp_dir.name)
             if f.startswith('frame_') and f.endswith('.npy')],
            key=lambda x: int(x.split('_')[1].split('.')[0]))

        prev_data = None
        for i, filename in enumerate(saved_files):
            file_path = os.path.join(self.temp_dir.name, filename)
            data = np.load(file_path, allow_pickle=True).item()

            # Apply temporal smoothing
            smoothed_AB = self._temporal_smooth(prev_data, data)

            # Reconstruct frame
            rgb_resized = cv2.resize(
                cv2.cvtColor(data['frame'], cv2.COLOR_BGR2RGB),
                (64, 64))
            lab = rgb2lab(rgb_resized.astype(np.float32)/255.0)
            final_lab = np.concatenate([lab[..., 0:1], smoothed_AB], axis=-1)

            # Convert to output dimensions
            colorized_rgb = (lab2rgb(final_lab) * 255).astype(np.uint8)
            final_frame = cv2.resize(
            colorized_rgb,
            (frame_width, frame_height))
            final_bgr = cv2.cvtColor(final_frame, cv2.COLOR_RGB2BGR)

            writer.write(final_bgr)

            prev_data = {'AB': smoothed_AB, 'gray': data['gray']}

            if i % 10 == 0:
                print(f"Processed {i+1}/{len(saved_files)} frames")

        # 5. Final cleanup
        writer.release()
        self.temp_dir.cleanup()

        # Verify output
        if os.path.exists(output_path):
            print(f"\n✅ Success! Colorized video saved to: {output_path}")
            print(f"Resolution: {frame_width}x{frame_height} | Frames: {len(saved_files)}")
        else:
            print("\n❌ Video creation failed - check codec compatibility")




In [None]:
colorizer = VideoColorizer(generator, temporal_alpha=0.85)


colorizer.colorize_video(
    "/content/drive/MyDrive/Colorization/Video/Input.mp4",
    "/content/drive/MyDrive/Colorization/Video/Output.mp4",
    batch_size=32,
    workers=8,
)

In [None]:
calculate_video_psnr("/content/drive/MyDrive/Colorization/Video/Input.mp4",
    "/content/drive/MyDrive/Colorization/Video/Output.mp4")

In [13]:
!ffprobe -v error -show_entries format=duration \
         -of default=noprint_wrappers=1:nokey=1 \
         "/content/drive/MyDrive/Colorization/Video/Input2.mp4"

8.600000


In [14]:
!ffprobe -v error -show_entries format=duration \
         -of default=noprint_wrappers=1:nokey=1 \
         "/content/drive/MyDrive/Colorization/Video/Output.mp4"

10.811000
