# Image Colorization Using Deep Convolutional Neural Networks
## BBM 418 - Assignment 3

**Student Name:** Ahmet Oğuzhan Kökülü  
**Student Number:** b2220356053

**Demo Video Link:** (sonradan ekleyecem)

**Dataset Drive Link:** https://drive.google.com/drive/folders/12lfa_UkMO9aBJGw989s1h1UDF-NbrDa9?usp=sharing

---

## Overview
This notebook implements an image colorization system using Deep Convolutional Neural Networks (DCNNs) with an encoder-decoder architecture. The model takes grayscale images (L channel from L*a*b* color space) as input and predicts the corresponding color channels (a* and b*).

## 1. Import Required Libraries

In [1]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from skimage import color, io
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

GPU memory allocated: 0.00 GB


## 2. Introduction

Image colorization is the process of adding realistic colors to grayscale images.

In this assignment, we use an encoder-decoder architecture with Deep Convolutional Neural Networks (DCNNs) to learn the mapping between grayscale images and their colorized versions.

### L*a*b* Color Space

We use the L*a*b* color space (CIELAB) which has three components:
- **L***: Lightness (0 for black to 100 for white)
- **a***: Green-to-red spectrum
- **b***: Blue-to-yellow spectrum

This color space is ideal for colorization because it separates brightness (L*) from color information (a*, b*). We use:
- **Input**: L channel (grayscale image)
- **Target**: a and b channels (color information)

## 3. Dataset Preparation

### 3.1 Configuration and Paths

In [2]:
# Local dataset configuration
DATASET_PATH = './dataset'  # Local path to dataset folder
OUTPUT_PATH = './outputs'  # Folder for saving outputs (images, plots)
MODEL_PATH = './models'    # Folder for saving trained models

# Create directories if they don't exist
import os
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(MODEL_PATH, exist_ok=True)

# Image configuration
IMG_SIZE = 256  # Architecture is hardcoded for 256x256
BATCH_SIZE = 1  # Batch size of 1 for 6GB GPU (only option that works)
NUM_WORKERS = 2

# Training configuration
NUM_EPOCHS = 30
LEARNING_RATE = 0.001
TRAIN_SPLIT = 0.8

print(f"Model path: {MODEL_PATH}")
print(f"Image size: {IMG_SIZE}x{IMG_SIZE}")


print(f"Batch size: {BATCH_SIZE}")
print(f"Training split: {TRAIN_SPLIT * 100}%")

Model path: ./models
Image size: 256x256
Batch size: 1
Training split: 80.0%


### 3.2 Prepare Dataset Split

This function creates train/test split indices from dataset.

In [3]:
def prepare_dataset_split(color_images_dir, train_split=0.8):
    """
    Prepare train/test split from color images in local directory.
    Returns lists of image file paths.
    
    Args:
        color_images_dir: Directory containing color images locally
        train_split: Ratio of training images
    
    Returns:
        train_files: List of training image paths
        test_files: List of test image paths
    """
    print(f"Scanning images from: {color_images_dir}")
    
    # Get all image files
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']:
        image_files.extend(list(Path(color_images_dir).glob(ext)))
    
    print(f"Found {len(image_files)} color images")
    
    if len(image_files) == 0:
        raise ValueError(f"No images found in {color_images_dir}. Please check the path.")
    
    # Shuffle with fixed seed for reproducibility
    random.seed(SEED)
    random.shuffle(image_files)
    
    # Split into train and test
    split_idx = int(len(image_files) * train_split)
    train_files = image_files[:split_idx]
    test_files = image_files[split_idx:]
    
    print(f"\nDataset split:")
    print(f"  Training images: {len(train_files)}")
    print(f"  Test images: {len(test_files)}")
    print(f"  Train/Test ratio: {train_split}/{1-train_split}")
    
    return train_files, test_files

# Prepare the dataset split from local folder
train_image_files, test_image_files = prepare_dataset_split(DATASET_PATH, TRAIN_SPLIT)

Scanning images from: ./dataset
Found 5000 color images

Dataset split:
  Training images: 4000
  Test images: 1000
  Train/Test ratio: 0.8/0.19999999999999996


### 3.3 Dataset Class

In [4]:
class ColorizationDataset(Dataset):
    """
    Dataset class for image colorization.
    Loads color images directly from Google Drive and converts to L*a*b* on-the-fly.
    Returns L channel as input and ab channels as target.
    """
    def __init__(self, image_files, img_size=256, transform=None):
        """
        Args:
            image_files: List of Path objects pointing to color images in Drive
            img_size: Size to resize images to
            transform: Optional transforms to apply
        """
        self.image_files = image_files
        self.img_size = img_size
        self.transform = transform
        
        print(f"Initialized dataset with {len(self.image_files)} images")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load color image directly from Drive
        img_path = str(self.image_files[idx])
        
        try:
            # Read RGB image
            img_rgb = io.imread(img_path)
            
            # Convert to RGB if grayscale or RGBA
            if len(img_rgb.shape) == 2:
                img_rgb = np.stack([img_rgb] * 3, axis=-1)
            elif img_rgb.shape[2] == 4:
                img_rgb = img_rgb[:, :, :3]
            
            # Resize
            img_pil = Image.fromarray(img_rgb)
            img_pil = img_pil.resize((self.img_size, self.img_size), Image.LANCZOS)
            img_rgb = np.array(img_pil)
            
            # Convert RGB to L*a*b* color space
            img_lab = color.rgb2lab(img_rgb).astype(np.float32)
            
            # Split into L and ab channels
            L = img_lab[:, :, 0]  # Lightness channel
            ab = img_lab[:, :, 1:]  # Color channels (a*, b*)
            
            # Normalize L to [0, 1] (L is in range [0, 100])
            L = L / 100.0
            
            # Normalize ab to [-1, 1] (ab is approximately in range [-128, 127])
            ab = ab / 128.0
            
            # Convert to tensors
            L = torch.FloatTensor(L).unsqueeze(0)  # Shape: (1, H, W)
            ab = torch.FloatTensor(ab).permute(2, 0, 1)  # Shape: (2, H, W)
            
            if self.transform:
                # Apply same transform to both L and ab
                seed = np.random.randint(2147483647)
                random.seed(seed)
                torch.manual_seed(seed)
                L = self.transform(L)
                random.seed(seed)
                torch.manual_seed(seed)
                ab = self.transform(ab)
            
            return L, ab, img_lab  # Return original lab_img for visualization
            
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a black image as fallback
            L = torch.zeros(1, self.img_size, self.img_size)
            ab = torch.zeros(2, self.img_size, self.img_size)
            lab_img = np.zeros((self.img_size, self.img_size, 3), dtype=np.float32)
            return L, ab, lab_img

### 3.4 Create Data Loaders

In [5]:
# Create datasets from Drive image files (uncomment when Drive is mounted and split is prepared)
train_dataset = ColorizationDataset(train_image_files, img_size=IMG_SIZE)
test_dataset = ColorizationDataset(test_image_files, img_size=IMG_SIZE)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                         num_workers=NUM_WORKERS, pin_memory=True)
# Using test_loader as val_loader (validation) since assignment uses train/test split
val_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)
test_loader = val_loader  # Alias for compatibility

print(f"Training set: {len(train_dataset)} images, {len(train_loader)} batches")
print(f"Validation/Test set: {len(test_dataset)} images, {len(val_loader)} batches")

Initialized dataset with 4000 images
Initialized dataset with 1000 images
Training set: 4000 images, 4000 batches
Validation/Test set: 1000 images, 1000 batches


### 3.5 Visualize Sample Images and L*a*b* Channels

In [6]:
def visualize_lab_channels(dataset, num_samples=3):
    """Visualize L, a, b channels separately for sample images."""
    fig, axes = plt.subplots(num_samples, 5, figsize=(20, 4 * num_samples))
    
    for i in range(num_samples):
        L, ab, lab_img = dataset[i]
        
        # Convert back from tensors
        L_img = (L.squeeze().numpy() * 100.0).astype(np.float32)
        ab_img = (ab.permute(1, 2, 0).numpy() * 128.0).astype(np.float32)
        
        # Reconstruct full LAB image
        lab_reconstructed = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
        lab_reconstructed[:, :, 0] = L_img
        lab_reconstructed[:, :, 1:] = ab_img
        
        # Convert to RGB
        rgb_img = color.lab2rgb(lab_reconstructed)
        
        # Plot
        axes[i, 0].imshow(L_img, cmap='gray')
        axes[i, 0].set_title(f'Sample {i+1}: L (Lightness)')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(ab_img[:, :, 0], cmap='RdYlGn_r')
        axes[i, 1].set_title('a* (Green-Red)')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(ab_img[:, :, 1], cmap='YlGnBu_r')
        axes[i, 2].set_title('b* (Blue-Yellow)')
        axes[i, 2].axis('off')
        
        axes[i, 3].imshow(rgb_img)
        axes[i, 3].set_title('RGB (Ground Truth)')
        axes[i, 3].axis('off')
        
        # Show grayscale version
        axes[i, 4].imshow(L_img, cmap='gray')
        axes[i, 4].set_title('Grayscale Input')
        axes[i, 4].axis('off')
    
    plt.tight_layout()
    plt.show()

# visualize_lab_channels(train_dataset, num_samples=3)

## 4. Model Architecture

### 4.1 Low-Level Feature Extraction Network

This network extracts local spatial features from the input grayscale image.

In [7]:
class LowLevelFeatureExtractor(nn.Module):
    """
    Low-level feature extraction network.
    Follows the architecture from Figure 3:
    Conv1-6 with progressive downsampling from 256x256 -> 128x128 -> 64x64 -> 32x32
    """
    def __init__(self):
        super(LowLevelFeatureExtractor, self).__init__()
        
        # Conv1-2: 256x256 -> 128x128
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)  # 256->128
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        
        # Conv3-4: 128x128 -> 64x64
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)  # 128->64
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        # Conv5-6: 64x64 -> 32x32
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)  # 64->32
        self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(512)
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        # Input: (B, 1, 256, 256)
        x = self.relu(self.bn1(self.conv1(x)))  # (B, 64, 128, 128)
        x = self.relu(self.bn2(self.conv2(x)))  # (B, 128, 128, 128)
        
        x = self.relu(self.bn3(self.conv3(x)))  # (B, 128, 64, 64)
        x = self.relu(self.bn4(self.conv4(x)))  # (B, 256, 64, 64)
        
        x = self.relu(self.bn5(self.conv5(x)))  # (B, 256, 32, 32)
        x = self.relu(self.bn6(self.conv6(x)))  # (B, 512, 32, 32)
        
        return x  # Output: (B, 512, 32, 32)

### 4.2 Global Feature Extraction Network (Baseline)

This network extracts global context features. We'll start with a simple baseline, then integrate pretrained models.

In [8]:
class GlobalFeatureExtractor(nn.Module):
    """
    Global feature extraction network (baseline).
    Extracts global context with Conv7-8 to produce a 1000-dimensional feature vector.
    """
    def __init__(self):
        super(GlobalFeatureExtractor, self).__init__()
        
        # Conv7-8: Further process features
        self.conv7 = nn.Conv2d(1, 512, kernel_size=3, stride=2, padding=1)  # Assuming input from original image
        self.bn7 = nn.BatchNorm2d(512)
        self.conv8 = nn.Conv2d(512, 256, kernel_size=3, stride=2, padding=1)
        self.bn8 = nn.BatchNorm2d(256)
        
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Fully connected layers to produce 1000-dim feature vector
        self.fc1 = nn.Linear(256, 1024)
        self.fc2 = nn.Linear(1024, 1000)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        # Input: (B, 1, H, W) - original grayscale image
        x = self.relu(self.bn7(self.conv7(x)))
        x = self.relu(self.bn8(self.conv8(x)))
        
        x = self.pool(x)  # (B, 256, 1, 1)
        x = x.view(x.size(0), -1)  # (B, 256)
        
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)  # (B, 1000)
        
        return x

### 4.3 Fusion Block

Combines low-level features with global features.

In [9]:
class FusionBlock(nn.Module):
    """
    Fusion block that combines low-level features (B, 512, 32, 32) 
    with global features (B, 1000) into a fused representation (B, 256, 32, 32).
    """
    def __init__(self):
        super(FusionBlock, self).__init__()
        
        # Project global features to spatial dimensions
        self.fc = nn.Linear(1000, 256 * 32 * 32)
        
        # Combine low-level (512 channels) + global (256 channels) = 768 channels
        self.conv_fuse = nn.Conv2d(512 + 256, 256, kernel_size=1)
        self.bn = nn.BatchNorm2d(256)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, low_level_features, global_features):
        # low_level_features: (B, 512, 32, 32)
        # global_features: (B, 1000)
        
        batch_size = low_level_features.size(0)
        
        # Project global features to spatial map
        global_spatial = self.fc(global_features)  # (B, 256*32*32)
        global_spatial = global_spatial.view(batch_size, 256, 32, 32)  # (B, 256, 32, 32)
        
        # Concatenate along channel dimension
        fused = torch.cat([low_level_features, global_spatial], dim=1)  # (B, 768, 32, 32)
        
        # Reduce channels
        fused = self.relu(self.bn(self.conv_fuse(fused)))  # (B, 256, 32, 32)
        
        return fused

### 4.4 Decoder Network

The decoder progressively upsamples the fused features to predict the ab channels.

In [10]:
class Decoder(nn.Module):
    """
    Decoder network that upsamples fused features from 32x32 to 256x256
    and predicts the ab channels.
    Follows the architecture from Figure 3: 32x32 -> 64x64 -> 128x128 -> 256x256
    """
    def __init__(self):
        super(Decoder, self).__init__()
        
        # Upsampling path
        # 32x32 -> 64x64
        self.upconv1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv1 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn1_2 = nn.BatchNorm2d(128)
        
        # 64x64 -> 128x128
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn2_2 = nn.BatchNorm2d(64)
        
        # 128x128 -> 256x256
        self.upconv3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn3_2 = nn.BatchNorm2d(32)
        
        # Final convolution to produce 2 channels (ab)
        self.final_conv = nn.Conv2d(32, 2, kernel_size=3, padding=1)
        
        self.relu = nn.ReLU(inplace=True)
        self.tanh = nn.Tanh()  # Output in range [-1, 1]
    
    def forward(self, x):
        # Input: (B, 256, 32, 32)
        
        # 32x32 -> 64x64
        x = self.relu(self.bn1(self.upconv1(x)))  # (B, 128, 64, 64)
        x = self.relu(self.bn1_2(self.conv1(x)))
        
        # 64x64 -> 128x128
        x = self.relu(self.bn2(self.upconv2(x)))  # (B, 64, 128, 128)
        x = self.relu(self.bn2_2(self.conv2(x)))
        
        # 128x128 -> 256x256
        x = self.relu(self.bn3(self.upconv3(x)))  # (B, 32, 256, 256)
        x = self.relu(self.bn3_2(self.conv3(x)))
        
        # Final output
        x = self.tanh(self.final_conv(x))  # (B, 2, 256, 256)
        
        return x

### 4.5 Complete Baseline Colorization Model

In [11]:
class ColorizationModel(nn.Module):
    """
    Complete colorization model combining all components.
    """
    def __init__(self, use_pretrained_global=None):
        super(ColorizationModel, self).__init__()
        
        self.low_level_extractor = LowLevelFeatureExtractor()
        
        if use_pretrained_global is None:
            # Use baseline global feature extractor
            self.global_extractor = GlobalFeatureExtractor()
        else:
            # Use pretrained model (to be implemented)
            self.global_extractor = use_pretrained_global
        
        self.fusion = FusionBlock()
        self.decoder = Decoder()
    
    def forward(self, x):
        # x: (B, 1, 256, 256) - grayscale L channel
        
        # Extract features
        low_level_feat = self.low_level_extractor(x)  # (B, 512, 32, 32)
        global_feat = self.global_extractor(x)  # (B, 1000)
        
        # Fuse features
        fused = self.fusion(low_level_feat, global_feat)  # (B, 256, 32, 32)
        
        # Decode to ab channels
        ab_pred = self.decoder(fused)  # (B, 2, 256, 256)
        
        return ab_pred

# # Test model instantiation
# model = ColorizationModel().to(device)
# print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

# # Test forward pass
# test_input = torch.randn(2, 1, 256, 256).to(device)
# test_output = model(test_input)
# print(f"Input shape: {test_input.shape}")
# print(f"Output shape: {test_output.shape}")

### 4.6 Pretrained Global Feature Extractors

We'll integrate pretrained models (VGG, ResNet, EfficientNet) as global feature extractors.

In [12]:
class PretrainedGlobalExtractor(nn.Module):
    """
    Global feature extractor using pretrained models.
    Supports VGG16, ResNet50, EfficientNet-B0.
    """
    def __init__(self, backbone='vgg16', freeze_backbone=False):
        super(PretrainedGlobalExtractor, self).__init__()
        
        self.backbone_name = backbone
        
        if backbone == 'vgg16':
            # Load pretrained VGG16
            vgg = models.vgg16(pretrained=True)
            # Modify first conv layer for grayscale input
            self.features = nn.Sequential(
                nn.Conv2d(1, 64, kernel_size=3, padding=1),  # Modified for grayscale
                *list(vgg.features.children())[1:]  # Rest of VGG features
            )
            feature_dim = 512
            
        elif backbone == 'resnet50':
            # Load pretrained ResNet50
            resnet = models.resnet50(pretrained=True)
            # Modify first conv layer for grayscale input
            self.features = nn.Sequential(
                nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),
                resnet.bn1,
                resnet.relu,
                resnet.maxpool,
                resnet.layer1,
                resnet.layer2,
                resnet.layer3,
                resnet.layer4
            )
            feature_dim = 2048
            
        elif backbone == 'efficientnet_b0':
            # Load pretrained EfficientNet-B0
            effnet = models.efficientnet_b0(pretrained=True)
            # Modify first conv layer for grayscale input
            first_conv = effnet.features[0][0]
            self.features = nn.Sequential(
                nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False),
                *list(effnet.features.children())[1:]
            )
            feature_dim = 1280
            
        else:
            raise ValueError(f"Unknown backbone: {backbone}")
        
        # Freeze backbone if requested
        if freeze_backbone:
            for param in self.features.parameters():
                param.requires_grad = False
        
        # Pooling and FC layers
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(feature_dim, 1024)
        self.fc2 = nn.Linear(1024, 1000)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        # Input: (B, 1, H, W)
        x = self.features(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)  # (B, 1000)
        
        return x

# Test pretrained extractors
# print("Testing pretrained global extractors...")
# for backbone in ['vgg16', 'resnet50', 'efficientnet_b0']:
#     try:
#         extractor = PretrainedGlobalExtractor(backbone=backbone, freeze_backbone=True)
#         test_in = torch.randn(2, 1, 256, 256)
#         test_out = extractor(test_in)
#         print(f"{backbone}: Input {test_in.shape} -> Output {test_out.shape}")
#     except Exception as e:
#         print(f"{backbone}: Error - {e}")

## 5. Loss Functions

### 5.1 L1 Loss (Baseline)

In [13]:
class L1Loss(nn.Module):
    """L1 loss for pixel-wise comparison."""
    def __init__(self):
        super(L1Loss, self).__init__()
        self.loss = nn.L1Loss()
    
    def forward(self, pred, target):
        return self.loss(pred, target)

### 5.2 Perceptual Loss

Uses VGG features to compute feature-based loss for more realistic colors.

In [14]:
class PerceptualLoss(nn.Module):
    """
    Perceptual loss using VGG16 features.
    Compares feature representations rather than pixel values.
    """
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        
        # Load pretrained VGG16
        vgg = models.vgg16(pretrained=True).features
        
        # Use specific layers for feature extraction
        self.slice1 = nn.Sequential(*list(vgg.children())[:4])   # relu1_2
        self.slice2 = nn.Sequential(*list(vgg.children())[4:9])  # relu2_2
        self.slice3 = nn.Sequential(*list(vgg.children())[9:16]) # relu3_3
        
        # Freeze VGG parameters
        for param in self.parameters():
            param.requires_grad = False
        
        self.mse_loss = nn.MSELoss()
    
    def forward(self, pred_ab, target_ab, L_channel):
        """
        Args:
            pred_ab: Predicted ab channels (B, 2, H, W)
            target_ab: Ground truth ab channels (B, 2, H, W)
            L_channel: L channel (B, 1, H, W)
        """
        # Reconstruct LAB images
        pred_lab = torch.cat([L_channel * 100, pred_ab * 128], dim=1)
        target_lab = torch.cat([L_channel * 100, target_ab * 128], dim=1)
        
        # Convert LAB to RGB for VGG (VGG expects RGB input)
        
        # Replicate ab channels to create 3-channel input (approximation)
        pred_input = torch.cat([pred_ab, pred_ab[:, :1, :, :]], dim=1)
        target_input = torch.cat([target_ab, target_ab[:, :1, :, :]], dim=1)
        
        # Extract features at different layers
        pred_feat1 = self.slice1(pred_input)
        target_feat1 = self.slice1(target_input)
        
        pred_feat2 = self.slice2(pred_feat1)
        target_feat2 = self.slice2(target_feat1)
        
        pred_feat3 = self.slice3(pred_feat2)
        target_feat3 = self.slice3(target_feat2)
        
        # Compute perceptual loss across multiple layers
        loss1 = self.mse_loss(pred_feat1, target_feat1)
        loss2 = self.mse_loss(pred_feat2, target_feat2)
        loss3 = self.mse_loss(pred_feat3, target_feat3)
        
        return loss1 + loss2 + loss3

# # Test perceptual loss
# perceptual_loss = PerceptualLoss()
# test_L = torch.randn(2, 1, 256, 256)
# test_ab_pred = torch.randn(2, 2, 256, 256)
# test_ab_target = torch.randn(2, 2, 256, 256)
# perc_loss_val = perceptual_loss(test_ab_pred, test_ab_target, test_L)
# print(f"Perceptual loss test: {perc_loss_val.item():.4f}")

## 6. Evaluation Metrics

In [15]:
def calculate_metrics(pred_ab, target_ab, L_channel):
    """
    Calculate MSE, PSNR, and SSIM metrics for colorization.
    
    Args:
        pred_ab: Predicted ab channels (B, 2, H, W) in range [-1, 1]
        target_ab: Target ab channels (B, 2, H, W) in range [-1, 1]
        L_channel: L channel (B, 1, H, W) in range [0, 1]
    
    Returns:
        Dictionary with MSE, PSNR, and SSIM values
    """
    batch_size = pred_ab.size(0)
    
    mse_sum = 0.0
    psnr_sum = 0.0
    ssim_sum = 0.0
    
    for i in range(batch_size):
        # Convert to numpy
        pred_ab_np = pred_ab[i].cpu().detach().numpy().transpose(1, 2, 0) * 128.0
        target_ab_np = target_ab[i].cpu().detach().numpy().transpose(1, 2, 0) * 128.0
        L_np = L_channel[i, 0].cpu().detach().numpy() * 100.0
        
        # Reconstruct LAB images
        pred_lab = np.zeros((pred_ab_np.shape[0], pred_ab_np.shape[1], 3), dtype=np.float32)
        pred_lab[:, :, 0] = L_np
        pred_lab[:, :, 1:] = pred_ab_np
        
        target_lab = np.zeros((target_ab_np.shape[0], target_ab_np.shape[1], 3), dtype=np.float32)
        target_lab[:, :, 0] = L_np
        target_lab[:, :, 1:] = target_ab_np
        
        # Convert to RGB
        pred_rgb = color.lab2rgb(pred_lab)
        target_rgb = color.lab2rgb(target_lab)
        
        # Ensure values are in [0, 1] range
        pred_rgb = np.clip(pred_rgb, 0, 1)
        target_rgb = np.clip(target_rgb, 0, 1)
        
        # Calculate MSE
        mse = np.mean((pred_rgb - target_rgb) ** 2)
        mse_sum += mse
        
        # Calculate PSNR
        if mse > 0:
            psnr_val = psnr(target_rgb, pred_rgb, data_range=1.0)
            psnr_sum += psnr_val
        
        # Calculate SSIM
        ssim_val = ssim(target_rgb, pred_rgb, channel_axis=2, data_range=1.0)
        ssim_sum += ssim_val
    
    return {
        'mse': mse_sum / batch_size,
        'psnr': psnr_sum / batch_size,
        'ssim': ssim_sum / batch_size
    }

## 7. Training Functions

In [16]:
def train_epoch(model, train_loader, criterion, optimizer, device, use_perceptual=False, perceptual_weight=0.1):
    """Train for one epoch."""
    model.train()
    epoch_loss = 0.0
    epoch_metrics = {'mse': 0.0, 'psnr': 0.0, 'ssim': 0.0}
    
    if use_perceptual:
        perceptual_loss_fn = PerceptualLoss().to(device)
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (L, ab, _) in enumerate(pbar):
        L, ab = L.to(device), ab.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        ab_pred = model(L)
        
        # Calculate loss
        pixel_loss = criterion(ab_pred, ab)
        
        if use_perceptual:
            perc_loss = perceptual_loss_fn(ab_pred, ab, L)
            total_loss = pixel_loss + perceptual_weight * perc_loss
        else:
            total_loss = pixel_loss
        
        # Backward pass
        total_loss.backward()
        optimizer.step()
        
        epoch_loss += total_loss.item()
        
        # Calculate metrics
        with torch.no_grad():
            metrics = calculate_metrics(ab_pred, ab, L)
            for key in metrics:
                epoch_metrics[key] += metrics[key]
        
        pbar.set_postfix({'loss': total_loss.item()})
    
    # Average over batches
    num_batches = len(train_loader)
    epoch_loss /= num_batches
    for key in epoch_metrics:
        epoch_metrics[key] /= num_batches
    
    return epoch_loss, epoch_metrics


def validate(model, val_loader, criterion, device, use_perceptual=False, perceptual_weight=0.1):
    """Validate the model."""
    model.eval()
    epoch_loss = 0.0
    epoch_metrics = {'mse': 0.0, 'psnr': 0.0, 'ssim': 0.0}
    
    if use_perceptual:
        perceptual_loss_fn = PerceptualLoss().to(device)
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for L, ab, _ in pbar:
            L, ab = L.to(device), ab.to(device)
            
            # Forward pass
            ab_pred = model(L)
            
            # Calculate loss
            pixel_loss = criterion(ab_pred, ab)
            
            if use_perceptual:
                perc_loss = perceptual_loss_fn(ab_pred, ab, L)
                total_loss = pixel_loss + perceptual_weight * perc_loss
            else:
                total_loss = pixel_loss
            
            epoch_loss += total_loss.item()
            
            # Calculate metrics
            metrics = calculate_metrics(ab_pred, ab, L)
            for key in metrics:
                epoch_metrics[key] += metrics[key]
            
            pbar.set_postfix({'loss': total_loss.item()})
    
    # Average over batches
    num_batches = len(val_loader)
    epoch_loss /= num_batches
    for key in epoch_metrics:
        epoch_metrics[key] /= num_batches
    
    return epoch_loss, epoch_metrics


def train_model(model, train_loader, val_loader, num_epochs, lr=0.001, 
                use_perceptual=False, perceptual_weight=0.1, model_name='model'):
    """
    Complete training loop.
    
    Args:
        model: The colorization model
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs: Number of epochs to train
        lr: Learning rate
        use_perceptual: Whether to use perceptual loss
        perceptual_weight: Weight for perceptual loss
        model_name: Name for saving the model
    """
    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    train_losses = []
    val_losses = []
    train_metrics_history = {'mse': [], 'psnr': [], 'ssim': []}
    val_metrics_history = {'mse': [], 'psnr': [], 'ssim': []}
    
    best_val_loss = float('inf')
    
    print(f"Training {model_name}...")
    print(f"Using perceptual loss: {use_perceptual}")
    if use_perceptual:
        print(f"Perceptual loss weight: {perceptual_weight}")
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss, train_metrics = train_epoch(model, train_loader, criterion, optimizer, 
                                                device, use_perceptual, perceptual_weight)
        
        # Validate
        val_loss, val_metrics = validate(model, val_loader, criterion, device, 
                                        use_perceptual, perceptual_weight)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Record history
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        for key in train_metrics:
            train_metrics_history[key].append(train_metrics[key])
            val_metrics_history[key].append(val_metrics[key])
        
        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        print(f"Train - MSE: {train_metrics['mse']:.4f}, PSNR: {train_metrics['psnr']:.2f}, SSIM: {train_metrics['ssim']:.4f}")
        print(f"Val   - MSE: {val_metrics['mse']:.4f}, PSNR: {val_metrics['psnr']:.2f}, SSIM: {val_metrics['ssim']:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            model_save_path = os.path.join(MODEL_PATH, f'{model_name}_best.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
            }, model_save_path)
            print(f"Saved best model to {model_save_path} with val_loss: {val_loss:.4f}")
    
    history = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_metrics': train_metrics_history,
        'val_metrics': val_metrics_history
    }
    
    return history

## 8. Visualization Functions

In [17]:
def plot_training_curves(history, model_name='Model'):
    """Plot training and validation curves."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(history['train_losses'], label='Train Loss')
    axes[0, 0].plot(history['val_losses'], label='Val Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title(f'{model_name} - Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # MSE
    axes[0, 1].plot(history['train_metrics']['mse'], label='Train MSE')
    axes[0, 1].plot(history['val_metrics']['mse'], label='Val MSE')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('MSE')
    axes[0, 1].set_title(f'{model_name} - MSE')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # PSNR
    axes[1, 0].plot(history['train_metrics']['psnr'], label='Train PSNR')
    axes[1, 0].plot(history['val_metrics']['psnr'], label='Val PSNR')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('PSNR (dB)')
    axes[1, 0].set_title(f'{model_name} - PSNR')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # SSIM
    axes[1, 1].plot(history['train_metrics']['ssim'], label='Train SSIM')
    axes[1, 1].plot(history['val_metrics']['ssim'], label='Val SSIM')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('SSIM')
    axes[1, 1].set_title(f'{model_name} - SSIM')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.show()


def visualize_results(model, dataset, device, num_samples=5):
    """
    Visualize colorization results.
    Shows: Input grayscale, Ground truth color, Predicted color
    """
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))
    
    with torch.no_grad():
        for i in range(num_samples):
            L, ab_true, _ = dataset[i]
            
            # Add batch dimension
            L_batch = L.unsqueeze(0).to(device)
            
            # Predict
            ab_pred = model(L_batch)
            
            # Convert to numpy
            L_np = (L.squeeze().numpy() * 100.0).astype(np.float32)
            ab_true_np = (ab_true.permute(1, 2, 0).numpy() * 128.0).astype(np.float32)
            ab_pred_np = (ab_pred[0].cpu().permute(1, 2, 0).numpy() * 128.0).astype(np.float32)
            
            # Reconstruct LAB images
            lab_true = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
            lab_true[:, :, 0] = L_np
            lab_true[:, :, 1:] = ab_true_np
            
            lab_pred = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
            lab_pred[:, :, 0] = L_np
            lab_pred[:, :, 1:] = ab_pred_np
            
            # Convert to RGB
            rgb_true = color.lab2rgb(lab_true)
            rgb_pred = color.lab2rgb(lab_pred)
            
            # Plot
            axes[i, 0].imshow(L_np, cmap='gray')
            axes[i, 0].set_title(f'Sample {i+1}: Grayscale Input')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(rgb_true)
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(rgb_pred)
            axes[i, 2].set_title('Predicted')
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()


def compare_models(models_dict, dataset, device, num_samples=3):
    """
    Compare results from multiple models side by side.
    
    Args:
        models_dict: Dictionary of {model_name: model}
        dataset: Test dataset
        device: Device to run on
        num_samples: Number of samples to compare
    """
    num_models = len(models_dict)
    fig, axes = plt.subplots(num_samples, num_models + 2, figsize=(4 * (num_models + 2), 4 * num_samples))
    
    for model in models_dict.values():
        model.eval()
    
    with torch.no_grad():
        for i in range(num_samples):
            L, ab_true, _ = dataset[i]
            L_batch = L.unsqueeze(0).to(device)
            
            # Convert ground truth
            L_np = (L.squeeze().numpy() * 100.0).astype(np.float32)
            ab_true_np = (ab_true.permute(1, 2, 0).numpy() * 128.0).astype(np.float32)
            
            lab_true = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
            lab_true[:, :, 0] = L_np
            lab_true[:, :, 1:] = ab_true_np
            rgb_true = color.lab2rgb(lab_true)
            
            # Plot input and ground truth
            axes[i, 0].imshow(L_np, cmap='gray')
            axes[i, 0].set_title(f'Sample {i+1}: Input' if i == 0 else 'Input')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(rgb_true)
            axes[i, 1].set_title('Ground Truth' if i == 0 else '')
            axes[i, 1].axis('off')
            
            # Plot predictions from each model
            for j, (model_name, model) in enumerate(models_dict.items()):
                ab_pred = model(L_batch)
                ab_pred_np = (ab_pred[0].cpu().permute(1, 2, 0).numpy() * 128.0).astype(np.float32)
                
                lab_pred = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
                lab_pred[:, :, 0] = L_np
                lab_pred[:, :, 1:] = ab_pred_np
                rgb_pred = color.lab2rgb(lab_pred)
                
                axes[i, j + 2].imshow(rgb_pred)
                axes[i, j + 2].set_title(model_name if i == 0 else '')
                axes[i, j + 2].axis('off')
    
    plt.tight_layout()
    plt.show()

## 9. Bonus: PatchGAN Discriminator

PatchGAN evaluates small patches of the image to determine if they look realistic, helping produce sharper and more locally consistent colors.

In [18]:
class PatchGANDiscriminator(nn.Module):
    """
    PatchGAN discriminator for colorization.
    Evaluates 70x70 patches to determine if they look realistic.
    
    Input: Concatenated L and ab channels (3 channels total)
    Output: Patch-wise predictions
    """
    def __init__(self, input_channels=3):
        super(PatchGANDiscriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *discriminator_block(input_channels, 64, normalization=False),  # 256 -> 128
            *discriminator_block(64, 128),   # 128 -> 64
            *discriminator_block(128, 256),  # 64 -> 32
            *discriminator_block(256, 512),  # 32 -> 16
            nn.Conv2d(512, 1, 4, padding=1)  # 16 -> 15 (output patch size)
        )
    
    def forward(self, img):
        # img: (B, 3, 256, 256) - L + ab channels
        return self.model(img)


def train_with_patchgan(generator, discriminator, train_loader, val_loader, num_epochs, 
                        lr_g=0.0002, lr_d=0.0002, lambda_l1=100, model_name='gan_model'):
    """
    Train colorization model with PatchGAN discriminator.
    
    Args:
        generator: Colorization model (generator)
        discriminator: PatchGAN discriminator
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs: Number of epochs
        lr_g: Generator learning rate
        lr_d: Discriminator learning rate
        lambda_l1: Weight for L1 loss
        model_name: Name for saving
    """
    criterion_GAN = nn.BCEWithLogitsLoss()
    criterion_L1 = nn.L1Loss()
    
    optimizer_G = optim.Adam(generator.parameters(), lr=lr_g, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))
    
    train_g_losses = []
    train_d_losses = []
    val_losses = []
    val_metrics_history = {'mse': [], 'psnr': [], 'ssim': []}
    
    best_val_loss = float('inf')
    
    print(f"Training {model_name} with PatchGAN...")
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        generator.train()
        discriminator.train()
        
        epoch_g_loss = 0.0
        epoch_d_loss = 0.0
        
        pbar = tqdm(train_loader, desc='Training')
        for L, ab_real, _ in pbar:
            batch_size = L.size(0)
            L, ab_real = L.to(device), ab_real.to(device)
            
            # Real and fake labels
            real_label = torch.ones((batch_size, 1, 15, 15), device=device)
            fake_label = torch.zeros((batch_size, 1, 15, 15), device=device)
            
            # ============================================
            # Train Generator
            # ============================================
            optimizer_G.zero_grad()
            
            # Generate fake ab channels
            ab_fake = generator(L)
            
            # Reconstruct LAB images for discriminator
            # L in range [0, 1], ab in range [-1, 1]
            # Normalize L to [-1, 1] for consistency
            L_norm = (L - 0.5) * 2
            fake_image = torch.cat([L_norm, ab_fake], dim=1)
            
            # Adversarial loss
            pred_fake = discriminator(fake_image)
            loss_GAN = criterion_GAN(pred_fake, real_label)
            
            # L1 loss
            loss_L1 = criterion_L1(ab_fake, ab_real)
            
            # Total generator loss
            loss_G = loss_GAN + lambda_l1 * loss_L1
            
            loss_G.backward()
            optimizer_G.step()
            
            # ============================================
            # Train Discriminator
            # ============================================
            optimizer_D.zero_grad()
            
            # Real images
            real_image = torch.cat([L_norm, ab_real], dim=1)
            pred_real = discriminator(real_image.detach())
            loss_D_real = criterion_GAN(pred_real, real_label)
            
            # Fake images
            pred_fake = discriminator(fake_image.detach())
            loss_D_fake = criterion_GAN(pred_fake, fake_label)
            
            # Total discriminator loss
            loss_D = (loss_D_real + loss_D_fake) * 0.5
            
            loss_D.backward()
            optimizer_D.step()
            
            epoch_g_loss += loss_G.item()
            epoch_d_loss += loss_D.item()
            
            pbar.set_postfix({'G_loss': loss_G.item(), 'D_loss': loss_D.item()})
        
        # Average losses
        epoch_g_loss /= len(train_loader)
        epoch_d_loss /= len(train_loader)
        train_g_losses.append(epoch_g_loss)
        train_d_losses.append(epoch_d_loss)
        
        # Validate
        generator.eval()
        val_loss = 0.0
        val_metrics = {'mse': 0.0, 'psnr': 0.0, 'ssim': 0.0}
        
        with torch.no_grad():
            for L, ab_real, _ in val_loader:
                L, ab_real = L.to(device), ab_real.to(device)
                
                ab_fake = generator(L)
                loss = criterion_L1(ab_fake, ab_real)
                val_loss += loss.item()
                
                metrics = calculate_metrics(ab_fake, ab_real, L)
                for key in metrics:
                    val_metrics[key] += metrics[key]
        
        val_loss /= len(val_loader)
        for key in val_metrics:
            val_metrics[key] /= len(val_loader)
        
        val_losses.append(val_loss)
        for key in val_metrics:
            val_metrics_history[key].append(val_metrics[key])
        
        print(f"G Loss: {epoch_g_loss:.4f}, D Loss: {epoch_d_loss:.4f}, Val Loss: {val_loss:.4f}")
        print(f"Val - MSE: {val_metrics['mse']:.4f}, PSNR: {val_metrics['psnr']:.2f}, SSIM: {val_metrics['ssim']:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            model_save_path = os.path.join(MODEL_PATH, f'{model_name}_best.pth')
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'val_loss': val_loss,
            }, model_save_path)
            print(f"Saved best model to {model_save_path} with val_loss: {val_loss:.4f}")
    
    history = {
        'train_g_losses': train_g_losses,
        'train_d_losses': train_d_losses,
        'val_losses': val_losses,
        'val_metrics': val_metrics_history
    }
    
    return history

# # Test PatchGAN discriminator
# discriminator = PatchGANDiscriminator(input_channels=3).to(device)
# test_img = torch.randn(2, 3, 256, 256).to(device)
# test_disc_out = discriminator(test_img)
# print(f"Discriminator test - Input: {test_img.shape}, Output: {test_disc_out.shape}")

## 10. Training Experiments

Now we'll train different model configurations and compare them:

1. **Baseline Model**: Low-level + baseline global extractor + L1 loss
2. **Pretrained Backbone**: Low-level + VGG16/ResNet50 + L1 loss  
3. **Perceptual Loss**: Best model + perceptual loss
4. **PatchGAN (Bonus)**: Best model + PatchGAN discriminator

### 10.1 Train Baseline Model

In [19]:
print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
# Train baseline model with L1 loss
baseline_model = ColorizationModel().to(device)
baseline_history = train_model(
    baseline_model, 
    train_loader, 
    val_loader, 
    num_epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,
    use_perceptual=False,
    model_name='baseline_model'
)

# Plot training curves
plot_training_curves(baseline_history, 'Baseline Model')

# Visualize results
visualize_results(baseline_model, test_dataset, device, num_samples=5)

GPU memory allocated: 0.00 GB
Training baseline_model...
Using perceptual loss: False

Epoch 1/30
Training baseline_model...
Using perceptual loss: False

Epoch 1/30


Training:  42%|████▏     | 1661/4000 [04:30<06:21,  6.14it/s, loss=0.0566] 



KeyboardInterrupt: 

### 10.2 Train Model with Pretrained Global Extractor (VGG16)

In [None]:
# Train model with VGG16 backbone
# vgg_extractor = PretrainedGlobalExtractor(backbone='vgg16', freeze_backbone=True).to(device)
# vgg_model = ColorizationModel(use_pretrained_global=vgg_extractor).to(device)

# vgg_history = train_model(
#     vgg_model, 
#     train_loader, 
#     val_loader, 
#     num_epochs=NUM_EPOCHS,
#     lr=LEARNING_RATE,
#     use_perceptual=False,
#     model_name='vgg16_model'
# )

# # Plot training curves
# plot_training_curves(vgg_history, 'VGG16 Model')

# # Visualize results
# visualize_results(vgg_model, test_dataset, device, num_samples=5)

### 10.3 Train Model with Perceptual Loss

In [None]:
# Train model with perceptual loss (using best backbone from previous experiments)
# perceptual_model = ColorizationModel(use_pretrained_global=vgg_extractor).to(device)

# perceptual_history = train_model(
#     perceptual_model, 
#     train_loader, 
#     val_loader, 
#     num_epochs=NUM_EPOCHS,
#     lr=LEARNING_RATE,
#     use_perceptual=True,
#     perceptual_weight=0.1,
#     model_name='perceptual_model'
# )

# # Plot training curves
# plot_training_curves(perceptual_history, 'Model with Perceptual Loss')

# # Visualize results
# visualize_results(perceptual_model, test_dataset, device, num_samples=5)

### 10.4 Train Model with PatchGAN (Bonus)

In [None]:
# Train model with PatchGAN discriminator
# gan_generator = ColorizationModel(use_pretrained_global=vgg_extractor).to(device)
# gan_discriminator = PatchGANDiscriminator(input_channels=3).to(device)

# gan_history = train_with_patchgan(
#     gan_generator,
#     gan_discriminator,
#     train_loader,
#     val_loader,
#     num_epochs=NUM_EPOCHS,
#     lr_g=0.0002,
#     lr_d=0.0002,
#     lambda_l1=100,
#     model_name='patchgan_model'
# )

# # Plot training curves
# fig, axes = plt.subplots(1, 2, figsize=(15, 5))
# axes[0].plot(gan_history['train_g_losses'], label='Generator Loss')
# axes[0].plot(gan_history['train_d_losses'], label='Discriminator Loss')
# axes[0].set_xlabel('Epoch')
# axes[0].set_ylabel('Loss')
# axes[0].set_title('PatchGAN Training Losses')
# axes[0].legend()
# axes[0].grid(True)

# axes[1].plot(gan_history['val_losses'], label='Validation Loss')
# axes[1].set_xlabel('Epoch')
# axes[1].set_ylabel('Loss')
# axes[1].set_title('PatchGAN Validation Loss')
# axes[1].legend()
# axes[1].grid(True)
# plt.tight_layout()
# plt.show()

# # Visualize results
# visualize_results(gan_generator, test_dataset, device, num_samples=5)

## 11. Results and Comparison

### 11.1 Compare All Models

In [None]:
# Compare all trained models
# models_to_compare = {
#     'Baseline': baseline_model,
#     'VGG16': vgg_model,
#     'Perceptual': perceptual_model,
#     'PatchGAN': gan_generator
# }

# compare_models(models_to_compare, test_dataset, device, num_samples=5)

### 11.2 Quantitative Results Summary

In [None]:
# Create summary table of all models
# def evaluate_model_on_testset(model, test_loader, device):
#     """Evaluate model on entire test set."""
#     model.eval()
#     total_metrics = {'mse': 0.0, 'psnr': 0.0, 'ssim': 0.0}
    
#     with torch.no_grad():
#         for L, ab, _ in test_loader:
#             L, ab = L.to(device), ab.to(device)
#             ab_pred = model(L)
            
#             metrics = calculate_metrics(ab_pred, ab, L)
#             for key in metrics:
#                 total_metrics[key] += metrics[key]
    
#     # Average over batches
#     for key in total_metrics:
#         total_metrics[key] /= len(test_loader)
    
#     return total_metrics

# # Evaluate all models
# results = {}
# for model_name, model in models_to_compare.items():
#     print(f"Evaluating {model_name}...")
#     metrics = evaluate_model_on_testset(model, test_loader, device)
#     results[model_name] = metrics

# # Create comparison table
# results_df = pd.DataFrame(results).T
# results_df = results_df.round(4)
# print("\n" + "="*60)
# print("QUANTITATIVE RESULTS SUMMARY")
# print("="*60)
# print(results_df)
# print("="*60)

# # Highlight best values
# print("\nBest MSE (lower is better):", results_df['mse'].min(), "->", results_df['mse'].idxmin())
# print("Best PSNR (higher is better):", results_df['psnr'].max(), "->", results_df['psnr'].idxmax())
# print("Best SSIM (higher is better):", results_df['ssim'].max(), "->", results_df['ssim'].idxmax())

# # Visualize metrics comparison
# fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# results_df['mse'].plot(kind='bar', ax=axes[0], color='skyblue')
# axes[0].set_title('MSE Comparison (Lower is Better)')
# axes[0].set_ylabel('MSE')
# axes[0].grid(True, alpha=0.3)

# results_df['psnr'].plot(kind='bar', ax=axes[1], color='lightgreen')
# axes[1].set_title('PSNR Comparison (Higher is Better)')
# axes[1].set_ylabel('PSNR (dB)')
# axes[1].grid(True, alpha=0.3)

# results_df['ssim'].plot(kind='bar', ax=axes[2], color='lightcoral')
# axes[2].set_title('SSIM Comparison (Higher is Better)')
# axes[2].set_ylabel('SSIM')
# axes[2].grid(True, alpha=0.3)

# plt.tight_layout()
# plt.show()

### 11.3 Failure Cases and Analysis

In [None]:
# Identify and visualize failure cases (images with lowest SSIM)
# def find_failure_cases(model, dataset, device, num_cases=5):
#     """Find images where the model performs poorly."""
#     model.eval()
    
#     ssim_scores = []
#     with torch.no_grad():
#         for i in range(len(dataset)):
#             L, ab, _ = dataset[i]
#             L_batch = L.unsqueeze(0).to(device)
#             ab_batch = ab.unsqueeze(0).to(device)
            
#             ab_pred = model(L_batch)
#             metrics = calculate_metrics(ab_pred, ab_batch, L_batch)
#             ssim_scores.append((i, metrics['ssim']))
    
#     # Sort by SSIM (ascending - worst first)
#     ssim_scores.sort(key=lambda x: x[1])
    
#     # Get worst cases
#     worst_indices = [idx for idx, _ in ssim_scores[:num_cases]]
    
#     # Visualize
#     fig, axes = plt.subplots(num_cases, 3, figsize=(12, 4 * num_cases))
    
#     with torch.no_grad():
#         for plot_idx, data_idx in enumerate(worst_indices):
#             L, ab_true, _ = dataset[data_idx]
#             L_batch = L.unsqueeze(0).to(device)
            
#             ab_pred = model(L_batch)
            
#             # Convert to numpy
#             L_np = (L.squeeze().numpy() * 100.0).astype(np.float32)
#             ab_true_np = (ab_true.permute(1, 2, 0).numpy() * 128.0).astype(np.float32)
#             ab_pred_np = (ab_pred[0].cpu().permute(1, 2, 0).numpy() * 128.0).astype(np.float32)
            
#             # Reconstruct LAB images
#             lab_true = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
#             lab_true[:, :, 0] = L_np
#             lab_true[:, :, 1:] = ab_true_np
            
#             lab_pred = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
#             lab_pred[:, :, 0] = L_np
#             lab_pred[:, :, 1:] = ab_pred_np
            
#             # Convert to RGB
#             rgb_true = color.lab2rgb(lab_true)
#             rgb_pred = color.lab2rgb(lab_pred)
            
#             # Calculate SSIM for this case
#             ssim_val = ssim(rgb_true, rgb_pred, channel_axis=2, data_range=1.0)
            
#             # Plot
#             axes[plot_idx, 0].imshow(L_np, cmap='gray')
#             axes[plot_idx, 0].set_title(f'Failure Case {plot_idx+1}: Input')
#             axes[plot_idx, 0].axis('off')
            
#             axes[plot_idx, 1].imshow(rgb_true)
#             axes[plot_idx, 1].set_title('Ground Truth')
#             axes[plot_idx, 1].axis('off')
            
#             axes[plot_idx, 2].imshow(rgb_pred)
#             axes[plot_idx, 2].set_title(f'Predicted (SSIM: {ssim_val:.3f})')
#             axes[plot_idx, 2].axis('off')
    
#     plt.suptitle('Failure Cases - Worst Colorization Results', fontsize=16, y=1.001)
#     plt.tight_layout()
#     plt.show()

# # Find failure cases for best model
# # find_failure_cases(vgg_model, test_dataset, device, num_cases=5)

## 12. Discussion

### 12.1 Model Performance Analysis

**Write your analysis here after training. Consider:**

1. **Baseline Model Performance:**
   - How well did the baseline encoder-decoder work?
   - What were the typical MSE, PSNR, and SSIM values?
   - Visual quality of colorizations

2. **Effect of Pretrained Global Extractors:**
   - Which pretrained backbone (VGG16/ResNet50/EfficientNet) performed best?
   - Did freezing the backbone help or hurt performance?
   - How much improvement over baseline?
   - Training time comparison

3. **Impact of Perceptual Loss:**
   - Did perceptual loss improve visual quality?
   - How did metrics compare to L1-only models?
   - Were colors more realistic and natural?
   - Any trade-offs observed?

4. **PatchGAN Discriminator (Bonus):**
   - Did adversarial training improve results?
   - Were colors sharper and more consistent?
   - Training stability and convergence
   - Comparison with non-GAN models

5. **Computational Cost:**
   - Training time for each model configuration
   - Memory requirements
   - Inference speed
   - Practical considerations for deployment

### 12.2 Failure Cases Discussion

**Discuss common failure patterns:**
- Types of images that are hard to colorize (e.g., unusual objects, ambiguous scenes)
- Color bleeding or inconsistencies
- Grayscale regions that stayed gray
- Over-saturation or under-saturation issues

### 12.3 Loss Function Comparison

**Compare different loss configurations:**
- L1 loss only
- Perceptual loss only  
- L1 + Perceptual loss combined
- Effect of perceptual loss weight

### 12.4 Architecture Improvements

**Discuss architectural modifications you made:**
- Skip connections (if added)
- Number of layers
- Fusion strategy changes
- Activation functions
- Normalization techniques

## 13. Conclusion and Future Work

### Summary

In this assignment, we implemented a Deep Convolutional Neural Network (DCNN) based image colorization system using an encoder-decoder architecture. Key achievements include:

1. **Dataset Creation**: Built a dataset of 5000+ images with proper L*a*b* color space conversion
2. **Baseline Model**: Implemented the encoder-decoder architecture from Figure 3 with low-level and global feature extractors
3. **Pretrained Backbones**: Integrated VGG16/ResNet50/EfficientNet as global feature extractors
4. **Loss Functions**: Experimented with L1 loss, perceptual loss, and their combinations
5. **Evaluation**: Used MSE, PSNR, and SSIM metrics for quantitative assessment
6. **PatchGAN (Bonus)**: Implemented adversarial training with PatchGAN discriminator

### Best Performing Model

**[Fill in after training]**
- Model configuration: ...
- Final metrics: MSE=..., PSNR=..., SSIM=...
- Why it performed best: ...

### Future Improvements

1. **Architecture Enhancements:**
   - Add U-Net style skip connections for better feature propagation
   - Implement attention mechanisms to focus on important regions
   - Use dilated convolutions for larger receptive fields

2. **Training Improvements:**
   - Data augmentation (flips, rotations, color jittering on input)
   - Progressive training (start with low resolution, increase gradually)
   - Ensemble multiple models for better results

3. **Loss Functions:**
   - Experiment with other perceptual loss layers
   - Add style loss for texture preservation
   - Use focal loss for hard examples

4. **Dataset:**
   - Increase dataset size (10K+ images)
   - Include diverse image categories
   - Balance color distribution

5. **Post-processing:**
   - Apply color smoothing filters
   - Histogram matching for better color consistency
   - Edge-aware filtering

### Lessons Learned

1. L*a*b* color space is well-suited for colorization tasks
2. Pretrained models provide strong features for global understanding
3. Perceptual loss helps generate more realistic colors
4. Adversarial training can improve local consistency
5. Proper evaluation requires both quantitative metrics and visual inspection

## 14. References

1. Zhang, R., Isola, P., & Efros, A. A. (2016). Colorful image colorization. In European conference on computer vision (pp. 649-666). Springer, Cham.

2. Iizuka, S., Simo-Serra, E., & Ishikawa, H. (2016). Let there be color! Joint end-to-end learning of global and local image priors for automatic image colorization with simultaneous classification. ACM Transactions on Graphics (ToG), 35(4), 1-11.

3. Isola, P., Zhu, J. Y., Zhou, T., & Efros, A. A. (2017). Image-to-image translation with conditional adversarial networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 1125-1134).

4. Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556.

5. He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).

6. Johnson, J., Alahi, A., & Fei-Fei, L. (2016). Perceptual losses for real-time style transfer and super-resolution. In European conference on computer vision (pp. 694-711). Springer, Cham.

---

**End of Notebook**

### Instructions for Completion:

1. **Create Your Dataset**:
   - Collect at least 5000 color images
   - Run the dataset creation function
   - Save to Google Drive and share with instructor

2. **Train Models**:
   - Uncomment and run training cells
   - Start with baseline model
   - Then train with pretrained backbones
   - Experiment with perceptual loss
   - (Optional) Train with PatchGAN

3. **Analyze Results**:
   - Fill in the discussion sections with your observations
   - Create comparison tables and visualizations
   - Identify and analyze failure cases

4. **Create Demo Video**:
   - Record a 5-minute video walking through your notebook
   - Show code execution, results, and comparisons
   - Upload to Drive/YouTube and add link at the top

5. **Submit**:
   - Export notebook as .ipynb
   - Create ZIP: `b2220356053.zip` containing the notebook
   - Submit via Hacettepe Submit system