<a href="https://colab.research.google.com/github/Jess-Lau/Real-Life-B-W-Video-Colorization-Project/blob/main/ColorizerGANV5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
# ----------------------
# Image Colorization GAN
# ----------------------
import os
import cv2
import time
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, Sequential
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb



In [7]:
# ------------------
# Configuration
# ------------------
IMAGE_SIZE = 64
CHANNELS = 1
EPOCHS = 70
BATCH_SIZE = 256
LAMBDA = 100
DATA_DIR = "/content/drive/MyDrive/ImageNet"  # Update with your path
WORKDIR = "/content/drive/MyDrive/Colorization"
CHECKPOINT_DIR = os.path.join(WORKDIR, "checkpoints")
RESULTS_DIR = os.path.join(WORKDIR, "results")

# Enable mixed precision
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# Create directories
os.makedirs(WORKDIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

In [3]:
# ------------------
# Data Pipeline
# ------------------
def load_mean(data_dir):
    """Load mean image from first training batch"""
    with open(os.path.join(data_dir, 'train_data_batch_1'), 'rb') as f:
        data = pickle.load(f)
        mean = data['mean'].astype(np.float32) / 255.0
        return mean.reshape(3, IMAGE_SIZE, IMAGE_SIZE).transpose(1, 2, 0)

def data_generator(data_dir, split='train'):
    mean = load_mean(data_dir) if split == 'train' else None
    files = [f'train_data_batch_{i}' for i in range(1, 11)] if split == 'train' else ['val_data']

    for file in files:
        path = os.path.join(data_dir, file)
        try:
            with open(path, 'rb') as f:
                data = pickle.load(f)
                x = data['data'].astype(np.float32) / 255.0
                x = x.reshape(-1, 3, IMAGE_SIZE, IMAGE_SIZE).transpose(0, 2, 3, 1)

                if mean is not None:
                    x -= mean

                for i in range(0, x.shape[0], BATCH_SIZE):
                    batch_rgb = x[i:i+BATCH_SIZE]
                    batch_lab = np.array([rgb2lab(img) for img in batch_rgb])
                    L = batch_lab[..., 0:1].astype(np.float32)
                    AB = (batch_lab[..., 1:] / 128.0).astype(np.float32)
                    yield L, AB
        except Exception as e:
            print(f"❌ Failed to load {path}: {str(e)}")
            continue  # Skip problematic files

def create_dataset(data_dir, split='train'):
    return tf.data.Dataset.from_generator(
        lambda: data_generator(data_dir, split),
        output_signature=(  # ✅ Proper parentheses
            tf.TensorSpec(shape=(None, 64, 64, 1), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 64, 64, 2), dtype=tf.float32)
        )
    ).prefetch(tf.data.AUTOTUNE)  # ✅ .prefetch() called on dataset

In [4]:
# ------------------
# Model Architectures
# ------------------
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    model = Sequential()
    model.add(layers.Conv2D(filters, size, strides=2, padding='same',
                          kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(0.2))
    return model

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    model = Sequential()
    model.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                    kernel_initializer=initializer, use_bias=False))
    model.add(layers.BatchNormalization())
    if apply_dropout:
        model.add(layers.Dropout(0.5))
    model.add(layers.ReLU())
    return model

def build_generator():
    inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))

    # Encoder
    d1 = downsample(64, 4, False)(inputs)    # 32x32
    d2 = downsample(128, 4)(d1)              # 16x16
    d3 = downsample(256, 4)(d2)              # 8x8
    d4 = downsample(512, 4)(d3)              # 4x4

    # Decoder
    u1 = upsample(512, 4, True)(d4)          # 8x8
    u1 = layers.Concatenate()([u1, d3])
    u2 = upsample(256, 4)(u1)                # 16x16
    u2 = layers.Concatenate()([u2, d2])
    u3 = upsample(128, 4)(u2)                # 32x32
    u3 = layers.Concatenate()([u3, d1])
    u4 = upsample(64, 4)(u3)                 # 64x64

    output = layers.Conv2D(2, 3, padding='same', activation='tanh')(u4)
    return Model(inputs, output)

def build_discriminator():
    inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))

    x = layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(1, activation='sigmoid')(x)
    return Model(inputs, x)

In [5]:
# ------------------
# Training Setup
# ------------------
generator = build_generator()
discriminator = build_discriminator()

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator=generator,
    discriminator=discriminator,
    epoch=tf.Variable(0)
)
manager = tf.train.CheckpointManager(checkpoint, CHECKPOINT_DIR, max_to_keep=3)

In [6]:
# ------------------
# Training Utilities
# ------------------
def generate_images(model, test_input, epoch):
    input_L = test_input[0]  # ✅ Extract L channel
    target_AB = test_input[1]  # Ground truth AB

    # Predict using only L
    prediction = model(input_L, training=False)[0].numpy()
    L = input_L[0].numpy()[..., 0]  # Use first sample in batch

    plt.figure(figsize=(12, 4))

    # Input (grayscale)
    plt.subplot(1, 3, 1)
    plt.imshow(L, cmap='gray')
    plt.title("Input")
    plt.axis('off')

    # Ground truth (colorized)
    plt.subplot(1, 3, 2)
    true_rgb = lab2rgb(np.dstack((L, target_AB[0].numpy() * 128)))  # ✅ Use target_AB
    plt.imshow(true_rgb)
    plt.title("Ground Truth")
    plt.axis('off')

    # Predicted (colorized)
    plt.subplot(1, 3, 3)
    pred_rgb = lab2rgb(np.dstack((L, prediction * 128)))
    plt.imshow(pred_rgb)
    plt.title("Predicted")
    plt.axis('off')

    plt.savefig(os.path.join(RESULTS_DIR, f'epoch_{epoch+1}.png'))
    plt.close()

@tf.function
def train_step(input_L, input_AB):
    # Cast to mixed precision
    input_L = tf.cast(input_L, tf.float16)
    input_AB = tf.cast(input_AB, tf.float16)

    with tf.GradientTape(persistent=True) as tape:
        generated_AB = generator(input_L, training=True)

        # Create concatenated images
        real_images = tf.concat([input_L, input_AB], axis=-1)
        fake_images = tf.concat([input_L, generated_AB], axis=-1)

        # Discriminator outputs
        disc_real = discriminator(real_images, training=True)
        disc_fake = discriminator(fake_images, training=True)

        # Loss calculations
        gen_loss = tf.keras.losses.binary_crossentropy(
            tf.ones_like(disc_fake), disc_fake) + LAMBDA * tf.reduce_mean(tf.abs(input_AB - generated_AB))
        disc_loss = tf.keras.losses.binary_crossentropy(
            tf.ones_like(disc_real), disc_real) + tf.keras.losses.binary_crossentropy(
            tf.zeros_like(disc_fake), disc_fake)

    # Apply gradient clipping
    gen_grads = tape.gradient(gen_loss, generator.trainable_variables)
    gen_grads = [tf.clip_by_norm(g, 1.0) for g in gen_grads]
    generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))

    disc_grads = tape.gradient(disc_loss, discriminator.trainable_variables)
    disc_grads = [tf.clip_by_norm(g, 1.0) for g in disc_grads]
    discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))

    return tf.reduce_mean(gen_loss), tf.reduce_mean(disc_loss)




In [8]:
# ------------------
# Training Loop
# ------------------
def train():
    train_dataset = create_dataset(DATA_DIR, 'train')
    val_dataset = create_dataset(DATA_DIR, 'val')

    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        print(f"Resumed from epoch {checkpoint.epoch.numpy()}")

    for epoch in range(checkpoint.epoch.numpy(), EPOCHS):
        start = time.time()
        gen_losses, disc_losses = [], []

        for batch, (L, AB) in enumerate(train_dataset):
            gen_loss, disc_loss = train_step(L, AB)
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)

            if batch % 100 == 0:
                gen_loss_val = gen_loss.numpy().item()
                disc_loss_val = disc_loss.numpy().item()
                print(f"Epoch {epoch+1} Batch {batch} | Gen: {gen_loss_val:.2f} Disc: {disc_loss_val:.2f}")
                tf.keras.backend.clear_session()

        if (epoch + 1) % 5 == 0:
            manager.save()
            test_batch = next(iter(val_dataset))
            generate_images(generator, test_batch, epoch)

        print(f"\nEpoch {epoch+1}/{EPOCHS}")
        print(f"Time: {time.time()-start:.2f}s")
        print(f"Gen Loss: {np.mean(gen_losses):.4f}")
        print(f"Disc Loss: {np.mean(disc_losses):.4f}\n")
        checkpoint.epoch.assign_add(1)

In [9]:
if __name__ == "__main__":
    train()

Resumed from epoch 69


  nparray = values.astype(dtype.as_numpy_dtype)


Epoch 70 Batch 0 | Gen: 9.85 Disc: 1.39
Epoch 70 Batch 100 | Gen: 10.69 Disc: 1.14
Epoch 70 Batch 200 | Gen: 9.63 Disc: 1.33
Epoch 70 Batch 300 | Gen: 9.67 Disc: 1.32
Epoch 70 Batch 400 | Gen: 9.84 Disc: 1.12
Epoch 70 Batch 500 | Gen: 10.32 Disc: 1.16
Epoch 70 Batch 600 | Gen: 9.74 Disc: 1.30
Epoch 70 Batch 700 | Gen: 10.37 Disc: 1.17
Epoch 70 Batch 800 | Gen: 9.73 Disc: 1.37
Epoch 70 Batch 900 | Gen: 9.32 Disc: 1.54
Epoch 70 Batch 1000 | Gen: 10.48 Disc: 1.23
Epoch 70 Batch 1100 | Gen: 9.77 Disc: 1.17
Epoch 70 Batch 1200 | Gen: 9.82 Disc: 1.30
Epoch 70 Batch 1300 | Gen: 10.13 Disc: 1.30
Epoch 70 Batch 1400 | Gen: 10.53 Disc: 1.03
Epoch 70 Batch 1500 | Gen: 9.41 Disc: 1.22
Epoch 70 Batch 1600 | Gen: 10.32 Disc: 1.33
Epoch 70 Batch 1700 | Gen: 10.17 Disc: 1.23
Epoch 70 Batch 1800 | Gen: 10.02 Disc: 1.28
Epoch 70 Batch 1900 | Gen: 10.55 Disc: 1.38
Epoch 70 Batch 2000 | Gen: 10.23 Disc: 1.25
Epoch 70 Batch 2100 | Gen: 10.34 Disc: 1.32
Epoch 70 Batch 2200 | Gen: 9.78 Disc: 1.13
Epoch 70 Ba

  pred_rgb = lab2rgb(np.dstack((L, prediction * 128)))



Epoch 70/70
Time: 1357.40s
Gen Loss: 10.1016
Disc Loss: 1.2832



In [11]:
# ----------------------
# Inference Function
# ----------------------
def colorize_image(model, image_path, output_path, image_size=64):
    """
    Colorizes a single image using the trained generator.

    Args:
        model: Trained generator model
        image_path: Path to input grayscale/RGB image
        output_path: Path to save colorized image
        image_size: Size to resize image (must match model input)
    """
    # Load and preprocess image
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Could not load image at {image_path}")

    # Convert to RGB if needed
    if image.ndim == 2:  # Grayscale
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    else:  # BGR to RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Resize and normalize
    image = cv2.resize(image, (image_size, image_size))
    image = image.astype(np.float32) / 255.0

    # Convert to LAB and extract L channel
    lab = rgb2lab(image)
    L = lab[:, :, 0:1]  # (H, W, 1)

    # Add batch dimension and predict
    L_batch = np.expand_dims(L, axis=0)  # (1, H, W, 1)
    AB_pred = model.predict(L_batch, verbose=0)[0]  # (H, W, 2)

    # Denormalize AB channels
    AB_pred = (AB_pred * 128.0).astype(np.float32)

    # Combine with L and convert to RGB
    colorized_lab = np.concatenate([L, AB_pred], axis=-1)
    colorized_rgb = lab2rgb(colorized_lab)

    # Clip and save
    colorized_rgb = np.clip(colorized_rgb, 0, 1)
    plt.imsave(output_path, colorized_rgb)
    print(f"Colorized image saved to {output_path}")

In [None]:
# After training, use like this:
colorize_image(
    generator,
    "/content/my_grayscale_image.jpg",  # Input path
    "/content/colorized_result.jpg"     # Output path
)

In [57]:
# ----------------------
# Parallel Video Colorization with Temporal Consistency
# ----------------------
import cv2
import numpy as np
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from skimage.color import rgb2lab, lab2rgb

# ----------------------
# Video Colorizer with Content Directory Temp Files
# ----------------------
import cv2
import numpy as np
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from skimage.color import rgb2lab, lab2rgb

class VideoColorizer:
    def __init__(self, model, temporal_alpha=0.8, blend_factor=0.7):
        self.model = model
        self.temporal_alpha = temporal_alpha
        self.blend_factor = blend_factor
        # Optical flow parameters
        self.flow_params = {
            'pyr_scale': 0.5,
            'levels': 3,
            'winsize': 15,
            'iterations': 3,
            'poly_n': 5,
            'poly_sigma': 1.2,
            'flags': cv2.OPTFLOW_FARNEBACK_GAUSSIAN
        }
        # Create temp directory in Colab's content directory
        self.temp_dir = tempfile.TemporaryDirectory(
            dir='/content',
            prefix='colorizer_temp_'
        )
        print(f"Temporary directory created at: {self.temp_dir.name}")
        os.sync()  # Flush file buffers



    def _process_chunk(self, frames, start_idx):
        """Process a chunk of frames in parallel"""
        try:
            print(f"Processing chunk starting at index {start_idx}")
            os.makedirs(self.temp_dir.name, exist_ok=True)

            for idx, frame in enumerate(frames):
                frame_idx = start_idx + idx
                save_path = os.path.join(
                    self.temp_dir.name,
                    f"frame_{frame_idx:06d}.npy"  # Consistent "frame_" prefix
                )

                # Colorization pipeline
                resized_rgb = cv2.resize(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), (64, 64))
                lab = rgb2lab(resized_rgb.astype(np.float32)/255.0)
                L = lab[:, :, 0:1]
                AB = self.model.predict(np.expand_dims(L, axis=0), verbose=0)[0]

                # Save with explicit path
                data = {
                    'frame': frame,
                    'L': L,
                    'AB': AB,
                    'gray': cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                }
                np.save(save_path, data)
                print(f"Saved frame {frame_idx} to {save_path}")

            return True
        except Exception as e:
            print(f"Error in chunk {start_idx}: {str(e)}")
            raise

    def _temporal_smooth(self, prev_data, current_data):
        """Apply temporal consistency between frames"""
        if prev_data is None:
            return current_data['AB']

        # Compute optical flow at model resolution
        target_size = (64, 64)
        prev_gray = cv2.resize(prev_data['gray'], target_size)
        current_gray = cv2.resize(current_data['gray'], target_size)

        flow = cv2.calcOpticalFlowFarneback(
            prev_gray, current_gray,
            None, **self.flow_params
        )

        # Create normalized coordinate grid
        h, w = target_size
        x_map, y_map = np.meshgrid(np.arange(w), np.arange(h))
        flow_map = np.stack([
           (x_map + flow[..., 0]).astype(np.float32),
           (y_map + flow[..., 1]).astype(np.float32)
        ], axis=-1)

        # Ensure coordinates stay within image bounds
        flow_map[..., 0] = np.clip(flow_map[..., 0], 0, w-1)
        flow_map[..., 1] = np.clip(flow_map[..., 1], 0, h-1)



        # Warp previous AB channels
        warped_AB = cv2.remap(
            prev_data['AB'].astype(np.float32),  # Ensure float32 input
            flow_map,
            None,
            cv2.INTER_LINEAR,
            borderMode=cv2.BORDER_REFLECT
        )

        # Handle invalid regions (black borders from warping)
        mask = (warped_AB == 0).all(axis=-1, keepdims=True)
        blended_AB = np.where(mask, current_data['AB'],
                         self.blend_factor * current_data['AB'] +
                         (1 - self.blend_factor) * warped_AB)

        smoothed_AB = self.temporal_alpha * blended_AB + \
                     (1 - self.temporal_alpha) * warped_AB
        return smoothed_AB

    def colorize_video(self, input_path, output_path,
                     batch_size=16, workers=8,
                     frame_limit=None):
        # Initialization checks
        print(f"Initializing video processing")
        print(f"Input: {input_path}")
        print(f"Output: {output_path}")

        """Main processing pipeline"""
        # Phase 1: Parallel colorization
        print(f"Temporary directory: {self.temp_dir.name}")
        if not os.path.exists(self.temp_dir.name):
            raise FileNotFoundError("Temp directory not created!")
        cap = cv2.VideoCapture(input_path)
        if not cap.isOpened():
            print("❌ Failed to open input video")
            return

        executor = ThreadPoolExecutor(max_workers=workers)
        futures = []
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        print(f"Total frames: {frame_count}")
        print(f"FPS: {cap.get(cv2.CAP_PROP_FPS)}")
        print(f"Resolution: {cap.get(cv2.CAP_PROP_FRAME_WIDTH)}x{cap.get(cv2.CAP_PROP_FRAME_HEIGHT)}")
        chunk = []

        while True:
            ret, frame = cap.read()
            if not ret or (frame_limit and frame_count >= frame_limit):
                break

            chunk.append(frame)
            frame_count += 1

            if len(chunk) == batch_size:
                future = executor.submit(self._process_chunk, chunk.copy(), frame_count - batch_size)
                futures.append(future)
                chunk = []

        # Process remaining frames
        if len(chunk) > 0:
            future = executor.submit(self._process_chunk, chunk, frame_count - len(chunk))
            futures.append(future)

        # Wait for all chunks to complete
        _ = [f.result() for f in futures]
        cap.release()

        # Phase 2: Temporal smoothing and output
        writer = cv2.VideoWriter(output_path,
                               cv2.VideoWriter_fourcc(*'mp4v'),
                               cap.get(cv2.CAP_PROP_FPS),
                               (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
                               int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))

        prev_data = None
        saved_files = sorted([
            f for f in os.listdir(self.temp_dir.name)
            if f.startswith("frame_") and f.endswith(".npy")
        ], key=lambda x: int(x.split('_')[1].split('.')[0]))

        print(f"Found {len(saved_files)} processed frames")
        print("Sample files:", saved_files[:3])

        for i, filename in enumerate(saved_files):
            file_path = os.path.join(self.temp_dir.name, filename)
            data = np.load(file_path, allow_pickle=True).item()

        for i in range(frame_count):
            # Use zero-padded filenames
            file_path = os.path.join(self.temp_dir.name, f"{i:06d}.npy")
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Missing frame {i}: {file_path}")

            data = np.load(file_path, allow_pickle=True).item()

            # Apply temporal smoothing
            smoothed_AB = self._temporal_smooth(prev_data, data)

            # Generate final frame
            lab = np.concatenate([rgb2lab(cv2.resize(cv2.cvtColor(data['frame'],cv2.COLOR_BGR2RGB),(64, 64)))[..., 0:1],smoothed_AB],axis=-1)
            # Resize back to original dimensions
            rgb = (lab2rgb(lab) * 255).astype(np.uint8)
            final_frame = cv2.resize(rgb, (data['frame'].shape[1], data['frame'].shape[0]))
            print(f"First frame shape before write: {final_frame.shape}")
            writer.write(final_frame)
            print("Successfully wrote first frame")
            prev_data = {'AB': smoothed_AB, 'gray': data['gray']}

            if i % 10 == 0:
                print(f"Processed {i+1}/{frame_count} frames")
        print(f"Generated {len(os.listdir(self.temp_dir.name))} intermediate files")
        writer.release()
        self.temp_dir.cleanup()


In [58]:
colorizer = VideoColorizer(generator, temporal_alpha=0.85)


colorizer.colorize_video(
    "/content/Video/input.mp4",
    "/content/Video/output.mp4",
    batch_size=32,
    workers=8,
    frame_limit=300
)

Temporary directory created at: /content/colorizer_temp_8yti6knv
Initializing video processing
Input: /content/Video/input.mp4
Output: /content/Video/output.mp4
Temporary directory: /content/colorizer_temp_8yti6knv
Total frames: 324
FPS: 29.97002997002997
Resolution: 1280.0x720.0
Found 0 processed frames
Sample files: []


FileNotFoundError: Missing frame 0: /content/colorizer_temp_8yti6knv/000000.npy

In [47]:
tempfile.TemporaryDirectory(
            dir='/content',
            prefix='colorizer_temp_'
        )

<TemporaryDirectory '/content/colorizer_temp_9sxw912w'>

In [None]:
# Find optimal batch size for your GPU
!nvidia-smi --loop=1  # Monitor GPU memory usage