In [None]:
import os
import numpy as np
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
import torchvision.models as models
from torchvision.models import VGG19_Weights
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from skimage import color
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

In [None]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [None]:
# DATASET CLASS
class ColorizationDataset(Dataset):
    def __init__(self, img_dir, size=256, augment=True):
        self.img_paths = [os.path.join(img_dir, f) for f in os.listdir(img_dir)
                         if f.lower().endswith(('.jpg','.png','.jpeg'))]
        self.size = size
        self.augment = augment

        # Transforms with augmentation
        if augment:
            self.transform = T.Compose([
                T.Resize((size + 32, size + 32)),
                T.RandomCrop((size, size)),
                T.RandomHorizontalFlip(p=0.5),
                T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                T.ToTensor()
            ])
        else:
            self.transform = T.Compose([
                T.Resize((size, size)),
                T.ToTensor()
            ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        try:
            # Load and transform image
            img = Image.open(self.img_paths[idx]).convert('RGB')
            img = self.transform(img)
            img = img.numpy().transpose((1,2,0))

            # Convert to LAB with improved normalization
            img_lab = color.rgb2lab(img).astype(np.float32)
            L = img_lab[:,:,0]  # [0,100]
            AB = img_lab[:,:,1:]  # [-128,127]

            # Better normalization
            L = L / 100.0 * 2.0 - 1.0  # [-1, 1]
            AB = AB / 128.0  # [-1, 1]

            # Convert to tensors
            L_tensor = torch.from_numpy(L).unsqueeze(0)
            AB_tensor = torch.from_numpy(AB.transpose((2,0,1)))

            return L_tensor.float(), AB_tensor.float()
        except Exception as e:
            print(f"Error loading image {self.img_paths[idx]}: {e}")
            # Return a dummy tensor if image fails to load
            L_dummy = torch.zeros(1, self.size, self.size)
            AB_dummy = torch.zeros(2, self.size, self.size)
            return L_dummy, AB_dummy

In [None]:
# U-NET MODEL
class ColorizationModel(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder: ResNet-50 for better feature extraction
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

        # Extract features at different scales for skip connections
        self.encoder1 = nn.Sequential(*list(resnet.children())[:3])   # 64 channels
        self.encoder2 = nn.Sequential(*list(resnet.children())[3:5])  # 256 channels
        self.encoder3 = nn.Sequential(*list(resnet.children())[5:6])  # 512 channels
        self.encoder4 = nn.Sequential(*list(resnet.children())[6:7])  # 1024 channels
        self.encoder5 = nn.Sequential(*list(resnet.children())[7:8])  # 2048 channels

        # Decoder with skip connections (U-Net style)
        self.decoder5 = self._make_decoder_block(2048, 1024)
        self.decoder4 = self._make_decoder_block(1024 + 1024, 512)  # +1024 from skip
        self.decoder3 = self._make_decoder_block(512 + 512, 256)    # +512 from skip
        self.decoder2 = self._make_decoder_block(256 + 256, 128)    # +256 from skip
        self.decoder1 = self._make_decoder_block(128 + 64, 64)      # +64 from skip

        # Final layer
        self.final_conv = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 2, 3, padding=1),
            nn.Tanh()
        )

    def _make_decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        # Convert grayscale to 3-channel
        x = x.repeat(1, 3, 1, 1)

        # Encoder with skip connections
        e1 = self.encoder1(x)     # 64x64x64
        e2 = self.encoder2(e1)    # 256x32x32
        e3 = self.encoder3(e2)    # 512x16x16
        e4 = self.encoder4(e3)    # 1024x8x8
        e5 = self.encoder5(e4)    # 2048x4x4

        # Decoder with skip connections
        d5 = self.decoder5(e5)    # 1024x8x8
        d4 = self.decoder4(torch.cat([d5, e4], dim=1))  # 512x16x16
        d3 = self.decoder3(torch.cat([d4, e3], dim=1))  # 256x32x32
        d2 = self.decoder2(torch.cat([d3, e2], dim=1))  # 128x64x64
        d1 = self.decoder1(torch.cat([d2, e1], dim=1))  # 64x128x128

        # Final upsampling to 256x256
        d1 = F.interpolate(d1, size=(256, 256), mode='bilinear', align_corners=False)

        # Final convolution
        out = self.final_conv(d1)
        return out

In [None]:
# PERCEPTUAL LOSS
class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # Use VGG19 for perceptual loss
        vgg = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features[:16]
        self.vgg = vgg.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

        self.mse = nn.MSELoss()

    def forward(self, pred_lab, target_lab):
        try:
            # Convert LAB to RGB for VGG
            pred_rgb = self.lab_to_rgb_tensor(pred_lab)
            target_rgb = self.lab_to_rgb_tensor(target_lab)

            # VGG features
            pred_features = self.vgg(pred_rgb)
            target_features = self.vgg(target_rgb)

            # Perceptual loss
            perceptual_loss = self.mse(pred_features, target_features)

            # Pixel loss
            pixel_loss = self.mse(pred_lab, target_lab)

            return pixel_loss + 0.1 * perceptual_loss
        except Exception as e:
            print(f"Perceptual loss error: {e}")
            # Fallback to pixel loss only
            return self.mse(pred_lab, target_lab)

    def lab_to_rgb_tensor(self, lab_tensor):
        """Convert LAB tensor to RGB tensor for VGG"""
        batch_size = lab_tensor.shape[0]
        rgb_imgs = []

        for i in range(batch_size):
            lab = lab_tensor[i].detach().cpu().numpy()

            # Denormalize
            L = (lab[0] + 1.0) / 2.0 * 100.0
            AB = lab[1:] * 128.0

            # Combine and convert
            lab_img = np.concatenate([L[None], AB], axis=0).transpose(1,2,0)
            # Clamp LAB values to valid ranges
            lab_img[:,:,0] = np.clip(lab_img[:,:,0], 0, 100)
            lab_img[:,:,1:] = np.clip(lab_img[:,:,1:], -128, 127)

            rgb_img = color.lab2rgb(lab_img)
            rgb_imgs.append(rgb_img)

        rgb_tensor = torch.FloatTensor(np.array(rgb_imgs)).permute(0,3,1,2)
        return rgb_tensor.to(lab_tensor.device)

In [None]:
# EVALUATION FUNCTION
def evaluate_model(model, val_loader, num_samples=20):
    """Evaluate model with PSNR and SSIM metrics - FIXED VERSION"""
    model.eval()
    psnr_scores = []
    ssim_scores = []

    with torch.no_grad():
        sample_count = 0
        for L, AB in val_loader:
            if sample_count >= num_samples:
                break

            L, AB = L.to(device), AB.to(device)
            pred_AB = model(L)

            # Convert to numpy for metrics
            for j in range(L.shape[0]):
                if sample_count >= num_samples:
                    break

                try:
                    # Reconstruct images
                    l_np = L[j,0].cpu().numpy()
                    ab_true = AB[j].cpu().numpy()
                    ab_pred = pred_AB[j].cpu().numpy()

                    # Denormalize
                    l_denorm = (l_np + 1.0) / 2.0 * 100.0
                    ab_true_denorm = ab_true * 128.0
                    ab_pred_denorm = ab_pred * 128.0

                    # Clamp values to valid LAB ranges
                    l_denorm = np.clip(l_denorm, 0, 100)
                    ab_true_denorm = np.clip(ab_true_denorm, -128, 127)
                    ab_pred_denorm = np.clip(ab_pred_denorm, -128, 127)

                    # Combine LAB
                    lab_true = np.stack([l_denorm, ab_true_denorm[0], ab_true_denorm[1]], axis=2)
                    lab_pred = np.stack([l_denorm, ab_pred_denorm[0], ab_pred_denorm[1]], axis=2)

                    # Convert to RGB
                    rgb_true = color.lab2rgb(lab_true)
                    rgb_pred = color.lab2rgb(lab_pred)

                    # Ensure valid range [0,1]
                    rgb_true = np.clip(rgb_true, 0, 1)
                    rgb_pred = np.clip(rgb_pred, 0, 1)

                    # Calculate metrics with proper parameters
                    psnr_val = psnr(rgb_true, rgb_pred, data_range=1.0)

                    # FIXED SSIM calculation - specify channel_axis and smaller win_size if needed
                    min_dim = min(rgb_true.shape[0], rgb_true.shape[1])
                    win_size = min(7, min_dim) if min_dim >= 3 else 3

                    if min_dim >= 3:  # Only calculate SSIM if image is large enough
                        ssim_val = ssim(rgb_true, rgb_pred,
                                      channel_axis=2,  # Specify channel axis
                                      data_range=1.0,
                                      win_size=win_size)
                        ssim_scores.append(ssim_val)

                    psnr_scores.append(psnr_val)
                    sample_count += 1

                except Exception as e:
                    print(f"Error evaluating sample {sample_count}: {e}")
                    continue

    if psnr_scores:
        print(f"PSNR: {np.mean(psnr_scores):.2f} ± {np.std(psnr_scores):.2f}")
    if ssim_scores:
        print(f"SSIM: {np.mean(ssim_scores):.3f} ± {np.std(ssim_scores):.3f}")
    else:
        print("SSIM: Could not calculate (images too small)")

    return psnr_scores, ssim_scores

In [None]:
# VIDEO COLORIZER
class VideoColorizer:
    def __init__(self, model_path):
        self.model = ColorizationModel().to(device)
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=device))
        else:
            print(f"Warning: Model file {model_path} not found. Using untrained model.")
        self.model.eval()

    def colorize_frame(self, gray_frame):
        """Colorize a single grayscale frame"""
        try:
            # Resize and normalize
            frame_resized = cv2.resize(gray_frame, (256, 256))
            frame_norm = frame_resized.astype(np.float32) / 255.0

            # Convert to LAB L channel
            L = frame_norm * 100.0  # [0, 100]
            L_norm = L / 100.0 * 2.0 - 1.0  # [-1, 1]

            # Convert to tensor
            L_tensor = torch.FloatTensor(L_norm).unsqueeze(0).unsqueeze(0).to(device)

            # Predict AB channels
            with torch.no_grad():
                pred_AB = self.model(L_tensor).cpu().squeeze().numpy()

            # Denormalize
            AB_denorm = pred_AB * 128.0
            L_denorm = (L_norm + 1.0) / 2.0 * 100.0

            # Clamp values
            L_denorm = np.clip(L_denorm, 0, 100)
            AB_denorm = np.clip(AB_denorm, -128, 127)

            # Combine LAB
            lab_frame = np.zeros((256, 256, 3))
            lab_frame[:,:,0] = L_denorm
            lab_frame[:,:,1:] = AB_denorm.transpose(1,2,0)

            # Convert to RGB
            rgb_frame = color.lab2rgb(lab_frame)
            rgb_frame = np.clip(rgb_frame * 255, 0, 255).astype(np.uint8)

            return rgb_frame

        except Exception as e:
            print(f"Error colorizing frame: {e}")
            # Return grayscale frame as RGB fallback
            rgb_fallback = cv2.cvtColor(cv2.resize(gray_frame, (256, 256)), cv2.COLOR_GRAY2RGB)
            return rgb_fallback

    def colorize_video(self, input_path, output_path, batch_size=4):
        """Colorize an entire video file"""
        # Open input video
        cap = cv2.VideoCapture(input_path)

        if not cap.isOpened():
            print(f"Error: Could not open video {input_path}")
            return

        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Setup video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

        print(f"Processing {total_frames} frames at {fps} FPS...")

        with tqdm(total=total_frames, desc="Colorizing video") as pbar:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break

                # Convert to grayscale
                gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

                # Colorize at 256x256 then resize back
                colorized_256 = self.colorize_frame(gray_frame)
                colorized_full = cv2.resize(colorized_256, (width, height))

                # Convert RGB to BGR for OpenCV
                bgr_frame = cv2.cvtColor(colorized_full, cv2.COLOR_RGB2BGR)
                out.write(bgr_frame)
                pbar.update(1)

        cap.release()
        out.release()
        print(f"Video saved to {output_path}")

In [None]:
# TRAINING FUNCTION
def train_model(data_dir='data/coco/images/val2017'):
    """Train the improved colorization model"""

    # Check if data directory exists
    if not os.path.exists(data_dir):
        print(f"Data directory {data_dir} not found!")
        print("Please ensure you have downloaded the COCO dataset.")
        return None, [], []

    try:
        # Initialize improved dataset
        full_dataset = ColorizationDataset(data_dir, size=256, augment=True)

        if len(full_dataset) == 0:
            print("No images found in dataset!")
            return None, [], []

        train_size = int(0.8 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_ds, val_ds = random_split(full_dataset, [train_size, val_size])

        # Model and loss
        model = ColorizationModel().to(device)
        criterion = PerceptualLoss().to(device)

        # Different learning rates for encoder and decoder
        encoder_params = []
        decoder_params = []

        for name, param in model.named_parameters():
            if 'encoder' in name:
                encoder_params.append(param)
            else:
                decoder_params.append(param)

        optimizer = optim.Adam([
            {'params': encoder_params, 'lr': 1e-5},  # Lower LR for pretrained encoder
            {'params': decoder_params, 'lr': 1e-4}   # Higher LR for decoder
        ])

        # Learning rate scheduler
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

        # Data loaders
        train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
        val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)

        # Training loop
        num_epochs = 100
        best_val_loss = float('inf')
        train_losses, val_losses = [], []

        for epoch in range(1, num_epochs + 1):
            # Training
            model.train()
            running_loss = 0.0
            num_batches = 0

            for L, AB in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
                try:
                    L, AB = L.to(device), AB.to(device)

                    optimizer.zero_grad()
                    pred_AB = model(L)

                    # Combine L and AB for perceptual loss
                    pred_LAB = torch.cat([L, pred_AB], dim=1)
                    target_LAB = torch.cat([L, AB], dim=1)

                    loss = criterion(pred_LAB, target_LAB)
                    loss.backward()
                    optimizer.step()

                    running_loss += loss.item()
                    num_batches += 1
                except Exception as e:
                    print(f"Training batch error: {e}")
                    continue

            if num_batches > 0:
                train_loss = running_loss / num_batches
                train_losses.append(train_loss)
            else:
                train_losses.append(float('inf'))

            # Validation
            model.eval()
            val_loss = 0.0
            num_val_batches = 0

            with torch.no_grad():
                for L, AB in val_loader:
                    try:
                        L, AB = L.to(device), AB.to(device)
                        pred_AB = model(L)

                        pred_LAB = torch.cat([L, pred_AB], dim=1)
                        target_LAB = torch.cat([L, AB], dim=1)

                        loss = criterion(pred_LAB, target_LAB)
                        val_loss += loss.item()
                        num_val_batches += 1
                    except Exception as e:
                        print(f"Validation batch error: {e}")
                        continue

            if num_val_batches > 0:
                val_loss /= num_val_batches
                val_losses.append(val_loss)
            else:
                val_losses.append(float('inf'))

            print(f"Epoch {epoch}: Train={train_losses[-1]:.4f}, Val={val_losses[-1]:.4f}")

            # Save best model
            if val_losses[-1] < best_val_loss:
                best_val_loss = val_losses[-1]
                torch.save(model.state_dict(), 'best_colorization_model.pth')
                print(f"New best model saved with val_loss: {best_val_loss:.4f}")

            scheduler.step(val_losses[-1])

        return model, train_losses, val_losses

    except Exception as e:
        print(f"Training error: {e}")
        return None, [], []

In [None]:
# QUICK TEST FUNCTION
def quick_test_colorization(model_path, image_path):
    """Quick test function for single image"""
    if not os.path.exists(model_path):
        print(f"Model file {model_path} not found!")
        return

    if not os.path.exists(image_path):
        print(f"Image file {image_path} not found!")
        return

    colorizer = VideoColorizer(model_path)

    # Load and convert image to grayscale
    img = cv2.imread(image_path)
    if img is None:
        print(f"Could not load image {image_path}")
        return

    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # Colorize
    colorized = colorizer.colorize_frame(gray)

    # Display results
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    axes[0].set_title('Original')
    axes[0].axis('off')

    axes[1].imshow(gray, cmap='gray')
    axes[1].set_title('Grayscale')
    axes[1].axis('off')

    axes[2].imshow(colorized)
    axes[2].set_title('Colorized')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# STEP 1: TRAIN THE MODEL

# Download COCO dataset (if not already done)
!mkdir -p data/coco/images
!wget -O data/coco/val2017.zip http://images.cocodataset.org/zips/val2017.zip
!unzip -q data/coco/val2017.zip -d data/coco/images/

# Train the improved model
print("Starting training...")
model, train_losses, val_losses = train_model()

# Plot improved training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', linewidth=2)
plt.plot(val_losses, label='Val Loss', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(val_losses, 'r-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss Trend')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

--2025-06-27 17:56:50--  http://images.cocodataset.org/zips/val2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.217.236.145, 52.216.147.172, 3.5.25.90, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|52.217.236.145|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 815585330 (778M) [application/zip]
Saving to: ‘data/coco/val2017.zip’


2025-06-27 17:56:58 (95.6 MB/s) - ‘data/coco/val2017.zip’ saved [815585330/815585330]

Starting training...


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 81.5MB/s]
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:07<00:00, 73.3MB/s]
  rgb_img = color.lab2rgb(lab_img)
  rgb_img = color.lab2rgb(lab_img)
  rgb_img = color.lab2rgb(lab_img)
  rgb_img = color.lab2rgb(lab_img)
  rgb_img = color.lab2rgb(lab_img)
  rgb_img = color.lab2rgb(lab_img)
  rgb_img = color.lab2rgb(lab_img)
  rgb_img = color.lab2rgb(lab_img)
  rgb_img = color.lab2rgb(lab_img)


In [None]:
# STEP 2: EVALUATE MODEL PERFORMANCE

# Load validation data for evaluation
val_dataset = ColorizationDataset('data/coco/images/val2017', size=256, augment=False)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

# Evaluate with PSNR and SSIM
print("\n Evaluating model performance...")
psnr_scores, ssim_scores = evaluate_model(model, val_loader, num_samples=50)


NameError: name 'ColorizationDataset' is not defined

In [None]:
# STEP 3: VIDEO COLORIZATION DEMO

# Initialize video colorizer
colorizer = VideoColorizer('best_colorization_model.pth')

# Demo 1: Process uploaded video
print("\nVideo Colorization Demo")
print("Upload a grayscale video file:")

from google.colab import files
uploaded = files.upload()

for filename in uploaded.keys():
    if filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
        input_video = filename
        output_video = f"colorized_{filename}"

        print(f"Processing: {filename}")

        # Colorize the video
        colorizer.colorize_video(input_video, output_video, batch_size=4)

        # Display first few frames for preview
        cap = cv2.VideoCapture(input_video)
        cap_out = cv2.VideoCapture(output_video)

        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        fig.suptitle('Video Colorization Preview (First 4 Frames)', fontsize=16)

        for i in range(4):
            ret_in, frame_in = cap.read()
            ret_out, frame_out = cap_out.read()

            if ret_in and ret_out:
                # Convert to RGB for matplotlib
                frame_in_rgb = cv2.cvtColor(frame_in, cv2.COLOR_BGR2RGB)
                frame_out_rgb = cv2.cvtColor(frame_out, cv2.COLOR_BGR2RGB)

                # Convert input to grayscale for display
                frame_gray = cv2.cvtColor(frame_in, cv2.COLOR_BGR2GRAY)

                axes[0, i].imshow(frame_gray, cmap='gray')
                axes[0, i].set_title(f'Frame {i+1} (Grayscale)')
                axes[0, i].axis('off')

                axes[1, i].imshow(frame_out_rgb)
                axes[1, i].set_title(f'Frame {i+1} (Colorized)')
                axes[1, i].axis('off')

        cap.release()
        cap_out.release()
        plt.tight_layout()
        plt.show()

        # Download colorized video
        files.download(output_video)

In [None]:
# STEP 4: BATCH IMAGE PROCESSING

def batch_colorize_images(image_folder, output_folder, model_path):
    """Colorize all images in a folder"""
    os.makedirs(output_folder, exist_ok=True)
    colorizer = VideoColorizer(model_path)

    image_files = [f for f in os.listdir(image_folder)
                   if f.lower().endswith(('.jpg', '.png', '.jpeg'))]

    for img_file in tqdm(image_files, desc="Colorizing images"):
        img_path = os.path.join(image_folder, img_file)

        # Load image
        img = cv2.imread(img_path)
        if img is None:
            continue

        # Convert to grayscale
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Colorize
        colorized = colorizer.colorize_frame(gray)

        # Resize back to original size
        original_shape = img.shape[:2]
        colorized_resized = cv2.resize(colorized,
                                     (img.shape[1], img.shape[0]))

        # Save
        output_path = os.path.join(output_folder, f"colorized_{img_file}")
        colorized_bgr = cv2.cvtColor(colorized_resized, cv2.COLOR_RGB2BGR)
        cv2.imwrite(output_path, colorized_bgr)

    print(f"Batch processing complete! Check {output_folder}")

In [None]:
# STEP 5: REAL-TIME COLORIZATION

def real_time_colorization(model_path, camera_id=0):
    """Real-time colorization from webcam"""
    colorizer = VideoColorizer(model_path)
    cap = cv2.VideoCapture(camera_id)

    print("🎥 Real-time colorization started. Press 'q' to quit.")

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Convert to grayscale
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

        # Colorize (resize for speed)
        small_gray = cv2.resize(gray, (128, 128))
        colorized_smally

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Save model to Drive
torch.save(model.state_dict(), '/content/drive/MyDrive/colorization_model.pth')
print("Model is successfully Saved to drive")

# Load from Drive in future sessions
# model.load_state_dict(torch.load('/content/drive/MyDrive/colorization_model.pth', map_location=device))