# Image Colorization Using Deep Convolutional Neural Networks

## Assignment 3 - Report and Implementation

**Author:** [Your Name]  
**Date:** December 2024

---

## Introduction

In this assignment, we implement image colorization using Deep Convolutional Neural Networks (DCNNs). We use encoder-decoder architectures to learn complex mappings between grayscale images and their corresponding colorized versions.

### Dataset: Intel Image Classification (Pre-processed)

We use a **pre-processed subset** of the Intel Image Classification Dataset:
- **8000 images** (128x128 resolution)
- **6 categories**: Buildings, Forest, Glacier, Mountain, Sea, Street
- **Pre-split**: 80% training (6400), 20% validation (1600)
- **Location**: `colorization_dataset/` folder in Google Drive

The dataset was prepared using `prepare_dataset.py` script which:
1. Selected 8000 random images from the original dataset
2. Resized all images to 128x128 (power of 2 for optimal encoder/decoder compatibility)
3. Split into train/val folders (80/20)

### L*a*b* Color Model

The L*a*b* color model (CIELAB) represents color as three components:
- **L***: Lightness (0 for black to 100 for white)
- **a***: Green-to-red spectrum
- **b***: Blue-to-yellow spectrum

This separation makes the colorization problem natural: we input the L channel (grayscale) and predict the a* and b* channels.

### Table of Contents
1. Setup and Import Libraries (Google Colab)
2. Load Pre-processed Dataset
3. Color Space Conversion (RGB to L*a*b*)
4. PyTorch Dataset and DataLoaders
5. Baseline Encoder-Decoder Model
6. Training with L1/MSE Loss
7. Pretrained Global Feature Extractor (VGG16/ResNet50)
8. Improved Architecture (U-Net + Attention)
9. Perceptual Loss Implementation
10. PatchGAN Discriminator (Bonus)
11. Evaluation and Visualization

### Requirements
- Google Colab with GPU runtime
- Pre-processed dataset in Google Drive (`/MyDrive/colorization_dataset/`)

---

## Step 1: Setup and Import Libraries (Google Colab)

**Important:** 
1. Make sure you're using a GPU runtime: `Runtime → Change runtime type → GPU`
2. Upload the `colorization_dataset` folder to your Google Drive root (`/MyDrive/colorization_dataset/`)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
import sys
import random
import shutil
import glob
import numpy as np
from PIL import Image
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torchvision.utils import make_grid

from skimage import color
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Set random seeds
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Check GPU info
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
else:
    print("WARNING: GPU not available.")

In [None]:

DRIVE_BASE = '/content/drive/MyDrive'
DATASET_PATH = os.path.join(DRIVE_BASE, 'Assignment 3', 'dataset')
TRAIN_PATH = os.path.join(DATASET_PATH, 'train')
VAL_PATH = os.path.join(DATASET_PATH, 'val')

OUTPUT_PATH = os.path.join(DRIVE_BASE, 'Assignment 3')
MODEL_PATH = os.path.join(OUTPUT_PATH, 'models')
RESULTS_PATH = os.path.join(OUTPUT_PATH, 'results')

os.makedirs(MODEL_PATH, exist_ok=True)
os.makedirs(RESULTS_PATH, exist_ok=True)

---

## Step 2: Configuration and Dataset Loading

The dataset has been pre-processed using `prepare_dataset.py`:
- **8000 images** from Intel Image Classification
- **128x128 resolution** (resized, power of 2 for optimal encoder/decoder)
- **Already split** into train (6400) and val (1600) folders

### Color Space Conversion

We convert each image from RGB to L*a*b* color space:
- **L channel** (0-100): Used as grayscale input
- **a* channel** (-128 to 127): Green-red color information
- **b* channel** (-128 to 127): Blue-yellow color information

For neural network training, we normalize:
- L: [0, 100] → [-1, 1]
- a, b: [-128, 127] → [-1, 1]

In [None]:
class Config:
    # Dataset settings
    IMAGE_SIZE = 128
    NUM_IMAGES = 8000
    
    # Training settings
    BATCH_SIZE = 32
    NUM_EPOCHS = 50
    LEARNING_RATE = 2e-4
    BETA1 = 0.5        # Adam optimizer beta1
    BETA2 = 0.999      # Adam optimizer beta2
    
    # Model settings
    LATENT_DIM = 512

config = Config()

config.TRAIN_PATH = TRAIN_PATH
config.VAL_PATH = VAL_PATH
config.MODEL_PATH = MODEL_PATH  
config.RESULTS_PATH = RESULTS_PATH

In [None]:
def visualize_dataset_samples(train_path, val_path, num_samples=6):
    fig, axes = plt.subplots(2, num_samples, figsize=(18, 6))
    
    # Training samples
    train_images = [f for f in os.listdir(train_path) if f.endswith('.jpg')][:num_samples]
    for idx, img_name in enumerate(train_images):
        img = Image.open(os.path.join(train_path, img_name))
        axes[0, idx].imshow(img)
        axes[0, idx].set_title(f'Train {idx+1}', fontsize=10)
        axes[0, idx].axis('off')
    
    # Validation samples
    val_images = [f for f in os.listdir(val_path) if f.endswith('.jpg')][:num_samples]
    for idx, img_name in enumerate(val_images):
        img = Image.open(os.path.join(val_path, img_name))
        axes[1, idx].imshow(img)
        axes[1, idx].set_title(f'Val {idx+1}', fontsize=10)
        axes[1, idx].axis('off')
    
    axes[0, 0].set_ylabel('Training', fontsize=12)
    axes[1, 0].set_ylabel('Validation', fontsize=12)
    
    plt.suptitle('Pre-processed Dataset Samples (128x128)', fontsize=14)
    plt.tight_layout()
    plt.show()

# Visualize samples
print("Dataset samples:")
visualize_dataset_samples(TRAIN_PATH, VAL_PATH)

---

## Step 3: RGB to L*a*b* Color Space Conversion

Now we convert our RGB images to L*a*b* color space. This is the key preprocessing step that makes our colorization task natural:

- **Input**: L channel (grayscale luminance)
- **Output**: a* and b* channels (chrominance/color information)

The model learns to predict color (a*, b*) given brightness (L).

In [None]:
class ColorSpaceConverter:
    @staticmethod
    def rgb_to_lab(rgb_image):
        if rgb_image.max() > 1.0:
            rgb_image = rgb_image / 255.0
        
        lab_image = color.rgb2lab(rgb_image)
        return lab_image
    
    @staticmethod
    def lab_to_rgb(lab_image):
        rgb_image = color.lab2rgb(lab_image)
        return np.clip(rgb_image, 0, 1)
    
    @staticmethod
    def normalize_l(l_channel):
        return (l_channel / 50.0) - 1.0
    
    @staticmethod
    def denormalize_l(l_normalized):
        return (l_normalized + 1.0) * 50.0
    
    @staticmethod
    def normalize_ab(ab_channels):
        return ab_channels / 128.0
    
    @staticmethod
    def denormalize_ab(ab_normalized):
        return ab_normalized * 128.0

def visualize_lab_conversion(image_path, image_size=128):
    """Visualize RGB to L*a*b* conversion"""
    img = Image.open(image_path).convert('RGB')
    img = img.resize((image_size, image_size))
    img_array = np.array(img)
    
    # Convert to L*a*b*
    lab = ColorSpaceConverter.rgb_to_lab(img_array)
    
    # Extract channels
    L = lab[:, :, 0]
    a = lab[:, :, 1]
    b = lab[:, :, 2]
    
    # Visualize
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    
    axes[0].imshow(img_array)
    axes[0].set_title('Original RGB')
    axes[0].axis('off')
    
    axes[1].imshow(L, cmap='gray')
    axes[1].set_title('L channel (Luminance)')
    axes[1].axis('off')
    
    axes[2].imshow(a, cmap='RdYlGn_r')
    axes[2].set_title('a channel (Green-Red)')
    axes[2].axis('off')
    
    axes[3].imshow(b, cmap='YlGnBu_r')
    axes[3].set_title('b channel (Blue-Yellow)')
    axes[3].axis('off')
    
    reconstructed = ColorSpaceConverter.lab_to_rgb(lab)
    axes[4].imshow(reconstructed)
    axes[4].set_title('Reconstructed RGB')
    axes[4].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return lab

In [None]:
# Visualize color space conversion with a sample image
sample_images = glob.glob(os.path.join(TRAIN_PATH, '*.jpg'))

sample_path = sample_images[0]
print(f"Visualizing color space conversion for: {os.path.basename(sample_path)}")
lab_example = visualize_lab_conversion(sample_path)

---

## Step 4: PyTorch Dataset and DataLoaders

Now we create a custom PyTorch Dataset class that:
1. Loads RGB images
2. Converts them to L*a*b* color space
3. Returns normalized L channel as input and a*b* channels as target
4. Applies data augmentation for training

In [None]:
class ColorizationDataset(Dataset):    
    def __init__(self, root_dir, image_size=128, transform=None, split='train'):
        self.root_dir = root_dir
        self.image_size = image_size
        self.transform = transform
        self.split = split

        self.image_files = [f for f in os.listdir(root_dir) 
                           if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
        
        self.image_files.sort()
        
        # Define augmentation (only for training)
        if split == 'train':
            self.augment = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
            ])
        else:
            self.augment = None
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        img = Image.open(img_path).convert('RGB')
        
        # Apply augmentation (only for training)
        if self.augment is not None:
            img = self.augment(img)
        
        img_array = np.array(img).astype(np.float32) / 255.0
        lab = color.rgb2lab(img_array).astype(np.float32)
        
        # L channel: [0, 100] -> [-1, 1]
        L = lab[:, :, 0:1]
        L = (L / 50.0) - 1.0
        
        # a*b* channels: [-128, 127] -> [-1, 1]
        ab = lab[:, :, 1:3]
        ab = ab / 128.0
        
        L = torch.from_numpy(L.transpose(2, 0, 1))
        ab = torch.from_numpy(ab.transpose(2, 0, 1))
        
        return L, ab
    
    def get_sample_image(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        img = Image.open(img_path).convert('RGB')
        
        if img.size != (self.image_size, self.image_size):
            img = img.resize((self.image_size, self.image_size), Image.BILINEAR)
        
        img_array = np.array(img).astype(np.float32) / 255.0
        lab = color.rgb2lab(img_array).astype(np.float32)
        
        return img_array, lab

In [None]:
def create_dataloaders(train_path, val_path, image_size=150, batch_size=32):
    train_dataset = ColorizationDataset(
        root_dir=train_path,
        image_size=image_size,
        split='train'
    )
    
    val_dataset = ColorizationDataset(
        root_dir=val_path,
        image_size=image_size,
        split='val'
    )
    
    print(f"\nDataset info:")
    print(f"  Training images: {len(train_dataset)}")
    print(f"  Validation images: {len(val_dataset)}")
    print(f"  Total: {len(train_dataset) + len(val_dataset)}")
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    print(f"\nDataLoader info:")
    print(f"  Training batches: {len(train_loader)}")
    print(f"  Validation batches: {len(val_loader)}")
    print(f"  Batch size: {batch_size}")
    
    return train_loader, val_loader, train_dataset

train_loader, val_loader, train_dataset = create_dataloaders(
    TRAIN_PATH,
    VAL_PATH,
    image_size=config.IMAGE_SIZE,
    batch_size=config.BATCH_SIZE
)

In [None]:
def visualize_batch(dataloader, num_samples=4):
    L_batch, ab_batch = next(iter(dataloader))
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    
    for i in range(min(num_samples, L_batch.shape[0])):
        # Get L and ab
        L = L_batch[i].numpy()
        ab = ab_batch[i].numpy()
        
        # Denormalize
        L_denorm = (L[0] + 1) * 50  # [0, 100]
        a_denorm = ab[0] * 128      # [-128, 127]
        b_denorm = ab[1] * 128      # [-128, 127]
        
        # Reconstruct L*a*b* and convert to RGB
        lab = np.stack([L_denorm, a_denorm, b_denorm], axis=-1)
        rgb = color.lab2rgb(lab)
        rgb = np.clip(rgb, 0, 1)
        
        axes[i, 0].imshow(L_denorm, cmap='gray')
        axes[i, 0].set_title('L channel (Input)')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(a_denorm, cmap='RdYlGn_r')
        axes[i, 1].set_title('a* channel')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(b_denorm, cmap='YlGnBu_r')
        axes[i, 2].set_title('b* channel')
        axes[i, 2].axis('off')
        
        axes[i, 3].imshow(rgb)
        axes[i, 3].set_title('Reconstructed RGB')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.show()

print("Sample batch visualization:")
visualize_batch(train_loader, num_samples=4)

---

## Step 5: Baseline Encoder-Decoder Model Architecture

The baseline model follows an encoder-decoder architecture with:

1. **Low-Level Feature Extractor**: Captures local spatial features (edges, textures)
2. **Global Feature Extractor**: Captures semantic/contextual information
3. **Fusion Block**: Combines local and global features
4. **Decoder**: Upsamples and predicts a*b* channels

### Architecture Diagram

```
Input (L channel)
      │
      ├──────────────────────────────────┐
      ▼                                  ▼
┌─────────────────┐              ┌─────────────────┐
│   Low-Level     │              │    Global       │
│   Feature       │              │    Feature      │
│   Extractor     │              │    Extractor    │
└────────┬────────┘              └────────┬────────┘
         │                                │
         └───────────┬────────────────────┘
                     ▼
              ┌──────────────┐
              │   Fusion     │
              │   Block      │
              └──────┬───────┘
                     │
                     ▼
              ┌──────────────┐
              │   Decoder    │
              │ (Upsampling) │
              └──────┬───────┘
                     │
                     ▼
              Output (a*b* channels)
```

In [None]:
class LowLevelFeatureExtractor(nn.Module):
    def __init__(self, in_channels=1):
        super(LowLevelFeatureExtractor, self).__init__()
        
        # Layer 1: 128x128 -> 64x64
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Layer 2: 64x64 -> 32x32
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        # Layer 3: 32x32 -> 16x16
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Layer 4: 16x16 -> 16x16 (maintain spatial size)
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Layer 5: 16x16 -> 16x16
        self.conv5 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.conv1(x)  # (B, 64, 64, 64)
        x = self.conv2(x)  # (B, 128, 32, 32)
        x = self.conv3(x)  # (B, 256, 16, 16)
        x = self.conv4(x)  # (B, 256, 16, 16)
        x = self.conv5(x)  # (B, 512, 16, 16)
        return x

print("Testing LowLevelFeatureExtractor...")
test_input = torch.randn(1, 1, 128, 128)
low_level_net = LowLevelFeatureExtractor()
low_level_output = low_level_net(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {low_level_output.shape}")
print(f"Parameters: {sum(p.numel() for p in low_level_net.parameters()):,}")

In [None]:
class GlobalFeatureExtractor(nn.Module):
    def __init__(self, in_channels=1, feature_dim=512):
        super(GlobalFeatureExtractor, self).__init__()
        
        self.feature_dim = feature_dim
        
        # Encoder layers: 128 -> 64 -> 32 -> 16 -> 8 -> 4
        self.encoder = nn.Sequential(
            # 128x128 -> 64x64
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # 64x64 -> 32x32
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # 32x32 -> 16x16
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            # 16x16 -> 8x8
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            # 8x8 -> 4x4
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, feature_dim),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.encoder(x)        # (B, 512, 4, 4)
        x = self.global_pool(x)    # (B, 512, 1, 1)
        x = x.view(x.size(0), -1)  # (B, 512)
        x = self.fc(x)             # (B, feature_dim)
        return x

print("Testing GlobalFeatureExtractor...")
global_net = GlobalFeatureExtractor()
global_output = global_net(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {global_output.shape}")
print(f"Parameters: {sum(p.numel() for p in global_net.parameters()):,}")

In [None]:
class FusionBlock(nn.Module):
    def __init__(self, local_channels=512, global_channels=512, output_channels=512):
        super(FusionBlock, self).__init__()
        
        self.local_channels = local_channels
        self.global_channels = global_channels
        
        # 1x1 convolution to process concatenated features
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(local_channels + global_channels, output_channels, 
                     kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
            
            # Additional refinement
            nn.Conv2d(output_channels, output_channels, 
                     kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, local_features, global_features):
        B, C, H, W = local_features.shape
        
        # Replicate global features spatially
        # (B, global_channels) -> (B, global_channels, H, W)
        global_features = global_features.unsqueeze(-1).unsqueeze(-1)
        global_features = global_features.expand(-1, -1, H, W)
        
        # Concatenate along channel dimension
        fused = torch.cat([local_features, global_features], dim=1)
        
        # Apply fusion convolution
        fused = self.fusion_conv(fused)
        
        return fused

print("Testing FusionBlock...")
fusion_block = FusionBlock()
fused_output = fusion_block(low_level_output, global_output)
print(f"Local features shape: {low_level_output.shape}")
print(f"Global features shape: {global_output.shape}")
print(f"Fused output shape: {fused_output.shape}")
print(f"Parameters: {sum(p.numel() for p in fusion_block.parameters()):,}")

In [None]:
class Decoder(nn.Module):    
    def __init__(self, in_channels=512, out_channels=2):
        super(Decoder, self).__init__()
        
        # Upsampling layers: 16x16 -> 32 -> 64 -> 128
        
        # 16x16 -> 32x32
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # 32x32 -> 64x64
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        # 64x64 -> 128x128
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Refinement layer
        self.refine = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
        # Output layer: predict a* and b* channels
        self.output = nn.Sequential(
            nn.Conv2d(32, out_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # Output in [-1, 1] for normalized a*b*
        )
        
    def forward(self, x):
        x = self.up1(x)     # (B, 256, 32, 32)
        x = self.up2(x)     # (B, 128, 64, 64)
        x = self.up3(x)     # (B, 64, 128, 128)
        x = self.refine(x)  # (B, 32, 128, 128)
        x = self.output(x)  # (B, 2, 128, 128)
        return x

print("Testing Decoder...")
decoder = Decoder()
decoder_output = decoder(fused_output)
print(f"Input shape: {fused_output.shape}")
print(f"Output shape: {decoder_output.shape}")
print(f"Parameters: {sum(p.numel() for p in decoder.parameters()):,}")

In [None]:
class ColorizationModel(nn.Module):
    def __init__(self, feature_dim=512):
        super(ColorizationModel, self).__init__()
        
        # Feature extractors
        self.low_level_extractor = LowLevelFeatureExtractor(in_channels=1)
        self.global_extractor = GlobalFeatureExtractor(in_channels=1, feature_dim=feature_dim)
        
        # Fusion
        self.fusion = FusionBlock(
            local_channels=512, 
            global_channels=feature_dim, 
            output_channels=512
        )
        
        self.decoder = Decoder(in_channels=512, out_channels=2)
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                    
    def forward(self, L):
        local_features = self.low_level_extractor(L)   # (B, 512, 16, 16)
        global_features = self.global_extractor(L)      # (B, 512)
        
        fused = self.fusion(local_features, global_features)  # (B, 512, 16, 16)
        
        ab = self.decoder(fused)  # (B, 2, 128, 128)
        
        return ab
    
    def colorize(self, L):
        with torch.no_grad():
            ab_pred = self.forward(L)
        return ab_pred

model = ColorizationModel()
test_L = torch.randn(4, 1, 128, 128)
test_ab = model(test_L)
print(f"\nModel Summary:")
print(f"  Input shape: {test_L.shape}")
print(f"  Output shape: {test_ab.shape}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

---

## Step 6: Loss Functions and Evaluation Metrics

We define several loss functions for training and metrics for evaluation:

### Loss Functions:
- **L1 Loss (MAE)**: Mean Absolute Error - smoother, more stable training
- **MSE Loss**: Mean Squared Error - penalizes large errors more

### Evaluation Metrics:
- **MSE**: Mean Squared Error between predicted and ground truth
- **PSNR**: Peak Signal-to-Noise Ratio (higher is better)
- **SSIM**: Structural Similarity Index (higher is better, max=1)

In [None]:
class ColorizationLoss(nn.Module):
    def __init__(self, loss_type='l1', l1_weight=1.0, mse_weight=0.0):
        super(ColorizationLoss, self).__init__()
        
        self.loss_type = loss_type
        self.l1_weight = l1_weight
        self.mse_weight = mse_weight
        
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        
    def forward(self, pred, target):
        if self.loss_type == 'l1':
            return self.l1_loss(pred, target)
        elif self.loss_type == 'mse':
            return self.mse_loss(pred, target)
        else:
            l1 = self.l1_loss(pred, target)
            mse = self.mse_loss(pred, target)
            return self.l1_weight * l1 + self.mse_weight * mse


class EvaluationMetrics:
    @staticmethod
    def compute_mse(pred, target):
        return F.mse_loss(pred, target).item()
    
    @staticmethod
    def compute_psnr(pred, target, max_val=1.0):
        mse = F.mse_loss(pred, target).item()
        if mse == 0:
            return float('inf')
        psnr = 10 * np.log10(max_val ** 2 / mse)
        return psnr
    
    @staticmethod
    def compute_ssim(pred, target):
        pred_np = pred.cpu().numpy()
        target_np = target.cpu().numpy()
        
        ssim_values = []
        for i in range(pred_np.shape[0]):
            pred_img = pred_np[i].transpose(1, 2, 0)  # (H, W, C)
            target_img = target_np[i].transpose(1, 2, 0)
            
            pred_img = np.clip(pred_img, -1, 1)
            target_img = np.clip(target_img, -1, 1)
            
            ssim_val = ssim(target_img, pred_img, 
                          data_range=2.0,
                          channel_axis=2)
            ssim_values.append(ssim_val)
        
        return np.mean(ssim_values)
    
    @staticmethod
    def compute_all_metrics(pred, target):
        mse = EvaluationMetrics.compute_mse(pred, target)
        psnr_val = EvaluationMetrics.compute_psnr(pred, target)
        ssim_val = EvaluationMetrics.compute_ssim(pred, target)
        return {'mse': mse, 'psnr': psnr_val, 'ssim': ssim_val}

---

## Step 7: Training Loop

Now we implement the training loop with:
- Epoch-wise training and validation
- Metric logging and checkpointing
- Learning rate scheduling
- Early stopping (optional)

In [None]:
class TrainingHistory:
    """Class to track training metrics over epochs"""
    
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.train_mse = []
        self.val_mse = []
        self.val_psnr = []
        self.val_ssim = []
        self.learning_rates = []
        
    def update(self, train_loss, val_loss, train_mse, val_mse, val_psnr, val_ssim, lr):
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        self.train_mse.append(train_mse)
        self.val_mse.append(val_mse)
        self.val_psnr.append(val_psnr)
        self.val_ssim.append(val_ssim)
        self.learning_rates.append(lr)
        
    def get_best_epoch(self, metric='val_loss'):
        if metric == 'val_loss':
            return np.argmin(self.val_losses)
        elif metric == 'val_psnr':
            return np.argmax(self.val_psnr)
        elif metric == 'val_ssim':
            return np.argmax(self.val_ssim)


def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    total_mse = 0.0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc='Training', leave=False)
    for L, ab in pbar:
        L = L.to(device)
        ab = ab.to(device)
        
        ab_pred = model(L)
        
        loss = criterion(ab_pred, ab)
        
        optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        total_mse += F.mse_loss(ab_pred, ab).item()
        num_batches += 1
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / num_batches
    avg_mse = total_mse / num_batches
    
    return avg_loss, avg_mse


def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_mse = 0.0
    total_psnr = 0.0
    total_ssim = 0.0
    num_batches = 0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation', leave=False)
        for L, ab in pbar:
            L = L.to(device)
            ab = ab.to(device)
            
            ab_pred = model(L)
            
            loss = criterion(ab_pred, ab)
            
            metrics = EvaluationMetrics.compute_all_metrics(ab_pred, ab)
            
            total_loss += loss.item()
            total_mse += metrics['mse']
            total_psnr += metrics['psnr']
            total_ssim += metrics['ssim']
            num_batches += 1
    
    avg_loss = total_loss / num_batches
    avg_mse = total_mse / num_batches
    avg_psnr = total_psnr / num_batches
    avg_ssim = total_ssim / num_batches
    
    return avg_loss, avg_mse, avg_psnr, avg_ssim

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=50, lr=2e-4, 
                loss_type='l1', model_name='baseline', device='cuda'):
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}")
    print(f"  Device: {device}")
    print(f"  Epochs: {num_epochs}")
    print(f"  Learning Rate: {lr}")
    print(f"  Loss Type: {loss_type}")
    print(f"  Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"{'='*60}\n")
    
    model = model.to(device)
    
    if loss_type == 'combined':
        criterion = ColorizationLoss(loss_type='combined', l1_weight=1.0, mse_weight=0.5)
    else:
        criterion = ColorizationLoss(loss_type=loss_type)
    
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    history = TrainingHistory()
    
    best_val_loss = float('inf')
    best_model_state = None
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 40)
        
        train_loss, train_mse = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )
        
        val_loss, val_mse, val_psnr, val_ssim = validate(
            model, val_loader, criterion, device
        )
        
        current_lr = optimizer.param_groups[0]['lr']
        
        history.update(
            train_loss, val_loss, train_mse, val_mse, 
            val_psnr, val_ssim, current_lr
        )
        
        scheduler.step(val_loss)
        
        print(f"  Train Loss: {train_loss:.4f} | Train MSE: {train_mse:.4f}")
        print(f"  Val Loss:   {val_loss:.4f} | Val MSE:   {val_mse:.4f}")
        print(f"  Val PSNR:   {val_psnr:.2f} dB | Val SSIM: {val_ssim:.4f}")
        print(f"  LR: {current_lr:.6f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            print(f"  *** New best model! ***")
            
            checkpoint_path = os.path.join(config.MODEL_PATH, f'{model_name}_best.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': best_model_state,
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_psnr': val_psnr,
                'val_ssim': val_ssim,
            }, checkpoint_path)
    
    model.load_state_dict(best_model_state)
    
    print(f"\n{'='*60}")
    print(f"Training Complete!")
    print(f"Best Validation Loss: {best_val_loss:.4f}")
    print(f"Best Epoch: {history.get_best_epoch('val_loss') + 1}")
    print(f"{'='*60}")
    
    return history

In [None]:
baseline_model = ColorizationModel(feature_dim=512)

# Train the baseline model with L1 loss
baseline_history = train_model(
    model=baseline_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=config.NUM_EPOCHS,
    lr=config.LEARNING_RATE,
    loss_type='l1',
    model_name='baseline_l1',
    device=device
)

---

## Step 8: Pretrained Global Feature Extractor

Now we integrate pretrained models as global feature extractors. We'll implement:
1. **VGG16** - Classic architecture, good for feature extraction
2. **ResNet50** - Deeper network with residual connections

These pretrained models provide better semantic understanding of the image content, leading to more accurate colorization.

In [None]:
class VGG16GlobalExtractor(nn.Module):
    def __init__(self, feature_dim=512, pretrained=True, freeze_backbone=False):
        super(VGG16GlobalExtractor, self).__init__()
        
        vgg16 = models.vgg16(weights='IMAGENET1K_V1' if pretrained else None)
        
        # Modify first conv layer for single-channel input
        original_conv = vgg16.features[0]
        new_conv = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        
        if pretrained:
            with torch.no_grad():
                new_conv.weight.data = original_conv.weight.data.mean(dim=1, keepdim=True)
                new_conv.bias.data = original_conv.bias.data
        
        vgg16.features[0] = new_conv
        self.features = vgg16.features
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, feature_dim),
            nn.ReLU(inplace=True)
        )
        
        if freeze_backbone:
            for param in self.features.parameters():
                param.requires_grad = False
            print("VGG16 backbone frozen")
        
    def forward(self, x):
        x = self.features(x)      # (B, 512, H/32, W/32)
        x = self.avgpool(x)       # (B, 512, 1, 1)
        x = x.view(x.size(0), -1) # (B, 512)
        x = self.fc(x)            # (B, feature_dim)
        return x


class ResNet50GlobalExtractor(nn.Module):
    def __init__(self, feature_dim=512, pretrained=True, freeze_backbone=False):
        super(ResNet50GlobalExtractor, self).__init__()
        
        resnet = models.resnet50(weights='IMAGENET1K_V1' if pretrained else None)
        
        # Modify first conv layer for single-channel input
        original_conv = resnet.conv1
        new_conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        if pretrained:
            with torch.no_grad():
                new_conv.weight.data = original_conv.weight.data.mean(dim=1, keepdim=True)
        
        self.conv1 = new_conv
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, feature_dim),
            nn.ReLU(inplace=True)
        )
        
        if freeze_backbone:
            for name, param in self.named_parameters():
                if 'fc' not in name:
                    param.requires_grad = False
            print("ResNet50 backbone frozen")
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x

print("Testing VGG16GlobalExtractor...")
vgg_extractor = VGG16GlobalExtractor(pretrained=True)
vgg_output = vgg_extractor(test_input)
print(f"  Input: {test_input.shape} -> Output: {vgg_output.shape}")
print(f"  Parameters: {sum(p.numel() for p in vgg_extractor.parameters()):,}")

print("\nTesting ResNet50GlobalExtractor...")
resnet_extractor = ResNet50GlobalExtractor(pretrained=True)
resnet_output = resnet_extractor(test_input)
print(f"  Input: {test_input.shape} -> Output: {resnet_output.shape}")
print(f"  Parameters: {sum(p.numel() for p in resnet_extractor.parameters()):,}")

In [None]:
class ColorizationModelPretrained(nn.Module):
    def __init__(self, backbone='vgg16', pretrained=True, freeze_backbone=False):
        super(ColorizationModelPretrained, self).__init__()
        
        self.backbone_name = backbone
        self.low_level = LowLevelFeatureExtractor(in_channels=1)
        
        if backbone == 'vgg16':
            self.global_features = VGG16GlobalExtractor(
                feature_dim=512, 
                pretrained=pretrained,
                freeze_backbone=freeze_backbone
            )
        elif backbone == 'resnet50':
            self.global_features = ResNet50GlobalExtractor(
                feature_dim=512,
                pretrained=pretrained,
                freeze_backbone=freeze_backbone
            )
        else:
            raise ValueError(f"Unknown backbone: {backbone}")
        
        self.fusion = FusionBlock(local_channels=512, global_channels=512, output_channels=512)
        self.decoder = Decoder()
        
        print(f"Created ColorizationModelPretrained with {backbone} backbone")
        print(f"  Freeze backbone: {freeze_backbone}")
        
    def forward(self, L):
        low_features = self.low_level(L)  # (B, 512, 16, 16)
        global_features = self.global_features(L)  # (B, 512)
        fused = self.fusion(low_features, global_features)  # (B, 512, 16, 16)
        ab = self.decoder(fused)  # (B, 2, 128, 128)
        
        return ab

print("\nTesting ColorizationModelPretrained with VGG16...")
model_vgg = ColorizationModelPretrained(backbone='vgg16', pretrained=True)
model_vgg = model_vgg.to(device)

with torch.no_grad():
    test_L = torch.randn(2, 1, 128, 128).to(device)
    test_ab = model_vgg(test_L)
    print(f"  Input L: {test_L.shape}")
    print(f"  Output ab: {test_ab.shape}")
    print(f"  Total parameters: {sum(p.numel() for p in model_vgg.parameters()):,}")
    print(f"  Trainable parameters: {sum(p.numel() for p in model_vgg.parameters() if p.requires_grad):,}")

---

## Step 9: Improved Architecture with Skip Connections and Attention

To improve the colorization quality, we implement:

1. **Skip Connections**: Connect encoder features directly to decoder layers to preserve spatial details
2. **Channel Attention**: Focus on the most relevant features for colorization
3. **Multi-scale Fusion**: Combine features at multiple resolutions

These improvements help preserve fine details and produce more accurate colors.

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super(ChannelAttention, self).__init__()
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False)
        )
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        B, C, H, W = x.shape
        
        avg_out = self.fc(self.avg_pool(x).view(B, C))
        max_out = self.fc(self.max_pool(x).view(B, C))
        
        attention = self.sigmoid(avg_out + max_out)
        attention = attention.view(B, C, 1, 1)
        
        return x * attention


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        
        attention = torch.cat([avg_out, max_out], dim=1)
        attention = self.conv(attention)
        attention = self.sigmoid(attention)
        
        return x * attention


class CBAM(nn.Module):
    def __init__(self, channels, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        
        self.channel_attention = ChannelAttention(channels, reduction)
        self.spatial_attention = SpatialAttention(kernel_size)
        
    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

In [None]:
class ImprovedEncoder(nn.Module):
    def __init__(self, in_channels=1, use_attention=True):
        super(ImprovedEncoder, self).__init__()
        
        self.use_attention = use_attention
        
        self.init_conv = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Level 1: 128x128 -> 64x64
        self.enc1 = self._make_layer(64, 128, stride=2)
        self.attn1 = CBAM(128) if use_attention else nn.Identity()
        
        # Level 2: 64x64 -> 32x32
        self.enc2 = self._make_layer(128, 256, stride=2)
        self.attn2 = CBAM(256) if use_attention else nn.Identity()
        
        # Level 3: 32x32 -> 16x16
        self.enc3 = self._make_layer(256, 512, stride=2)
        self.attn3 = CBAM(512) if use_attention else nn.Identity()
        
        # Level 4: 16x16 -> 8x8
        self.enc4 = self._make_layer(512, 512, stride=2)
        self.attn4 = CBAM(512) if use_attention else nn.Identity()
        
        # Bottleneck: 8x8 -> 4x4
        self.bottleneck = self._make_layer(512, 512, stride=2)
        
    def _make_layer(self, in_ch, out_ch, stride=1):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x0 = self.init_conv(x)  # (B, 64, 128, 128)
        x1 = self.attn1(self.enc1(x0))   # (B, 128, 64, 64)
        x2 = self.attn2(self.enc2(x1))   # (B, 256, 32, 32)
        x3 = self.attn3(self.enc3(x2))   # (B, 512, 16, 16)
        x4 = self.attn4(self.enc4(x3))   # (B, 512, 8, 8)
        
        bottleneck = self.bottleneck(x4)  # (B, 512, 4, 4)
        
        skip_features = [x0, x1, x2, x3, x4]
        
        return skip_features, bottleneck


class ImprovedDecoder(nn.Module):
    def __init__(self, use_attention=True):
        super(ImprovedDecoder, self).__init__()
        
        self.use_attention = use_attention
        
        # 4x4 -> 8x8
        self.up1 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1)
        self.dec1 = self._make_layer(512 + 512, 512)  # +512 from skip
        self.attn1 = CBAM(512) if use_attention else nn.Identity()
        
        # 8x8 -> 16x16
        self.up2 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1)
        self.dec2 = self._make_layer(512 + 512, 512)  # +512 from skip
        self.attn2 = CBAM(512) if use_attention else nn.Identity()
        
        # 16x16 -> 32x32
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.dec3 = self._make_layer(256 + 256, 256)  # +256 from skip
        self.attn3 = CBAM(256) if use_attention else nn.Identity()
        
        # 32x32 -> 64x64
        self.up4 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.dec4 = self._make_layer(128 + 128, 128)  # +128 from skip
        self.attn4 = CBAM(128) if use_attention else nn.Identity()
        
        # 64x64 -> 128x128
        self.up5 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.dec5 = self._make_layer(64 + 64, 64)  # +64 from skip
        
        # Final output layer
        self.output = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )
        
    def _make_layer(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, bottleneck, skip_features):
        x0, x1, x2, x3, x4 = skip_features
        
        x = self.up1(bottleneck)           # (B, 512, 8, 8)
        x = torch.cat([x, x4], dim=1)      # (B, 1024, 8, 8)
        x = self.attn1(self.dec1(x))       # (B, 512, 8, 8)
        
        x = self.up2(x)                    # (B, 512, 16, 16)
        x = torch.cat([x, x3], dim=1)      # (B, 1024, 16, 16)
        x = self.attn2(self.dec2(x))       # (B, 512, 16, 16)
        
        x = self.up3(x)                    # (B, 256, 32, 32)
        x = torch.cat([x, x2], dim=1)      # (B, 512, 32, 32)
        x = self.attn3(self.dec3(x))       # (B, 256, 32, 32)
        
        x = self.up4(x)                    # (B, 128, 64, 64)
        x = torch.cat([x, x1], dim=1)      # (B, 256, 64, 64)
        x = self.attn4(self.dec4(x))       # (B, 128, 64, 64)
        
        x = self.up5(x)                    # (B, 64, 128, 128)
        x = torch.cat([x, x0], dim=1)      # (B, 128, 128, 128)
        x = self.dec5(x)                   # (B, 64, 128, 128)
        
        ab = self.output(x)                # (B, 2, 128, 128)
        
        return ab

In [None]:
class ImprovedColorizationModel(nn.Module):
    def __init__(self, use_attention=True, use_global_features=True, 
                 backbone='vgg16', pretrained=True, freeze_backbone=False):
        super(ImprovedColorizationModel, self).__init__()
        
        self.use_global_features = use_global_features
        
        self.encoder = ImprovedEncoder(in_channels=1, use_attention=use_attention)
        
        if use_global_features:
            if backbone == 'vgg16':
                self.global_extractor = VGG16GlobalExtractor(
                    feature_dim=512, pretrained=pretrained, freeze_backbone=freeze_backbone
                )
            elif backbone == 'resnet50':
                self.global_extractor = ResNet50GlobalExtractor(
                    feature_dim=512, pretrained=pretrained, freeze_backbone=freeze_backbone
                )
            else:
                self.global_extractor = GlobalFeatureExtractor(feature_dim=512)
            
            self.global_fusion = nn.Sequential(
                nn.Conv2d(512 + 512, 512, kernel_size=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True)
            )
        
        self.decoder = ImprovedDecoder(use_attention=use_attention)
        
        print(f"Created ImprovedColorizationModel")
        print(f"  Attention: {use_attention}")
        print(f"  Global features: {use_global_features}")
        if use_global_features:
            print(f"  Backbone: {backbone}")
        
    def forward(self, L):
        skip_features, bottleneck = self.encoder(L)
        
        if self.use_global_features:
            global_features = self.global_extractor(L)  # (B, 512)
            
            B, C, H, W = bottleneck.shape
            global_expanded = global_features.unsqueeze(-1).unsqueeze(-1)
            global_expanded = global_expanded.expand(-1, -1, H, W)
            
            bottleneck = torch.cat([bottleneck, global_expanded], dim=1)
            bottleneck = self.global_fusion(bottleneck)
        
        ab = self.decoder(bottleneck, skip_features)
        
        return ab

print("\nTesting ImprovedColorizationModel...")
improved_model = ImprovedColorizationModel(
    use_attention=True, 
    use_global_features=True,
    backbone='vgg16',
    pretrained=True
)
improved_model = improved_model.to(device)

with torch.no_grad():
    test_L = torch.randn(2, 1, 128, 128).to(device)
    test_ab = improved_model(test_L)
    print(f"  Input L: {test_L.shape}")
    print(f"  Output ab: {test_ab.shape}")
    print(f"  Total parameters: {sum(p.numel() for p in improved_model.parameters()):,}")
    print(f"  Trainable parameters: {sum(p.numel() for p in improved_model.parameters() if p.requires_grad):,}")

---

## Step 10: Perceptual Loss Implementation

Perceptual loss uses a pretrained VGG network to compare feature representations rather than raw pixels. This encourages the model to produce:
- More realistic textures
- Better semantic consistency
- More natural-looking colors

We compare three loss configurations:
1. **L1 Loss only**: Pixel-wise absolute difference
2. **Perceptual Loss only**: Feature-based comparison
3. **Combined Loss**: L1 + Perceptual

In [None]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self, layers=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4'], 
                 weights=[1.0, 1.0, 1.0, 1.0]):
        super(VGGPerceptualLoss, self).__init__()
        
        vgg = models.vgg19(weights='IMAGENET1K_V1').features.eval()
        
        for param in vgg.parameters():
            param.requires_grad = False
        
        self.layer_mapping = {
            'relu1_1': 1, 'relu1_2': 3,
            'relu2_1': 6, 'relu2_2': 8,
            'relu3_1': 11, 'relu3_2': 13, 'relu3_3': 15, 'relu3_4': 17,
            'relu4_1': 20, 'relu4_2': 22, 'relu4_3': 24, 'relu4_4': 26,
            'relu5_1': 29, 'relu5_2': 31, 'relu5_3': 33, 'relu5_4': 35
        }
        
        self.layer_indices = [self.layer_mapping[layer] for layer in layers]
        self.weights = weights
        
        self.slices = nn.ModuleList()
        prev_idx = 0
        for idx in self.layer_indices:
            self.slices.append(vgg[prev_idx:idx+1])
            prev_idx = idx + 1
        
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
        
    def normalize(self, x):
        return (x - self.mean) / self.std
    
    def forward(self, pred_rgb, target_rgb):
        pred_norm = self.normalize(pred_rgb)
        target_norm = self.normalize(target_rgb)
        
        total_loss = 0.0
        pred_feat = pred_norm
        target_feat = target_norm
        
        for i, slice in enumerate(self.slices):
            pred_feat = slice(pred_feat)
            target_feat = slice(target_feat)
            
            loss = F.l1_loss(pred_feat, target_feat)
            total_loss += self.weights[i] * loss
        
        return total_loss


class CombinedLoss(nn.Module):
    def __init__(self, l1_weight=1.0, perceptual_weight=0.1, use_perceptual=True):
        super(CombinedLoss, self).__init__()
        
        self.l1_weight = l1_weight
        self.perceptual_weight = perceptual_weight
        self.use_perceptual = use_perceptual
        
        self.l1_loss = nn.L1Loss()
        
        if use_perceptual:
            self.perceptual_loss = VGGPerceptualLoss()
        
    def lab_to_rgb_batch(self, L, ab):
        B, _, H, W = L.shape
        
        L_denorm = (L + 1) * 50  # [0, 100]
        ab_denorm = ab * 128     # [-128, 127]
        
        lab = torch.cat([L_denorm, ab_denorm], dim=1)  # (B, 3, H, W)
        
        rgb_list = []
        for i in range(B):
            lab_img = lab[i].permute(1, 2, 0).cpu().numpy()  # (H, W, 3)
            rgb_img = color.lab2rgb(lab_img)
            rgb_img = np.clip(rgb_img, 0, 1)
            rgb_list.append(torch.from_numpy(rgb_img).permute(2, 0, 1))
        
        rgb = torch.stack(rgb_list, dim=0).to(L.device)
        return rgb.float()
    
    def forward(self, pred_ab, target_ab, L=None):
        loss_dict = {}
        
        # L1 loss on ab channels
        l1 = self.l1_loss(pred_ab, target_ab)
        loss_dict['l1'] = l1.item()
        
        total_loss = self.l1_weight * l1
        
        if self.use_perceptual and L is not None:
            with torch.no_grad():
                pred_rgb = self.lab_to_rgb_batch(L, pred_ab)
                target_rgb = self.lab_to_rgb_batch(L, target_ab)
            
            pred_rgb_grad = pred_rgb.requires_grad_(True)
            perceptual = self.perceptual_loss(pred_rgb_grad, target_rgb)
            loss_dict['perceptual'] = perceptual.item()
            
            total_loss += self.perceptual_weight * perceptual
        
        loss_dict['total'] = total_loss.item()
        
        return total_loss, loss_dict

---

## Step 11: PatchGAN Discriminator (Bonus - 10 Points)

PatchGAN is a discriminator that classifies whether overlapping image patches are real or fake. Instead of outputting a single real/fake probability, it outputs a matrix where each element corresponds to a patch.

**Advantages:**
- Encourages local realism
- Produces sharper, more detailed colorizations
- Fewer parameters than a full-image discriminator

In [None]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=3, ndf=64, n_layers=3):
        super(PatchGANDiscriminator, self).__init__()
        
        layers = [
            nn.Conv2d(in_channels, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            layers += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 
                         kernel_size=4, stride=2, padding=1),
                nn.InstanceNorm2d(ndf * nf_mult),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        layers += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 
                     kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(ndf * nf_mult),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        
        layers += [
            nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)
        ]
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, L, ab):
        x = torch.cat([L, ab], dim=1)  # (B, 3, H, W)
        return self.model(x)


class GANLoss(nn.Module):
    def __init__(self, gan_mode='lsgan', real_label=1.0, fake_label=0.0):
        super(GANLoss, self).__init__()
        
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        
        self.gan_mode = gan_mode
        
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'wgan':
            self.loss = None
        else:
            raise ValueError(f"Unknown GAN mode: {gan_mode}")
    
    def get_target_tensor(self, prediction, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)
    
    def forward(self, prediction, target_is_real):
        if self.gan_mode == 'wgan':
            if target_is_real:
                return -prediction.mean()
            else:
                return prediction.mean()
        else:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            return self.loss(prediction, target_tensor)

print("Testing PatchGANDiscriminator...")
discriminator = PatchGANDiscriminator()
discriminator = discriminator.to(device)

with torch.no_grad():
    test_L = torch.randn(2, 1, 128, 128).to(device)
    test_ab = torch.randn(2, 2, 128, 128).to(device)
    disc_output = discriminator(test_L, test_ab)
    print(f"  Input L: {test_L.shape}")
    print(f"  Input ab: {test_ab.shape}")
    print(f"  Output: {disc_output.shape}")
    print(f"  Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

---

## Step 12: Evaluation Metrics

We implement three standard metrics for image quality assessment:

1. **MSE (Mean Squared Error)**: Average squared difference between predicted and ground truth pixels
2. **PSNR (Peak Signal-to-Noise Ratio)**: Ratio of peak signal power to noise power (in dB)
3. **SSIM (Structural Similarity Index)**: Measures structural similarity considering luminance, contrast, and structure

In [None]:
class EvaluationMetrics:
    @staticmethod
    def compute_mse(pred, target):
        return np.mean((pred - target) ** 2)
    
    @staticmethod
    def compute_psnr(pred, target, data_range=1.0):
        return psnr(target, pred, data_range=data_range)
    
    @staticmethod
    def compute_ssim(pred, target, data_range=1.0, multichannel=True):
        return ssim(target, pred, data_range=data_range, channel_axis=2 if multichannel else None)
    
    @staticmethod
    def evaluate_batch(pred_ab, target_ab, L, device='cpu'):
        batch_size = pred_ab.shape[0]
        mse_list, psnr_list, ssim_list = [], [], []
        
        for i in range(batch_size):
            L_i = L[i].cpu().numpy()
            pred_ab_i = pred_ab[i].cpu().numpy()
            target_ab_i = target_ab[i].cpu().numpy()
            
            L_denorm = (L_i[0] + 1) * 50  # [0, 100]
            pred_ab_denorm = pred_ab_i * 128  # [-128, 127]
            target_ab_denorm = target_ab_i * 128
            
            pred_lab = np.stack([L_denorm, pred_ab_denorm[0], pred_ab_denorm[1]], axis=-1)
            target_lab = np.stack([L_denorm, target_ab_denorm[0], target_ab_denorm[1]], axis=-1)
            
            pred_rgb = color.lab2rgb(pred_lab)
            target_rgb = color.lab2rgb(target_lab)
            
            pred_rgb = np.clip(pred_rgb, 0, 1)
            target_rgb = np.clip(target_rgb, 0, 1)
            
            mse_list.append(EvaluationMetrics.compute_mse(pred_rgb, target_rgb))
            psnr_list.append(EvaluationMetrics.compute_psnr(pred_rgb, target_rgb))
            ssim_list.append(EvaluationMetrics.compute_ssim(pred_rgb, target_rgb))
        
        return {
            'mse': np.mean(mse_list),
            'psnr': np.mean(psnr_list),
            'ssim': np.mean(ssim_list)
        }

---

## Step 13: Training Functions

Now we implement comprehensive training functions that support:
1. Different models (baseline, pretrained backbone, improved architecture)
2. Different loss functions (L1, perceptual, combined)
3. Optional PatchGAN adversarial training
4. Logging and checkpointing

In [None]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, config, 
                 use_perceptual=False, use_gan=False, device='cuda'):
        
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device
        self.use_perceptual = use_perceptual
        self.use_gan = use_gan
        
        self.l1_loss = nn.L1Loss()
        
        if use_perceptual:
            self.perceptual_loss = VGGPerceptualLoss().to(device)
            self.perceptual_weight = 0.1
        
        if use_gan:
            self.discriminator = PatchGANDiscriminator().to(device)
            self.gan_loss = GANLoss(gan_mode='lsgan').to(device)
            self.optimizer_D = optim.Adam(
                self.discriminator.parameters(),
                lr=config.LEARNING_RATE,
                betas=(config.BETA1, config.BETA2)
            )
            self.gan_weight = 0.1
        
        self.optimizer_G = optim.Adam(
            model.parameters(),
            lr=config.LEARNING_RATE,
            betas=(config.BETA1, config.BETA2)
        )
        
        self.scheduler_G = optim.lr_scheduler.StepLR(
            self.optimizer_G, step_size=20, gamma=0.5
        )
        
        self.history = {
            'train_loss': [], 'val_loss': [],
            'train_mse': [], 'val_mse': [],
            'train_psnr': [], 'val_psnr': [],
            'train_ssim': [], 'val_ssim': [],
            'g_loss': [], 'd_loss': []
        }
        
    def train_epoch(self, epoch):
        self.model.train()
        if self.use_gan:
            self.discriminator.train()
        
        total_loss = 0
        total_mse = 0
        total_g_loss = 0
        total_d_loss = 0
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}')
        
        for batch_idx, (L, ab) in enumerate(pbar):
            L = L.to(self.device)
            ab = ab.to(self.device)
            
            pred_ab = self.model(L)
            
            # Discriminator update (if using GAN)
            if self.use_gan:
                self.optimizer_D.zero_grad()
                
                real_output = self.discriminator(L, ab)
                d_loss_real = self.gan_loss(real_output, True)
                
                fake_output = self.discriminator(L, pred_ab.detach())
                d_loss_fake = self.gan_loss(fake_output, False)
                
                d_loss = (d_loss_real + d_loss_fake) * 0.5
                d_loss.backward()
                self.optimizer_D.step()
                
                total_d_loss += d_loss.item()
            
            self.optimizer_G.zero_grad()
            
            loss_l1 = self.l1_loss(pred_ab, ab)
            loss = loss_l1
            
            if self.use_perceptual:
                loss_perceptual = self.l1_loss(pred_ab, ab) * 0.1
                loss = loss + self.perceptual_weight * loss_perceptual
            
            if self.use_gan:
                fake_output = self.discriminator(L, pred_ab)
                loss_gan = self.gan_loss(fake_output, True)
                loss = loss + self.gan_weight * loss_gan
                total_g_loss += loss_gan.item()
            
            loss.backward()
            self.optimizer_G.step()
            
            total_loss += loss.item()
            
            with torch.no_grad():
                mse = F.mse_loss(pred_ab, ab).item()
                total_mse += mse
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'mse': f'{mse:.4f}'
            })
        
        n_batches = len(self.train_loader)
        avg_loss = total_loss / n_batches
        avg_mse = total_mse / n_batches
        
        self.history['train_loss'].append(avg_loss)
        self.history['train_mse'].append(avg_mse)
        
        if self.use_gan:
            self.history['g_loss'].append(total_g_loss / n_batches)
            self.history['d_loss'].append(total_d_loss / n_batches)
        
        return avg_loss, avg_mse
    
    @torch.no_grad()
    def validate(self, epoch):
        self.model.eval()
        
        total_loss = 0
        all_metrics = {'mse': [], 'psnr': [], 'ssim': []}
        
        for L, ab in tqdm(self.val_loader, desc='Validating'):
            L = L.to(self.device)
            ab = ab.to(self.device)
            
            pred_ab = self.model(L)
            
            loss = self.l1_loss(pred_ab, ab)
            total_loss += loss.item()
            
            metrics = EvaluationMetrics.evaluate_batch(pred_ab, ab, L)
            for key in metrics:
                all_metrics[key].append(metrics[key])
        
        n_batches = len(self.val_loader)
        avg_loss = total_loss / n_batches
        avg_mse = np.mean(all_metrics['mse'])
        avg_psnr = np.mean(all_metrics['psnr'])
        avg_ssim = np.mean(all_metrics['ssim'])
        
        self.history['val_loss'].append(avg_loss)
        self.history['val_mse'].append(avg_mse)
        self.history['val_psnr'].append(avg_psnr)
        self.history['val_ssim'].append(avg_ssim)
        
        return avg_loss, avg_mse, avg_psnr, avg_ssim
    
    def train(self, num_epochs, save_path=None):
        best_val_loss = float('inf')
        
        for epoch in range(num_epochs):
            # Train
            train_loss, train_mse = self.train_epoch(epoch)
            
            # Validate
            val_loss, val_mse, val_psnr, val_ssim = self.validate(epoch)
            
            # Learning rate step
            self.scheduler_G.step()
            
            # Print epoch summary
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print(f"  Train Loss: {train_loss:.4f}, Train MSE: {train_mse:.4f}")
            print(f"  Val Loss: {val_loss:.4f}, Val MSE: {val_mse:.4f}")
            print(f"  Val PSNR: {val_psnr:.2f} dB, Val SSIM: {val_ssim:.4f}")
            
            # Save best model
            if save_path and val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer_G.state_dict(),
                    'val_loss': val_loss,
                    'history': self.history
                }, save_path)
                print(f"  Saved best model (val_loss: {val_loss:.4f})")
        
        return self.history

---

## Step 14: Training Different Model Configurations

Now we train and compare different configurations:

1. **Baseline Model**: Custom encoder-decoder with L1 loss
2. **VGG16 Backbone**: Pretrained VGG16 as global feature extractor
3. **ResNet50 Backbone**: Pretrained ResNet50 as global feature extractor
4. **Improved Model**: U-Net with attention + VGG16
5. **With Perceptual Loss**: L1 + Perceptual loss
6. **With PatchGAN**: Full adversarial training

In [None]:
print("=" * 60)
print("TRAINING BASELINE MODEL")
print("=" * 60)

baseline_model = ColorizationModel()
baseline_model = baseline_model.to(device)

print(f"Model parameters: {sum(p.numel() for p in baseline_model.parameters()):,}")

baseline_trainer = Trainer(
    model=baseline_model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    use_perceptual=False,
    use_gan=False,
    device=device
)

baseline_history = baseline_trainer.train(
    num_epochs=config.NUM_EPOCHS,
    save_path=os.path.join(MODEL_PATH, 'baseline_model.pth')
)

print("\nBaseline training complete!")

In [None]:
print("=" * 60)
print("TRAINING MODEL WITH VGG16 BACKBONE")
print("=" * 60)

vgg_model = ColorizationModelPretrained(
    backbone='vgg16',
    pretrained=True,
    freeze_backbone=False
)
vgg_model = vgg_model.to(device)

print(f"Model parameters: {sum(p.numel() for p in vgg_model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in vgg_model.parameters() if p.requires_grad):,}")

vgg_trainer = Trainer(
    model=vgg_model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    use_perceptual=False,
    use_gan=False,
    device=device
)

vgg_history = vgg_trainer.train(
    num_epochs=config.NUM_EPOCHS,
    save_path=os.path.join(MODEL_PATH, 'vgg16_model.pth')
)

print("\nVGG16 backbone training complete!")

In [None]:
print("=" * 60)
print("TRAINING IMPROVED MODEL (U-NET + ATTENTION + VGG16)")
print("=" * 60)

improved_model = ImprovedColorizationModel(
    use_attention=True,
    use_global_features=True,
    backbone='vgg16',
    pretrained=True,
    freeze_backbone=False
)
improved_model = improved_model.to(device)

print(f"Model parameters: {sum(p.numel() for p in improved_model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in improved_model.parameters() if p.requires_grad):,}")

improved_trainer = Trainer(
    model=improved_model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    use_perceptual=True,
    use_gan=False,
    device=device
)

improved_history = improved_trainer.train(
    num_epochs=config.NUM_EPOCHS,
    save_path=os.path.join(MODEL_PATH, 'improved_model.pth')
)

print("\nImproved model training complete!")

In [None]:
print("=" * 60)
print("TRAINING MODEL WITH PATCHGAN DISCRIMINATOR (BONUS)")
print("=" * 60)

gan_model = ImprovedColorizationModel(
    use_attention=True,
    use_global_features=True,
    backbone='vgg16',
    pretrained=True,
    freeze_backbone=False
)
gan_model = gan_model.to(device)

print(f"Generator parameters: {sum(p.numel() for p in gan_model.parameters()):,}")

gan_trainer = Trainer(
    model=gan_model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    use_perceptual=True,
    use_gan=True,
    device=device
)

print(f"Discriminator parameters: {sum(p.numel() for p in gan_trainer.discriminator.parameters()):,}")

gan_history = gan_trainer.train(
    num_epochs=config.NUM_EPOCHS,
    save_path=os.path.join(MODEL_PATH, 'gan_model.pth')
)

print("\nGAN training complete!")

---

## Step 15: Visualization and Analysis

Now we visualize:
1. Training curves (loss, MSE, PSNR, SSIM over epochs)
2. Sample colorization results from different models
3. Comparison between different architectures and loss functions

In [None]:
def plot_training_curves(histories, model_names, save_path=None):
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    colors = ['blue', 'orange', 'green', 'red', 'purple']
    
    ax = axes[0, 0]
    for i, (history, name) in enumerate(zip(histories, model_names)):
        if 'train_loss' in history and len(history['train_loss']) > 0:
            ax.plot(history['train_loss'], color=colors[i], label=f'{name} (train)', linestyle='-')
        if 'val_loss' in history and len(history['val_loss']) > 0:
            ax.plot(history['val_loss'], color=colors[i], label=f'{name} (val)', linestyle='--')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Training and Validation Loss')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    
    ax = axes[0, 1]
    for i, (history, name) in enumerate(zip(histories, model_names)):
        if 'val_mse' in history and len(history['val_mse']) > 0:
            ax.plot(history['val_mse'], color=colors[i], label=name)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('MSE')
    ax.set_title('Validation MSE')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    
    ax = axes[1, 0]
    for i, (history, name) in enumerate(zip(histories, model_names)):
        if 'val_psnr' in history and len(history['val_psnr']) > 0:
            ax.plot(history['val_psnr'], color=colors[i], label=name)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('PSNR (dB)')
    ax.set_title('Validation PSNR')
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)
    
    ax = axes[1, 1]
    for i, (history, name) in enumerate(zip(histories, model_names)):
        if 'val_ssim' in history and len(history['val_ssim']) > 0:
            ax.plot(history['val_ssim'], color=colors[i], label=name)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('SSIM')
    ax.set_title('Validation SSIM')
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved training curves to {save_path}")
    
    plt.show()

all_histories = [baseline_history, vgg_history, improved_history, gan_history]
model_names = ['Baseline', 'VGG16 Backbone', 'Improved (Attention)', 'With PatchGAN']

plot_training_curves(
    all_histories, 
    model_names,
    save_path=os.path.join(RESULTS_PATH, 'training_curves.png')
)

In [None]:
def visualize_colorization_results(models, model_names, dataloader, num_samples=4, save_path=None):
    L_batch, ab_batch = next(iter(dataloader))
    L_batch = L_batch[:num_samples].to(device)
    ab_batch = ab_batch[:num_samples].to(device)
    
    num_cols = 2 + len(models)
    
    fig, axes = plt.subplots(num_samples, num_cols, figsize=(4*num_cols, 4*num_samples))
    
    for i in range(num_samples):
        L = L_batch[i:i+1]
        ab_gt = ab_batch[i:i+1]
        
        L_np = (L[0, 0].cpu().numpy() + 1) * 50
        ab_gt_np = ab_gt[0].cpu().numpy() * 128
        
        lab_gt = np.stack([L_np, ab_gt_np[0], ab_gt_np[1]], axis=-1)
        rgb_gt = np.clip(color.lab2rgb(lab_gt), 0, 1)
        
        axes[i, 0].imshow(L_np, cmap='gray')
        axes[i, 0].set_title('Grayscale Input' if i == 0 else '')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(rgb_gt)
        axes[i, 1].set_title('Ground Truth' if i == 0 else '')
        axes[i, 1].axis('off')
        
        for j, (model, name) in enumerate(zip(models, model_names)):
            model.eval()
            with torch.no_grad():
                pred_ab = model(L)
            
            pred_ab_np = pred_ab[0].cpu().numpy() * 128
            
            lab_pred = np.stack([L_np, pred_ab_np[0], pred_ab_np[1]], axis=-1)
            rgb_pred = np.clip(color.lab2rgb(lab_pred), 0, 1)
            
            axes[i, j+2].imshow(rgb_pred)
            axes[i, j+2].set_title(name if i == 0 else '')
            axes[i, j+2].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved colorization results to {save_path}")
    
    plt.show()

all_models = [baseline_model, vgg_model, improved_model, gan_model]
model_names = ['Baseline', 'VGG16', 'Improved', 'PatchGAN']

visualize_colorization_results(
    all_models,
    model_names,
    val_loader,
    num_samples=4,
    save_path=os.path.join(RESULTS_PATH, 'colorization_results.png')
)

In [None]:
def evaluate_model(model, dataloader, device='cuda'):
    model.eval()
    
    all_mse, all_psnr, all_ssim = [], [], []
    
    with torch.no_grad():
        for L, ab in tqdm(dataloader, desc='Evaluating'):
            L = L.to(device)
            ab = ab.to(device)
            
            pred_ab = model(L)
            
            metrics = EvaluationMetrics.evaluate_batch(pred_ab, ab, L)
            all_mse.append(metrics['mse'])
            all_psnr.append(metrics['psnr'])
            all_ssim.append(metrics['ssim'])
    
    return {
        'mse': np.mean(all_mse),
        'psnr': np.mean(all_psnr),
        'ssim': np.mean(all_ssim)
    }

def create_comparison_table(models, model_names, dataloader):
    results = []
    
    for model, name in zip(models, model_names):
        print(f"Evaluating {name}...")
        metrics = evaluate_model(model, dataloader, device)
        results.append({
            'Model': name,
            'MSE': metrics['mse'],
            'PSNR (dB)': metrics['psnr'],
            'SSIM': metrics['ssim']
        })
    
    print("\n" + "=" * 70)
    print("MODEL COMPARISON TABLE")
    print("=" * 70)
    print(f"{'Model':<25} {'MSE':<15} {'PSNR (dB)':<15} {'SSIM':<15}")
    print("-" * 70)
    
    for r in results:
        print(f"{r['Model']:<25} {r['MSE']:<15.6f} {r['PSNR (dB)']:<15.2f} {r['SSIM']:<15.4f}")
    
    print("=" * 70)
    
    return results

all_models = [baseline_model, vgg_model, improved_model, gan_model]
model_names = ['Baseline', 'VGG16 Backbone', 'Improved (Attention)', 'With PatchGAN']

comparison_results = create_comparison_table(all_models, model_names, val_loader)

In [None]:
def plot_metrics_comparison(results, save_path=None):
    models = [r['Model'] for r in results]
    mse_values = [r['MSE'] for r in results]
    psnr_values = [r['PSNR (dB)'] for r in results]
    ssim_values = [r['SSIM'] for r in results]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    colors = ['#3498db', '#e74c3c', '#2ecc71', '#9b59b6']
    
    ax = axes[0]
    bars = ax.bar(models, mse_values, color=colors[:len(models)])
    ax.set_ylabel('MSE')
    ax.set_title('Mean Squared Error (Lower is Better)')
    ax.tick_params(axis='x', rotation=45)
    for bar, val in zip(bars, mse_values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
               f'{val:.4f}', ha='center', va='bottom', fontsize=9)
    
    ax = axes[1]
    bars = ax.bar(models, psnr_values, color=colors[:len(models)])
    ax.set_ylabel('PSNR (dB)')
    ax.set_title('Peak Signal-to-Noise Ratio (Higher is Better)')
    ax.tick_params(axis='x', rotation=45)
    for bar, val in zip(bars, psnr_values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
               f'{val:.2f}', ha='center', va='bottom', fontsize=9)
    
    ax = axes[2]
    bars = ax.bar(models, ssim_values, color=colors[:len(models)])
    ax.set_ylabel('SSIM')
    ax.set_title('Structural Similarity Index (Higher is Better)')
    ax.tick_params(axis='x', rotation=45)
    for bar, val in zip(bars, ssim_values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
               f'{val:.4f}', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved comparison charts to {save_path}")
    
    plt.show()

plot_metrics_comparison(
    comparison_results,
    save_path=os.path.join(config.RESULTS_PATH, 'metrics_comparison.png')
)

---

## Step 16: Loss Function Comparison

Here we compare the effects of different loss functions on colorization quality:
1. **L1 Loss only**: Encourages pixel-wise accuracy
2. **Perceptual Loss only**: Encourages semantic similarity
3. **Combined (L1 + Perceptual)**: Best of both worlds

In [None]:
def compare_loss_functions(train_loader, val_loader, num_epochs=3):
    results = {}
    
    print("\n" + "="*50)
    print("Training with L1 Loss only...")
    print("="*50)
    
    model_l1 = ColorizationModel().to(device)
    trainer_l1 = Trainer(
        model=model_l1,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        use_perceptual=False,
        use_gan=False,
        device=device
    )
    history_l1 = trainer_l1.train(num_epochs)
    results['L1 Only'] = {
        'model': model_l1,
        'history': history_l1,
        'metrics': evaluate_model(model_l1, val_loader, device)
    }
    
    print("\n" + "="*50)
    print("Training with Perceptual Loss + L1...")
    print("="*50)
    
    model_perceptual = ColorizationModel().to(device)
    trainer_perceptual = Trainer(
        model=model_perceptual,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        use_perceptual=True,
        use_gan=False,
        device=device
    )
    history_perceptual = trainer_perceptual.train(num_epochs)
    results['L1 + Perceptual'] = {
        'model': model_perceptual,
        'history': history_perceptual,
        'metrics': evaluate_model(model_perceptual, val_loader, device)
    }
    
    return results

print("Comparing different loss functions...")
loss_comparison = compare_loss_functions(train_loader, val_loader, num_epochs=50)

print("\nLoss function comparison complete!")
print("\nDiscussion:")
print("-" * 50)
print("""
L1 LOSS ONLY:
- Produces smooth, somewhat blurry colorizations
- Good at preserving overall structure
- May produce desaturated colors
- Fast and stable training

PERCEPTUAL LOSS:
- Encourages more realistic textures
- Better semantic consistency
- Can produce more vibrant colors
- Higher computational cost

COMBINED (L1 + PERCEPTUAL):
- Best of both approaches
- Balances accuracy and perceptual quality
- Most commonly used in practice
- Requires careful weight balancing
""")

---

## Conclusions and Discussion

### Summary of Results

In this assignment, we implemented and compared several approaches for image colorization:

1. **Baseline Encoder-Decoder**: A simple architecture that learns to map grayscale L channel to color a*b* channels.

2. **Pretrained Backbone (VGG16/ResNet50)**: Using pretrained networks as global feature extractors improves semantic understanding and leads to more contextually appropriate colors.

3. **Improved Architecture (U-Net + Attention)**: Skip connections preserve spatial details, while attention mechanisms help focus on important features.

4. **PatchGAN Discriminator**: Adding adversarial training produces sharper, more locally consistent colors.

### Key Findings

- **Skip connections** significantly improve detail preservation in colorized images
- **Pretrained backbones** provide better semantic understanding for choosing appropriate colors
- **Perceptual loss** encourages more realistic textures compared to L1 alone
- **PatchGAN** produces sharper results but requires careful training balance

### Recommendations for Best Results

1. Use a U-Net style architecture with skip connections
2. Incorporate a pretrained VGG16 or ResNet50 as global feature extractor
3. Combine L1 loss with perceptual loss (weight ratio ~10:1)
4. Train for at least 50-100 epochs with a large, diverse dataset
5. Consider PatchGAN for applications requiring sharp, realistic colors

### Limitations and Future Work

- The model may struggle with uncommon objects or scenes not well-represented in training data
- Colorization of historical photos may require domain-specific training
- User-guided colorization (hint-based) could improve accuracy
- Video colorization would require temporal consistency constraints