<a href="https://colab.research.google.com/github/Jess-Lau/Real-Life-B-W-Video-Colorization-Project/blob/main/ColorizerGANTrueFinalVer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ----------------------
# Colorization GAN
# ----------------------
import os
import cv2
import time
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, Sequential
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb



In [16]:
# ------------------
# Configuration
# ------------------
IMAGE_SIZE = 64
CHANNELS = 1
EPOCHS = 70
BATCH_SIZE = 512
LAMBDA = 100
DATA_DIR = "/content/drive/MyDrive/ImageNet"  # Update with your path
WORKDIR = "/content/drive/MyDrive/Colorization"
CHECKPOINT_DIR = os.path.join(WORKDIR, "checkpoints2")
RESULTS_DIR = os.path.join(WORKDIR, "results2")

# Enable mixed precision with proper policy
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)


# Create directories
os.makedirs(WORKDIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

In [10]:
!nvidia-smi


Sun Apr  6 20:52:37 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             53W /  400W |    6593MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
# ------------------
# Data Pipeline
# ------------------
def load_mean(data_dir):
    """Load mean image from first training batch"""
    with open(os.path.join(data_dir, 'train_data_batch_1'), 'rb') as f:
        data = pickle.load(f)
        mean = data['mean'].astype(np.float32) / 255.0
        return mean.reshape(3, IMAGE_SIZE, IMAGE_SIZE).transpose(1, 2, 0)

def data_generator(data_dir, split='train'):
    mean = load_mean(data_dir) if split == 'train' else None
    files = [f'train_data_batch_{i}' for i in range(1, 11)] if split == 'train' else ['val_data']

    for file in files:
        path = os.path.join(data_dir, file)
        try:
            with open(path, 'rb') as f:
                data = pickle.load(f)
                x = data['data'].astype(np.float32) / 255.0
                x = x.reshape(-1, 3, IMAGE_SIZE, IMAGE_SIZE).transpose(0, 2, 3, 1)

                if mean is not None:
                    x -= mean

                for i in range(0, x.shape[0], BATCH_SIZE):
                    batch_rgb = x[i:i+BATCH_SIZE]
                    batch_lab = np.array([rgb2lab(img) for img in batch_rgb])
                    L = batch_lab[..., 0:1].astype(np.float32)
                    AB = (batch_lab[..., 1:] / 128.0).astype(np.float32)
                    yield L, AB
        except Exception as e:
            print(f"❌ Failed to load {path}: {str(e)}")
            continue  # Skip problematic files

def create_dataset(data_dir, split='train'):
    return tf.data.Dataset.from_generator(
        lambda: data_generator(data_dir, split),
        output_signature=(  # ✅ Proper parentheses
            tf.TensorSpec(shape=(None, 64, 64, 1), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 64, 64, 2), dtype=tf.float32)
        )
    ).prefetch(tf.data.AUTOTUNE)  # ✅ .prefetch() called on dataset

In [4]:
# ------------------
# Model Architectures
# ------------------
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    model = Sequential()
    model.add(layers.Conv2D(filters, size, strides=2, padding='same',
                          kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(0.2))
    return model

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    model = Sequential()
    model.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                    kernel_initializer=initializer, use_bias=False))
    model.add(layers.BatchNormalization())
    if apply_dropout:
        model.add(layers.Dropout(0.5))
    model.add(layers.ReLU())
    return model

def build_generator():
    inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))

    # Encoder
    d1 = downsample(64, 4, False)(inputs)    # 32x32
    d2 = downsample(128, 4)(d1)              # 16x16
    d3 = downsample(256, 4)(d2)              # 8x8
    d4 = downsample(512, 4)(d3)              # 4x4

    # Decoder
    u1 = upsample(512, 4, True)(d4)          # 8x8
    u1 = layers.Concatenate()([u1, d3])
    u2 = upsample(256, 4)(u1)                # 16x16
    u2 = layers.Concatenate()([u2, d2])
    u3 = upsample(128, 4)(u2)                # 32x32
    u3 = layers.Concatenate()([u3, d1])
    u4 = upsample(64, 4)(u3)                 # 64x64

    output = layers.Conv2D(2, 3, padding='same', activation='tanh')(u4)
    return Model(inputs, output)

def build_discriminator():
    inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))

    x = layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1, activation='sigmoid')(x)
    return Model(inputs, x)

In [5]:
# ------------------
# Training Setup
# ------------------
generator = build_generator()
discriminator = build_discriminator()

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator=generator,
    discriminator=discriminator,
    epoch=tf.Variable(0)
)
manager = tf.train.CheckpointManager(checkpoint, CHECKPOINT_DIR, max_to_keep=3)

In [6]:
# ------------------
# Training Utilities
# ------------------
def generate_images(model, test_input, epoch):
    input_L = test_input[0]  # ✅ Extract L channel
    target_AB = test_input[1]  # Ground truth AB

    # Predict using only L
    prediction = model(input_L, training=False)[0].numpy()
    L = input_L[0].numpy()[..., 0]  # Use first sample in batch

    plt.figure(figsize=(12, 4))

    # Input (grayscale)
    plt.subplot(1, 3, 1)
    plt.imshow(L, cmap='gray')
    plt.title("Input")
    plt.axis('off')

    # Ground truth (colorized)
    plt.subplot(1, 3, 2)
    true_rgb = lab2rgb(np.dstack((L, target_AB[0].numpy() * 128)))  # ✅ Use target_AB
    plt.imshow(true_rgb)
    plt.title("Ground Truth")
    plt.axis('off')

    # Predicted (colorized)
    plt.subplot(1, 3, 3)
    pred_rgb = lab2rgb(np.dstack((L, prediction * 128)))
    plt.imshow(pred_rgb)
    plt.title("Predicted")
    plt.axis('off')

    plt.savefig(os.path.join(RESULTS_DIR, f'epoch_{epoch+1}.png'))
    plt.close()


# --- Add PSNR/SSIM Calculations ---
# Convert LAB to RGB for metrics
def lab_to_rgb(lab):
    """Convert LAB tensor to RGB tensor (0-255 range) with mixed precision support"""
    # Ensure LAB tensor is float32 for stable calculations
    lab = tf.cast(lab, tf.float32)

    # Denormalize LAB
    L = lab[..., 0] * 100.0          # L: [0,100]
    ab = lab[..., 1:] * 128.0        # ab: [-128, 127]

    # Convert LAB to XYZ
    y = (L + 16.0) / 116.0
    x = ab[..., 0] / 500.0 + y
    z = y - ab[..., 1] / 200.0

    xyz = tf.stack([x, y, z], axis=-1)
    xyz = tf.where(xyz > 0.2068966, xyz**3, (xyz - 16.0/116.0)/7.787)

    # D65 reference white (cast to float32)
    xyz = xyz * tf.constant([95.047, 100.0, 108.883], dtype=tf.float32)

    # XYZ to RGB matrix
    rgb = tf.tensordot(xyz, tf.constant([
        [3.2406, -1.5372, -0.4986],
        [-0.9689, 1.8758, 0.0415],
        [0.0557, -0.2040, 1.0570]
    ], dtype=tf.float32), axes=1)

    # Gamma correction
    rgb = tf.where(rgb > 0.0031308,
                    1.055 * (rgb ** (1/2.4)) - 0.055,
                    12.92 * rgb)

    # Final conversion to float16 if needed
    return tf.cast(tf.clip_by_value(rgb * 255.0, 0.0, 255.0), tf.float16)

def safe_psnr(real, fake, max_val, eps=1e-10):
    # Cast both tensors to float32 to ensure dtype consistency
    real = tf.cast(real, tf.float32)
    fake = tf.cast(fake, tf.float32)
    mse = tf.reduce_mean(tf.square(real - fake))
    return 20 * tf.math.log(max_val) / tf.math.log(10.0) - 10 * tf.math.log(mse + eps) / tf.math.log(10.0)

@tf.function
def train_step(input_L, input_AB):
    # Cast to mixed precision
    input_L = tf.cast(input_L, tf.float16)
    input_AB = tf.cast(input_AB, tf.float16)

    with tf.GradientTape(persistent=True) as tape:
        generated_AB = generator(input_L, training=True)

        # Create concatenated images
        real_images = tf.concat([input_L, input_AB], axis=-1)
        fake_images = tf.concat([input_L, generated_AB], axis=-1)

        # Discriminator outputs
        disc_real = discriminator(real_images, training=True)
        disc_fake = discriminator(fake_images, training=True)

        # Loss calculations
        gen_loss = tf.keras.losses.binary_crossentropy(
            tf.ones_like(disc_fake), disc_fake) + LAMBDA * tf.reduce_mean(tf.abs(input_AB - generated_AB))
        disc_loss = tf.keras.losses.binary_crossentropy(
            tf.ones_like(disc_real), disc_real) + tf.keras.losses.binary_crossentropy(
            tf.zeros_like(disc_fake), disc_fake)



    # Ground truth RGB (from original LAB)
    lab_real = tf.concat([input_L, input_AB], axis=-1)
    rgb_real = lab_to_rgb(lab_real)

    # Generated RGB
    lab_fake = tf.concat([input_L, generated_AB], axis=-1)
    rgb_fake = lab_to_rgb(lab_fake)

    # Calculate metrics
    psnr = safe_psnr(rgb_real, rgb_fake, max_val=255.0)
    ssim = tf.image.ssim(rgb_real, rgb_fake, max_val=255.0)

    # Apply gradient clipping
    gen_grads = tape.gradient(gen_loss, generator.trainable_variables)
    gen_grads = [tf.clip_by_norm(g, 1.0) for g in gen_grads]
    generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))

    disc_grads = tape.gradient(disc_loss, discriminator.trainable_variables)
    disc_grads = [tf.clip_by_norm(g, 1.0) for g in disc_grads]
    discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))

    return tf.reduce_mean(gen_loss), tf.reduce_mean(disc_loss), psnr, ssim




In [7]:
from skimage.metrics import peak_signal_noise_ratio
import cv2
import numpy as np

def calculate_frame_psnr(original, colorized):
    """
    Calculate PSNR for a single frame pair.
    Args:
        original: Ground truth frame (BGR)
        colorized: Colorized frame (BGR)
    Returns:
        PSNR value in dB
    """
    # Convert to YCrCb for luminance comparison (optional)
    original_yuv = cv2.cvtColor(original, cv2.COLOR_BGR2YCrCb)
    colorized_yuv = cv2.cvtColor(colorized, cv2.COLOR_BGR2YCrCb)

    # Calculate PSNR for each channel
    psnr_y = peak_signal_noise_ratio(original_yuv[...,0], colorized_yuv[...,0], data_range=255)
    psnr_cr = peak_signal_noise_ratio(original_yuv[...,1], colorized_yuv[...,1], data_range=255)
    psnr_cb = peak_signal_noise_ratio(original_yuv[...,2], colorized_yuv[...,2], data_range=255)

    return np.mean([psnr_y, psnr_cr, psnr_cb])

def calculate_video_psnr(original_video_path, colorized_video_path):
    """
    Calculate average PSNR between two videos.
    Returns:
        Mean PSNR (dB), Frame-wise PSNR array
    """
    cap_orig = cv2.VideoCapture(original_video_path)
    cap_color = cv2.VideoCapture(colorized_video_path)

    psnr_values = []

    while True:
        ret_orig, frame_orig = cap_orig.read()
        ret_color, frame_color = cap_color.read()

        if not ret_orig or not ret_color:
            break

        # Resize if necessary (match dimensions)
        if frame_orig.shape != frame_color.shape:
            frame_color = cv2.resize(frame_color, (frame_orig.shape[1], frame_orig.shape[0]))

        psnr = calculate_frame_psnr(frame_orig, frame_color)
        psnr_values.append(psnr)

    cap_orig.release()
    cap_color.release()

    return np.mean(psnr_values), psnr_values

In [8]:
# ------------------
# Training Loop
# ------------------
def train():
    train_dataset = create_dataset(DATA_DIR, 'train')
    val_dataset = create_dataset(DATA_DIR, 'val')

    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        print(f"Resumed from epoch {checkpoint.epoch.numpy()}")

    # Initialize metrics
    psnr_metric = tf.keras.metrics.Mean(name='psnr')
    ssim_metric = tf.keras.metrics.Mean(name='ssim')

    for epoch in range(checkpoint.epoch.numpy(), EPOCHS):
        start = time.time()
        gen_losses, disc_losses = [], []
        # Reset metrics each epoch (CORRECTED METHOD NAME)
        psnr_metric.reset_state()
        ssim_metric.reset_state()

        # Training phase
        for batch, (L, AB) in enumerate(train_dataset):
            gen_loss, disc_loss, psnr, ssim = train_step(L, AB)
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)

            # Update metrics
            psnr_metric.update_state(psnr)
            ssim_metric.update_state(ssim)


            # Existing logging
            if batch % 100 == 0:
                gen_loss_val = gen_loss.numpy().item()
                disc_loss_val = disc_loss.numpy().item()
                print(f"Batch {batch} | PSNR: {psnr_metric.result():.2f} | SSIM: {ssim_metric.result():.3f}")
                print(f"Gen: {gen_loss_val:.2f} Disc: {disc_loss_val:.2f}")


        # In your training loop after the epoch's training phase:
        if (epoch + 1) % 5 == 0:
            manager.save()
            test_batch = next(iter(val_dataset))
            generate_images(generator, test_batch, epoch)

        # Validation phase
        val_psnr = []
        val_ssim = []
        for val_L, val_AB in val_dataset.take(10):
           val_gen_AB = generator(val_L, training=False)

        # Convert to float32 before concatenation
           val_real_lab = tf.concat([
               tf.cast(val_L, tf.float32),
               tf.cast(val_AB, tf.float32)
           ], axis=-1)
           val_real_rgb = lab_to_rgb(val_real_lab)

           val_fake_lab = tf.concat([
               tf.cast(val_L, tf.float32),
               tf.cast(val_gen_AB, tf.float32)
           ], axis=-1)
           val_fake_rgb = lab_to_rgb(val_fake_lab)

        # Calculate metrics
        batch_psnr = tf.reduce_mean(tf.image.psnr(val_real_rgb, val_fake_rgb, max_val=255))
        batch_ssim = tf.reduce_mean(tf.image.ssim(val_real_rgb, val_fake_rgb, max_val=255))

        val_psnr.append(batch_psnr.numpy())
        val_ssim.append(batch_ssim.numpy())

        # Epoch summary
        print(f"\nEpoch {epoch+1}")
        print(f"Time: {time.time()-start:.2f}s")
        print(f"Gen Loss: {np.mean(gen_losses):.4f}")
        print(f"Disc Loss: {np.mean(disc_losses):.4f}\n")
        print(f"Train PSNR: {psnr_metric.result():.2f} dB")
        print(f"Train SSIM: {ssim_metric.result():.4f}")
        print(f"Val PSNR: {np.mean(val_psnr):.2f} dB")
        print(f"Val SSIM: {np.mean(val_ssim):.4f}")

        checkpoint.epoch.assign_add(1)


In [None]:
if __name__ == "__main__":
    train()

Resumed from epoch 59
Batch 0 | PSNR: 24.42 | SSIM: 0.980
Gen: 9.77 Disc: 1.40
Batch 100 | PSNR: 24.98 | SSIM: 0.981
Gen: 9.94 Disc: 1.33
Batch 200 | PSNR: 24.89 | SSIM: 0.981
Gen: 10.27 Disc: 1.36
Batch 300 | PSNR: 24.92 | SSIM: 0.981
Gen: 9.96 Disc: 1.36
Batch 400 | PSNR: 24.94 | SSIM: 0.981
Gen: 10.21 Disc: 1.38
Batch 500 | PSNR: 24.94 | SSIM: 0.981
Gen: 10.34 Disc: 1.38
Batch 600 | PSNR: 24.96 | SSIM: 0.981
Gen: 10.09 Disc: 1.39
Batch 700 | PSNR: 24.97 | SSIM: 0.981
Gen: 10.32 Disc: 1.35
Batch 800 | PSNR: 24.98 | SSIM: 0.981
Gen: 10.57 Disc: 1.30
Batch 900 | PSNR: 24.98 | SSIM: 0.981
Gen: 10.14 Disc: 1.38
Batch 1000 | PSNR: 24.98 | SSIM: 0.981
Gen: 10.67 Disc: 1.33
Batch 1100 | PSNR: 24.98 | SSIM: 0.981
Gen: 10.02 Disc: 1.34
Batch 1200 | PSNR: 24.98 | SSIM: 0.981
Gen: 10.52 Disc: 1.42
Batch 1300 | PSNR: 24.98 | SSIM: 0.981
Gen: 10.12 Disc: 1.35
Batch 1400 | PSNR: 24.97 | SSIM: 0.981
Gen: 10.63 Disc: 1.33
Batch 1500 | PSNR: 24.97 | SSIM: 0.981
Gen: 9.86 Disc: 1.33
Batch 1600 | PSNR:

In [None]:
# ----------------------
# Inference Function
# ----------------------
def colorize_image(model, image_path, output_path, image_size=64):
    """
    Colorizes a single image using the trained generator.

    Args:
        model: Trained generator model
        image_path: Path to input grayscale/RGB image
        output_path: Path to save colorized image
        image_size: Size to resize image (must match model input)
    """
    # Load and preprocess image
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Could not load image at {image_path}")

    # Convert to RGB if needed
    if image.ndim == 2:  # Grayscale
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    else:  # BGR to RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Resize and normalize
    image = cv2.resize(image, (image_size, image_size))
    image = image.astype(np.float32) / 255.0

    # Convert to LAB and extract L channel
    lab = rgb2lab(image)
    L = lab[:, :, 0:1]  # (H, W, 1)

    # Add batch dimension and predict
    L_batch = np.expand_dims(L, axis=0)  # (1, H, W, 1)
    AB_pred = model.predict(L_batch, verbose=0)[0]  # (H, W, 2)

    # Denormalize AB channels
    AB_pred = (AB_pred * 128.0).astype(np.float32)

    # Combine with L and convert to RGB
    colorized_lab = np.concatenate([L, AB_pred], axis=-1)
    colorized_rgb = lab2rgb(colorized_lab)

    # Clip and save
    colorized_rgb = np.clip(colorized_rgb, 0, 1)
    plt.imsave(output_path, colorized_rgb)
    print(f"Colorized image saved to {output_path}")

In [None]:
# After training, use like this:
colorize_image(
    generator,
    "/content/Test/rose.jpeg",  # Input path
    "/content/colorized_result.jpg"     # Output path
)

Colorized image saved to /content/colorized_result.jpg


In [None]:
# ----------------------
# Parallel Video Colorization with Temporal Consistency
# ----------------------
import cv2
import numpy as np
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from skimage.color import rgb2lab, lab2rgb

# ----------------------
# Video Colorizer with Content Directory Temp Files
# ----------------------
import cv2
import numpy as np
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from skimage.color import rgb2lab, lab2rgb

class VideoColorizer:
    def __init__(self, model, temporal_alpha=0.8, blend_factor=0.7):
        self.model = model
        self.temporal_alpha = temporal_alpha
        self.blend_factor = blend_factor
        # Optical flow parameters
        self.flow_params = {
            'pyr_scale': 0.5,
            'levels': 3,
            'winsize': 15,
            'iterations': 3,
            'poly_n': 5,
            'poly_sigma': 1.2,
            'flags': cv2.OPTFLOW_FARNEBACK_GAUSSIAN
        }
        # Create temp directory in Colab's content directory
        self.temp_dir = tempfile.TemporaryDirectory(
            dir='/content',
            prefix='colorizer_temp_'
        )
        os.makedirs(self.temp_dir.name, exist_ok=True)
        os.chmod(self.temp_dir.name, 0o777)  # Ensure write permissions
        print(f"Created temp directory at: {self.temp_dir.name}")

    def _process_chunk(self, frames, start_idx):
        """Process a chunk of frames in parallel"""
        print(f"Processing chunk starting at {start_idx} with {len(frames)} frames")

        for local_idx, frame in enumerate(frames):
            global_idx = start_idx + local_idx
            save_path = os.path.join(
                self.temp_dir.name,
                f"frame_{global_idx:06d}.npy"  # Consistent naming
            )

            # Debug: Verify frame content
            if frame is None or frame.size == 0:
                raise ValueError(f"Invalid frame at index {global_idx}")

            # Colorization pipeline
            resized_rgb = cv2.resize(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), (64, 64))
            lab = rgb2lab(resized_rgb.astype(np.float32)/255.0)
            L = lab[:, :, 0:1]
            AB = self.model.predict(np.expand_dims(L, axis=0), verbose=0)[0]

            # Save data with verification
            data = {
                'frame': frame,
                'L': L,
                'AB': AB,
                'gray': cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            }
            np.save(save_path, data)
            print(f"Saved frame {global_idx} to {save_path}")

        # Immediate verification of saved files
        saved_files = [f for f in os.listdir(self.temp_dir.name)
                      if f.startswith(f"frame_{start_idx:06d}")]
        print(f"Saved {len(saved_files)} files in this chunk")

        return True

    def _temporal_smooth(self, prev_data, current_data):
        """Apply temporal consistency between frames"""
        if prev_data is None:
            return current_data['AB']

        # Compute optical flow at model resolution
        target_size = (64, 64)
        prev_gray = cv2.resize(prev_data['gray'], target_size)
        current_gray = cv2.resize(current_data['gray'], target_size)

        flow = cv2.calcOpticalFlowFarneback(
            prev_gray, current_gray,
            None, **self.flow_params
        )

        # Create normalized coordinate grid
        h, w = target_size
        x_map, y_map = np.meshgrid(np.arange(w), np.arange(h))
        flow_map = np.stack([
           (x_map + flow[..., 0]).astype(np.float32),
           (y_map + flow[..., 1]).astype(np.float32)
        ], axis=-1)

        # Ensure coordinates stay within image bounds
        flow_map[..., 0] = np.clip(flow_map[..., 0], 0, w-1)
        flow_map[..., 1] = np.clip(flow_map[..., 1], 0, h-1)



        # Warp previous AB channels
        warped_AB = cv2.remap(
            prev_data['AB'].astype(np.float32),  # Ensure float32 input
            flow_map,
            None,
            cv2.INTER_LINEAR,
            borderMode=cv2.BORDER_REFLECT
        )

        # Handle invalid regions (black borders from warping)
        mask = (warped_AB == 0).all(axis=-1, keepdims=True)
        blended_AB = np.where(mask, current_data['AB'],
                         self.blend_factor * current_data['AB'] +
                         (1 - self.blend_factor) * warped_AB)

        smoothed_AB = self.temporal_alpha * blended_AB + \
                     (1 - self.temporal_alpha) * warped_AB
        return smoothed_AB


    def colorize_video(self, input_path, output_path, batch_size=16, workers=8):
        # Open video once for all processing
        cap = cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f"Couldn't open video {input_path}")

        # Get video properties from the first frame
        ret, first_frame = cap.read()
        if not ret:
            cap.release()
            raise ValueError("Couldn't read first frame")
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        print(f"Width: {frame_width} | Height: {frame_height} | FPS: {fps}")
        # Rewind to beginning
        cap.set(cv2.CAP_PROP_POS_FRAMES, 0)

        # Initialize video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height), isColor=True)

        # Process frames in a single pass
        executor = ThreadPoolExecutor(max_workers=workers)
        futures = []
        chunk = []
        frame_count = 0

        while True:
            ret, frame = cap.read()
            if not ret:
                break
            chunk.append(frame)
            frame_count += 1
            if len(chunk) == batch_size:
                future = executor.submit(self._process_chunk, chunk.copy(), frame_count - len(chunk))
                futures.append(future)
                chunk = []

        # Process remaining frames
        if chunk:
            future = executor.submit(self._process_chunk, chunk, frame_count - len(chunk))
            futures.append(future)

        # Wait for all processing to finish
        for f in futures:
            f.result()

        cap.release()


        # 4. Reconstruct video with temporal smoothing
        saved_files = sorted(
            [f for f in os.listdir(self.temp_dir.name)
             if f.startswith('frame_') and f.endswith('.npy')],
            key=lambda x: int(x.split('_')[1].split('.')[0]))

        prev_data = None
        for i, filename in enumerate(saved_files):
            file_path = os.path.join(self.temp_dir.name, filename)
            data = np.load(file_path, allow_pickle=True).item()

            # Apply temporal smoothing
            smoothed_AB = self._temporal_smooth(prev_data, data)

            # Reconstruct frame
            rgb_resized = cv2.resize(
                cv2.cvtColor(data['frame'], cv2.COLOR_BGR2RGB),
                (64, 64))
            lab = rgb2lab(rgb_resized.astype(np.float32)/255.0)
            final_lab = np.concatenate([lab[..., 0:1], smoothed_AB], axis=-1)

            # Convert to output dimensions
            colorized_rgb = (lab2rgb(final_lab) * 255).astype(np.uint8)
            final_frame = cv2.resize(
            colorized_rgb,
            (frame_width, frame_height))
            final_bgr = cv2.cvtColor(final_frame, cv2.COLOR_RGB2BGR)

            writer.write(final_bgr)

            prev_data = {'AB': smoothed_AB, 'gray': data['gray']}

            if i % 10 == 0:
                print(f"Processed {i+1}/{len(saved_files)} frames")

        # 5. Final cleanup
        writer.release()
        self.temp_dir.cleanup()

        # Verify output
        if os.path.exists(output_path):
            print(f"\n✅ Success! Colorized video saved to: {output_path}")
            print(f"Resolution: {frame_width}x{frame_height} | Frames: {len(saved_files)}")
        else:
            print("\n❌ Video creation failed - check codec compatibility")




In [None]:
colorizer = VideoColorizer(generator, temporal_alpha=0.85)


colorizer.colorize_video(
    "/content/drive/MyDrive/Colorization/Video/Input.mp4",
    "/content/drive/MyDrive/Colorization/Video/Output.mp4",
    batch_size=32,
    workers=8,
)

Created temp directory at: /content/colorizer_temp_qa9brar1
Width: 1280 | Height: 720 | FPS: 29.97002997002997
Processing chunk starting at 0 with 32 frames
Processing chunk starting at 32 with 32 frames
Processing chunk starting at 64 with 32 frames
Processing chunk starting at 96 with 32 frames
Processing chunk starting at 128 with 32 frames
Processing chunk starting at 160 with 32 frames
Processing chunk starting at 192 with 32 frames
Processing chunk starting at 224 with 32 frames
Saved frame 224 to /content/colorizer_temp_qa9brar1/frame_000224.npy
Saved frame 160 to /content/colorizer_temp_qa9brar1/frame_000160.npy
Saved frame 0 to /content/colorizer_temp_qa9brar1/frame_000000.npy
Saved frame 96 to /content/colorizer_temp_qa9brar1/frame_000096.npy
Saved frame 32 to /content/colorizer_temp_qa9brar1/frame_000032.npy
Saved frame 192 to /content/colorizer_temp_qa9brar1/frame_000192.npy
Saved frame 64 to /content/colorizer_temp_qa9brar1/frame_000064.npy
Saved frame 128 to /content/colo

In [None]:
!ffprobe -v error -show_entries format=duration \
         -of default=noprint_wrappers=1:nokey=1 \
         "/content/drive/MyDrive/Colorization/Video/Input2.mp4"

8.600000


In [None]:
!ffprobe -v error -show_entries format=duration \
         -of default=noprint_wrappers=1:nokey=1 \
         "/content/drive/MyDrive/Colorization/Video/Output.mp4"

10.811000
