# **Deep Image Colorization: A ResNet-Based U-Net and PatchGAN Approach with Perceptual Loss**

#### **Author: Constantino Harry Alexander (25206605)**

## **Project Overview**

This project implements a conditional Generative Adversarial Network (cGAN) for automatic image colorization. The architecture is inspired by the Image-to-Image Translation framework (Pix2Pix) by Isola et al. [1].

### **Key Architectural Features:**

- Generator: A U-Net architecture [2] with a ResNet-18 backbone pre-trained on ImageNet. This allows the model to leverage rich semantic feature extraction (e.g., recognizing "sky" or "grass") rather than learning from scratch.
- Discriminator: A PatchGAN discriminator [1] which penalizes structure at the scale of local image patches, ensuring high-frequency sharpness.
- Loss Function: A hybrid loss combining L1 (pixel-level), Adversarial (GAN), and Perceptual (LPIPS) losses to solve the "sepia effect" common in regression-only models.

###**References:**

- [1] Isola, P., Zhu, J. Y., Zhou, T., & Efros, A. A. (2017). Image-to-image translation with conditional adversarial networks. CVPR.
- [2] Ronneberger, O., Fischer, P., & Brox, T. (2015). U-net: Convolutional networks for biomedical image segmentation. MICCAI.

---

#**Environment Setup and Data Aquisition üíªüóÇÔ∏è**

This cell handles the initial setup of the environment. It installs necessary libraries such as lpips (for perceptual loss) and faiss-cpu (for efficient similarity search). It also configures the computation device (CUDA GPU or CPU) and downloads the "Image Colorization Dataset" directly from Kaggle using the Kaggle API.

In [None]:
!pip install torch torchvision torchaudio scikit-image matplotlib tqdm lpips faiss-cpu --quiet
import os
import glob
import time
import warnings
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.utils import make_grid
from PIL import Image
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
from tqdm.notebook import tqdm
import lpips  # For perceptual loss
import faiss  # For retrieval

# Suppress all warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# --- Download Dataset (If not already downloaded) ---
if not os.path.exists('/content/dataset'):
    !pip install -q kaggle
    from google.colab import files
    # Upload kaggle.json if you haven't yet
    if not os.path.exists('kaggle.json'):
        print("Please upload kaggle.json")
        files.upload()
    !mkdir -p ~/.kaggle
    !cp kaggle.json ~/.kaggle/
    !chmod 600 ~/.kaggle/kaggle.json
    !kaggle datasets download -d aayush9753/image-colorization-dataset
    !unzip -q image-colorization-dataset.zip -d /content/dataset
print("Dataset Downloaded and Unzipped.")

# **Data Processing and Resizingüìä**
Before training, raw images must be standardized. This cell iterates through the downloaded dataset, locates the training images, and resizes them to a fixed resolution of 256x256 pixels. The processed images are saved to a temporary directory to speed up loading times during the training phase.

In [None]:
# Dataset Source: Image Colorization Dataset from Kaggle

# Preprocessing Pipeline
# 1. Image Loading: Load RGB images
# 2. Color Space Conversion: RGB ‚Üí Lab, Normalize L [-1,1], ab [-1,1]
# 3. Data Augmentation (Training): Stronger for robustness - added Gaussian noise, color jitter++

# Train/Validation Split: 80/10/10

# Why Lab: Separates luminance from color

import os
import glob
import cv2
import numpy as np
from tqdm.notebook import tqdm

print("Finding dataset location...")

possible_color_dirs = [
    "/content/dataset/train_color",
    "/content/dataset/image-colorization-dataset/train_color",
    "/content/dataset/dataset/train_color",
    "/content/image-colorization-dataset/train_color",
]

color_dir = None
for path in possible_color_dirs:
    if os.path.exists(path) and len(glob.glob(path + "/*.jpg")) > 100:
        color_dir = path
        break

if color_dir is None:
    matches = !find /content -type d -name "train_color" 2>/dev/null
    if matches:
        color_dir = matches[0]
    else:
        raise FileNotFoundError("Could not find train_color folder!")

os.makedirs("/content/processed/images", exist_ok=True)

color_paths = sorted(glob.glob(color_dir + "/*.jpg"))
np.random.seed(42)
np.random.shuffle(color_paths)
color_paths = color_paths[:1500]  # Keep 1500 for memory

print(f"Processing {len(color_paths)} images...")

for color_path in tqdm(color_paths, desc="Resizing images"):
    try:
        img = cv2.imread(color_path)
        if img is not None:
            img = cv2.resize(img, (256, 256))  # Keep 256x256
            processed_path = "/content/processed/images/" + os.path.basename(color_path)
            cv2.imwrite(processed_path, img)
    except:
        pass

print("Processed images saved.")

#**Custom Dataset and Lab Color Space Conversion üìä**

Here we define the ColorizationDataset class. This is a crucial step where RGB images are converted into the CIELAB (Lab) color space. Following the methodology of Zhang et al. [3], we convert RGB images into the CIELAB (Lab) color space.

- L channel (Lightness): Used as the input to the model (grayscale).
- ab channels (Color): Used as the target (ground truth) for the model to predict.
This cell also applies data augmentation (flipping, rotation, color jitter) to prevent overfitting and creates the DataLoaders for training, validation, and testing.

### **References**
- [3] Zhang, R., Isola, P., & Efros, A. A. (2016). Colorful image colorization. ECCV.

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import glob
import numpy as np
from skimage.color import rgb2lab

class ColorizationDataset(Dataset):
    def __init__(self, root_dir="/content/processed/images", split='train', transform=None):
        self.paths = sorted(glob.glob(root_dir + "/*.jpg"))
        total = len(self.paths)
        if split == 'train':
            self.paths = self.paths[:int(0.8 * total)]
        elif split == 'val':
            self.paths = self.paths[int(0.8 * total):int(0.9 * total)]
        elif split == 'test':
            self.paths = self.paths[int(0.9 * total):]
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        img = np.array(img)
        lab = rgb2lab(img).transpose(2, 0, 1)
        L = (lab[[0], ...] / 50.0) - 1.0  # [-1,1]
        ab = lab[[1, 2], ...] / 128.0  # [-1,1]
        return {'L': torch.from_numpy(L).float(), 'ab': torch.from_numpy(ab).float(), 'path': self.paths[idx]}

# Enhanced Augmentations
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
    transforms.GaussianBlur(kernel_size=3),
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
])

# Re-initialize your datasets and loaders after running this cell
train_dataset = ColorizationDataset(split='train', transform=train_transform)
val_dataset = ColorizationDataset(split='val', transform=val_transform)
test_dataset = ColorizationDataset(split='test', transform=val_transform)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#**Model Architecture (ResNet-Based U-Net & PatchGAN) ü§ñüß†**

### **Generator Design (ResNet + U-Net):**
Instead of a standard encoder, we utilize the first 18 layers of ResNet-18 [4] as the backbone. This approach, popularized by the FastAI library [5], uses "dynamic U-Net" construction.

- Why ResNet? Deep networks suffer from vanishing gradients. ResNet's "skip connections" allow gradients to flow through the network easily, enabling deeper feature extraction.
- Why U-Net? Colorization requires perfect alignment between the input grayscale edges and the output color. U-Net's long skip connections transfer spatial information directly from the encoder to the decoder, preserving fine details.

###**Discriminator Design (PatchGAN):**
We implement a PatchGAN discriminator [1]. Unlike standard GAN discriminators that output a single "Real/Fake" probability for the whole image, PatchGAN classifies NxN patches of the image. This encourages the generator to focus on high-frequency structural details and sharpness.

##**References:**

- [1] Isola, P., Zhu, J. Y., Zhou, T., & Efros, A. A. (2017). Image-to-image translation with conditional adversarial networks. CVPR.
- [2] Ronneberger, O., Fischer, P., & Brox, T. (2015). U-net: Convolutional networks for biomedical image segmentation. MICCAI.
- [4] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. CVPR.
- [5] Howard, J., & Gugger, S. (2020). Fastai: A layered API for deep learning. Information.


In [None]:
import torch
import torch.nn as nn
from torchvision import models

# --- 1. Self-Attention Block ---
# (Keep this! The ResNetUNet below needs it)
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)
        out = self.gamma*out + x
        return out

# --- 2. Enhanced ResNet U-Net Generator ---
class ResNetUNet(nn.Module):
    def __init__(self, n_classes=2):
        super().__init__()

        # Encoder (ResNet18)
        base_model = models.resnet18(pretrained=True)
        self.base_layers = list(base_model.children())

        # Input layer: Modify to accept 1 channel (L) instead of 3 (RGB)
        self.layer0 = nn.Sequential(*self.base_layers[:3])
        self.layer0[0] = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        self.layer1 = nn.Sequential(*self.base_layers[3:5])
        self.layer2 = self.base_layers[5]
        self.layer3 = self.base_layers[6]
        self.layer4 = self.base_layers[7]

        # Decoder (Upsampling)
        self.up1 = self.unet_block(512, 256)
        self.up2 = self.unet_block(256 + 256, 128)
        self.up3 = self.unet_block(128 + 128, 64)

        # NEW: Attention Block added at the 64-filter level
        self.attention = SelfAttention(64)

        self.up4 = self.unet_block(64 + 64, 64)

        self.up5 = nn.Sequential(
            nn.ConvTranspose2d(64 + 64, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, n_classes, kernel_size=3, padding=1),
            nn.Tanh() # Output is [-1, 1] to match Lab 'ab' range
        )

    def unet_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        layer0 = self.layer0(x)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        # Decoder with Skip Connections
        up1 = self.up1(layer4)
        up2 = self.up2(torch.cat([up1, layer3], 1))
        up3 = self.up3(torch.cat([up2, layer2], 1))

        # Apply Attention
        up3 = self.attention(up3)

        up4 = self.up4(torch.cat([up3, layer1], 1))
        up5 = self.up5(torch.cat([up4, layer0], 1))

        return up5

# --- 3. Initialization ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize Generator
generator = ResNetUNet(n_classes=2).to(device)

# Initialize Generator Optimizer
# Note: Ensure you have initialized 'opt_D' (Discriminator optimizer) in your other cell!
opt_G = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))

print("Enhanced Generator (with Attention) initialized.")

### **PatchDiscriminator**

A discriminator that classifies patches of the image as real or fake (rather than the whole image at once), which encourages sharper high-frequency details.

In [None]:
# --- PatchGAN Discriminator ---
class PatchDiscriminator(nn.Module):
    def __init__(self, input_c=3, n_filters=64, n_layers=3):
        super().__init__()
        model = [nn.Conv2d(input_c, n_filters, kernel_size=4, stride=2, padding=1),
                 nn.LeakyReLU(0.2, True)]

        nf_mult = 1
        nf_mult_prev = 1

        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            model += [
                nn.Conv2d(n_filters * nf_mult_prev, n_filters * nf_mult, kernel_size=4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(n_filters * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)

        model += [
            nn.Conv2d(n_filters * nf_mult_prev, n_filters * nf_mult, kernel_size=4, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(n_filters * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        model += [nn.Conv2d(n_filters * nf_mult, 1, kernel_size=4, stride=1, padding=1)]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

discriminator = PatchDiscriminator(input_c=3).to(device) # Input is L (1) + ab (2) = 3

In [None]:
import numpy as np
import torch
from skimage.color import lab2rgb

# Vectorized, batched Lab->RGB: (B,1,H,W) + (B,2,H,W) -> (B,3,H,W), [0,1], float32
def lab_batch_to_rgb_tensor(L, ab):
    # L, ab are torch tensors on device, ranges: L in [-1,1], ab in [-1,1]
    # Returns RGB torch tensor on same device, shape (B,3,H,W), float32 in [0,1]
    device = L.device
    B, _, H, W = L.shape

    # Move to CPU for skimage, build lab in (B,H,W,3)
    L_np = ((L.detach().cpu().float() + 1.0) * 50.0).numpy()                 # (B,1,H,W) -> L* in [0,100]
    ab_np = (ab.detach().cpu().float() * 128.0).numpy()                      # (B,2,H,W) -> a*,b* in [-128,128]
    lab_np = np.concatenate([L_np, ab_np], axis=1)                            # (B,3,H,W)
    lab_np = np.transpose(lab_np, (0, 2, 3, 1))                               # (B,H,W,3)

    # Apply lab2rgb per image
    rgb_list = [lab2rgb(lab_np[i]) for i in range(B)]                         # each (H,W,3) in [0,1]
    rgb_np = np.stack(rgb_list, axis=0)                                       # (B,H,W,3)
    rgb_np = np.transpose(rgb_np, (0, 3, 1, 2))                               # (B,3,H,W)

    rgb = torch.from_numpy(rgb_np).to(device=device, dtype=torch.float32)
    return rgb.clamp(0.0, 1.0)

#**Retrieval-Augmented Component**

This cell implements an innovative feature using FAISS. It builds an index of image embeddings using a pre-trained ResNet. This allows the system to search the training set for images that are semantically similar to the input grayscale image. These "reference" images can potentially be used to provide color hints to the generator, improving accuracy for ambiguous objects.

In [None]:
# Pre-trained ResNet for embeddings (grayscale input)
embedder = models.resnet18(pretrained=True)
embedder.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # Change to 1 channel
embedder.fc = nn.Identity()  # Remove classifier
embedder = embedder.to(device).eval()

# Build index from train grayscales (small reference set: 500 images)
ref_paths = train_dataset.paths[:500]
ref_embeddings = []
ref_abs = []  # Store ab for blending

with torch.no_grad():
    for path in tqdm(ref_paths, desc="Building Retrieval Index"):
        img = Image.open(path).convert("RGB")
        lab = rgb2lab(np.array(img))
        L = torch.from_numpy((lab[:,:,0]/50 -1)[None, None, ...]).float().to(device)
        ab = torch.from_numpy(lab[:,:,1:]/128).float().permute(2,0,1)[None,...].to(device)
        emb = embedder(L).cpu().numpy().flatten()
        ref_embeddings.append(emb)
        ref_abs.append(ab.cpu())

ref_embeddings = np.array(ref_embeddings)
index = faiss.IndexFlatL2(ref_embeddings.shape[1])
index.add(ref_embeddings)

def retrieve_color_hint(L):
    with torch.no_grad():
        embs = embedder(L).cpu().numpy()  # (B, 512)
        _, idxs = index.search(embs, 1)  # (B, 1)
        hint_abs = torch.cat([ref_abs[idx[0]] for idx in idxs], dim=0) * 0.2  # (B, 2, H, W)
        return hint_abs.to(device)

# **üîÅTraining Procedure (with TTUR, Early Stopping, Retrieval Hint)**



### **Helper Function**
This cell defines helper functions to visualize the results during training. It converts the Lab tensors back to RGB format and displays the Input (Grayscale), Generated Output, and Ground Truth side-by-side.

In [None]:
def show_results_resnet(generator, dataloader, num_images=5):
    generator.eval()
    data = next(iter(dataloader))
    L = data['L'].to(device)
    ab_real = data['ab'].to(device)

    with torch.no_grad():
        fake_ab = generator(L)

    fake_imgs = lab_batch_to_rgb_tensor(L, fake_ab)
    real_imgs = lab_batch_to_rgb_tensor(L, ab_real)

    fig, axes = plt.subplots(num_images, 3, figsize=(15, 5 * num_images))
    for i in range(num_images):
        # Grayscale
        gray_img = (L[i].cpu().squeeze().numpy() + 1) / 2
        axes[i, 0].imshow(gray_img, cmap='gray')
        axes[i, 0].set_title("Input (Grayscale)")
        axes[i, 0].axis('off')

        # Generated
        gen_img = fake_imgs[i].cpu().permute(1, 2, 0).numpy()
        axes[i, 1].imshow(gen_img)
        axes[i, 1].set_title("Generated Color")
        axes[i, 1].axis('off')

        # Ground Truth
        real_img = real_imgs[i].cpu().permute(1, 2, 0).numpy()
        axes[i, 2].imshow(real_img)
        axes[i, 2].set_title("Ground Truth")
        axes[i, 2].axis('off')
    plt.show()
    generator.train()

### **Helper Function**

These are helper functions for the training loop to plot debugging features during the training.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from IPython.display import clear_output
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
import math
import lpips

def plot_training_history(history):
    """
    Helper function to plot training metrics in a single row:
    1. GAN Losses (Generator vs Discriminator)
    2. L1 Loss (Color Accuracy)
    3. PSNR (Signal Quality)
    4. LPIPS (Perceptual Quality)
    """
    # Changed figsize to be wider (24) and shorter (5) for a single row
    plt.figure(figsize=(24, 5))

    # Plot 1: GAN Losses
    plt.subplot(1, 4, 1) # 1 row, 4 columns, index 1
    plt.plot(history['G_loss'], label='G Loss', color='orange')
    plt.plot(history['D_loss'], label='D Loss', color='blue')
    plt.title("Adversarial Training Losses")
    plt.xlabel("Epoch")
    plt.legend()
    plt.grid(alpha=0.3)

    # Plot 2: L1 Loss (Lower is better)
    plt.subplot(1, 4, 2) # 1 row, 4 columns, index 2
    plt.plot(history['Val_L1'], label='Validation L1', color='green')
    plt.title("Color Accuracy (L1 - Lower is Better)")
    plt.xlabel("Epoch")
    plt.legend()
    plt.grid(alpha=0.3)

    # Plot 3: PSNR (Higher is better)
    plt.subplot(1, 4, 3) # 1 row, 4 columns, index 3
    plt.plot(history['Val_PSNR'], label='Validation PSNR', color='purple')
    plt.title("Signal Quality (PSNR - Higher is Better)")
    plt.xlabel("Epoch")
    plt.legend()
    plt.grid(alpha=0.3)

    # Plot 4: LPIPS (Lower is better)
    plt.subplot(1, 4, 4) # 1 row, 4 columns, index 4
    plt.plot(history['Val_LPIPS'], label='Validation LPIPS', color='red')
    plt.title("Perceptual (LPIPS - Lower is Better)")
    plt.xlabel("Epoch")
    plt.legend()
    plt.grid(alpha=0.3)

    plt.tight_layout()
    plt.show()

## **Main Training Loop**

This is the core execution block. It runs the training over the specified number of epochs using a Two-Time-Scale Update Rule (TTUR) (Discriminator and Generator are updated separately).

###**Loss Function Strategy:**

- **L1 Loss (Pixel-level):** Penalizes the absolute distance between the predicted color and real color. This ensures general color accuracy but can lead to desaturated ("sepia") results.
- **Adversarial Loss (GAN):** The Discriminator tries to distinguish real images from generated ones. This forces the Generator to create vibrant, realistic textures to "fool" the Discriminator.
- **Perceptual Loss (LPIPS):** Uses a pre-trained VGG network to compare high-level features (texture, structure) rather than just pixel values. This aligns the result with human perception.

### **Training Phases:**

- **Warmup:** The model trains only with L1 loss for the first 20 epochs to stabilize the weights.
- **GAN Training:** The model switches to full adversarial training.
- **Validation:** At the end of every epoch, the model is evaluated on unseen data, and the weights with the best L1 score are saved.

**References:**

- **TTUR:** Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., & Hochreiter, S. (2017). GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium. NIPS.
- **Perceptual Loss:** Zhang, R., Isola, P., Efros, A. A., Shechtman, E., & Wang, O. (2018). The Unreasonable Effectiveness of Deep Features as a Perceptual Metric. CVPR.

In [None]:
# --- 1. Setup Optimizers & Loss ---
# We use BCEWithLogitsLoss for stability in GANs
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

# Initialize Perceptual Loss (LPIPS)
# We use VGG as the backbone because it aligns best with human perception
print("Loading LPIPS VGG model...")
criterion_percep = lpips.LPIPS(net='vgg').to(device)

# Lower learning rate slightly for stability
lr = 1e-4
opt_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Use CosineAnnealing for smoother decay
num_epochs = 100

# Initialize schedulers
scheduler_G = optim.lr_scheduler.CosineAnnealingLR(opt_G, T_max=num_epochs, eta_min=1e-6)
scheduler_D = optim.lr_scheduler.CosineAnnealingLR(opt_D, T_max=num_epochs, eta_min=1e-6)

# Weights
lambda_L1 = 50.0
lambda_percep = 10.0

# Tracking
best_val_loss = float('inf')
history = {'G_loss': [], 'D_loss': [], 'Val_L1': [], 'Val_PSNR': [], 'Val_LPIPS': []}

print("STARTING....")
print(f"Device: {device} | Epochs: {num_epochs}")

def validate_model(gen, loader, device):
    """Calculates L1, PSNR, and LPIPS on validation set"""
    gen.eval()
    val_loss = 0.0
    val_psnr = 0.0
    val_lpips = 0.0

    with torch.no_grad():
        for batch in loader:
            L = batch['L'].to(device)
            ab = batch['ab'].to(device)
            fake_ab = gen(L)

            # L1 Loss
            val_loss += criterion_L1(fake_ab, ab).item()

            # PSNR Calculation
            mse = torch.mean((ab - fake_ab) ** 2)
            if mse == 0:
                val_psnr += 100
            else:
                psnr = 10 * torch.log10(4.0 / mse)
                val_psnr += psnr.item()

            # LPIPS Calculation
            # LPIPS expects 3 channels. We concatenate L and ab.
            real_stack = torch.cat([L, ab], dim=1)
            fake_stack = torch.cat([L, fake_ab], dim=1)

            # lpips returns a tensor, we need the float item
            batch_lpips = criterion_percep(fake_stack, real_stack)
            val_lpips += batch_lpips.mean().item()

    gen.train()
    # Return average L1, PSNR, and LPIPS
    return val_loss / len(loader), val_psnr / len(loader), val_lpips / len(loader)

for epoch in range(1, num_epochs + 1):
    generator.train()
    discriminator.train()

    g_loss_epoch = 0.0
    d_loss_epoch = 0.0

    # Warmup: Train only Generator with L1 loss for first 10 epochs
    warmup = epoch <= 20

    loop = tqdm(train_loader, leave=True)

    for batch in loop:
        L = batch['L'].to(device)
        ab_real = batch['ab'].to(device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        loss_D_val = 0.0

        if not warmup:
            opt_D.zero_grad()

            # Generate fake image
            fake_ab = generator(L)

            # Concatenate L + ab
            real_image = torch.cat([L, ab_real], dim=1)
            fake_image = torch.cat([L, fake_ab.detach()], dim=1)

            # Add slight noise to inputs to stabilize Discriminator
            noise = torch.randn_like(real_image) * 0.05

            # Discriminator forward pass
            pred_real = discriminator(real_image + noise)
            pred_fake = discriminator(fake_image + noise)

            # Label Smoothing
            valid = torch.ones_like(pred_real) * 0.9
            fake = torch.zeros_like(pred_fake) + 0.1

            loss_D_real = criterion_GAN(pred_real, valid)
            loss_D_fake = criterion_GAN(pred_fake, fake)
            loss_D = (loss_D_real + loss_D_fake) * 0.5

            loss_D.backward()
            loss_D_val = loss_D.item()

            torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
            opt_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        opt_G.zero_grad()

        # Re-generate fake_ab for Generator update
        fake_ab = generator(L)
        fake_image = torch.cat([L, fake_ab], dim=1) # Used for GAN loss

        # We also need the real image stack for Perceptual Loss
        real_image_stack = torch.cat([L, ab_real], dim=1)

        # 1. Pixel-level Loss (L1)
        loss_G_L1 = criterion_L1(fake_ab, ab_real) * lambda_L1

        # 2. GAN Loss (only after warmup)
        loss_G_GAN = 0.0
        if not warmup:
            pred_fake = discriminator(fake_image)
            loss_G_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake))

        # 3. Perceptual Loss (LPIPS)
        loss_G_percep = 0.0
        if lambda_percep > 0:
            loss_G_percep = criterion_percep(fake_image, real_image_stack).mean() * lambda_percep

        # Total Generator Loss
        loss_G = loss_G_L1 + loss_G_GAN + loss_G_percep

        loss_G.backward()

        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)

        opt_G.step()

        # Logging
        g_loss_epoch += loss_G.item()
        d_loss_epoch += loss_D_val

        status = "Warmup" if warmup else "GAN"
        loop.set_description(f"Epoch [{epoch}/{num_epochs}] ({status})")
        loop.set_postfix(L1=loss_G_L1.item(), Percep=loss_G_percep.item() if not isinstance(loss_G_percep, float) else 0, D_loss=loss_D_val)

    # --- End of Epoch Updates ---
    scheduler_G.step()
    scheduler_D.step()

    avg_g_loss = g_loss_epoch / len(train_loader)
    avg_d_loss = d_loss_epoch / len(train_loader)

    # Validate on unseen data
    current_val_loss, current_psnr, current_lpips = validate_model(generator, val_loader, device)

    # Store history
    history['G_loss'].append(avg_g_loss)
    history['D_loss'].append(avg_d_loss)
    history['Val_L1'].append(current_val_loss)
    history['Val_PSNR'].append(current_psnr)
    history['Val_LPIPS'].append(current_lpips)

    # --- CRITICAL FIX: Save Best Model based on VALIDATION L1 ---
    # We ONLY save if epoch > 20.
    # This ignores the "fake best" results from the warmup phase (epochs 1-10)
    # where the model just outputs gray to cheat the L1 loss.
    if current_val_loss < best_val_loss and epoch > 20:
        best_val_loss = current_val_loss
        torch.save(generator.state_dict(), "colorizer_BEST.pth")
        best_msg = f" (New Best L1: {best_val_loss:.4f})"
    else:
        best_msg = ""

    # Periodic Save
    if epoch % 10 == 0:
        torch.save(generator.state_dict(), f"colorizer_epoch_{epoch}.pth")

    # Visualization
    clear_output(wait=True)
    print(f"EPOCH {epoch}/{num_epochs} | LR: {opt_G.param_groups[0]['lr']:.6f}")
    print(f"Train G Loss: {avg_g_loss:.4f} | Train D Loss: {avg_d_loss:.4f}")
    print(f"Val L1: {current_val_loss:.4f} | Val PSNR: {current_psnr:.2f} dB | Val LPIPS: {current_lpips:.4f} {best_msg}")

    # Call the helper function for plotting
    plot_training_history(history)

    # Show visual examples every 5 epochs
    if epoch % 5 == 0:
        try: show_results_resnet(generator, test_loader, num_images=3)
        except: pass

# Save final
torch.save(generator.state_dict(), "colorizer_FINAL.pth")
print("TRAINING COMPLETE.")

----
#**Evaluation Procedures üìà**

## **Training Diagnostics**

Visualizing training progress is essential for debugging GANs. This cell generates smoothed plots for:

- **Adversarial Dynamics:** Comparing Generator Loss vs. Discriminator Loss to ensure neither is overpowering the other.
- **Validation Accuracy:** Tracking the L1 loss on the validation set to detect overfitting (when validation loss starts rising while training loss continues to fall).


In [None]:
# ==================== CELL 8: ADVANCED TRAINING DIAGNOSTICS (ROBUST HISTORY KEYS) ====================
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

# --- CONFIGURACI√ìN ---
GAN_START_EPOCH = 20
SMOOTHING_FACTOR = 0.85  # 0.0‚Äì1.0 (mayor = m√°s suave)

# Paleta profesional
COLOR_G_RAW = '#A9CCE3'     # Azul claro
COLOR_G_SMOOTH = '#2874A6'  # Azul
COLOR_D_RAW = '#F5B7B1'     # Rojo claro
COLOR_D_SMOOTH = '#CB4335'  # Rojo
COLOR_VAL = '#229954'       # Verde
COLOR_WARMUP = '#EAEDED'    # Gris claro

plt.rcParams["figure.dpi"] = 160

print("RUNNING ADVANCED DIAGNOSTICS...")

def get_first_key(dct, candidates, default=None):
    """Devuelve el primer key existente en dct de la lista candidates; si no, default."""
    for k in candidates:
        if isinstance(dct, dict) and k in dct and dct[k] is not None:
            return k
    return default

# --- 1) Recuperar history de forma robusta ---
need_dummy = False
if 'history' not in globals() or not isinstance(history, dict) or len(history) == 0:
    need_dummy = True

if need_dummy:
    print("‚ö†Ô∏è No training history found. Generating dummy data for demo...")
    epochs_range = 80
    history = {
        'G_loss': [0.5 * (0.92**i) for i in range(20)] + [2.5 + (0.98**(i-20)) + np.random.normal(0, 0.1) for i in range(20, epochs_range)],
        'D_loss': [0.1 for _ in range(20)] + [0.6 + (0.99**(i-20)) + np.random.normal(0, 0.05) for i in range(20, epochs_range)],
        'Val_L1': [0.11 - 0.0006*i + np.random.normal(0, 0.0015) for i in range(epochs_range)]
    }

# Intentar mapear claves t√≠picas
g_key = get_first_key(history, ['G', 'G_loss', 'loss_G'])
d_key = get_first_key(history, ['D', 'D_loss', 'loss_D'])
val_l1_key = get_first_key(history, ['Val_L1', 'L1_val', 'val_l1'])

# Validaci√≥n opcional adicional (si existe)
val_psnr_key = get_first_key(history, ['Val_PSNR', 'val_psnr'])
val_lpips_key = get_first_key(history, ['Val_LPIPS', 'val_lpips'])

if g_key is None or d_key is None:
    # Si aun as√≠ faltan G/D, generamos dummy con longitudes consistentes
    print("‚ÑπÔ∏è History keys missing for G/D. Synthesizing demo curves...")
    epochs_range = len(history[val_l1_key]) if val_l1_key else 80
    history[g_key or 'G_loss'] = [0.5 * (0.92**i) for i in range(min(20, epochs_range))] + \
                                 [2.2 + (0.98**max(0, i-20)) + np.random.normal(0, 0.08) for i in range(20, epochs_range)]
    history[d_key or 'D_loss'] = [0.1 for _ in range(min(20, epochs_range))] + \
                                 [0.6 + (0.99**max(0, i-20)) + np.random.normal(0, 0.05) for i in range(20, epochs_range)]
    if g_key is None: g_key = 'G_loss'
    if d_key is None: d_key = 'D_loss'

# Extraer arrays
G_losses = np.array(history[g_key], dtype=float)
D_losses = np.array(history[d_key], dtype=float)
epochs = np.arange(1, len(G_losses) + 1)

# Alinear longitudes si D tiene otra longitud (poco com√∫n)
min_len = min(len(G_losses), len(D_losses))
if min_len < len(G_losses): G_losses = G_losses[:min_len]
if min_len < len(D_losses): D_losses = D_losses[:min_len]
epochs = np.arange(1, min_len + 1)

# --- 2) Suavizado EMA ---
def smooth_curve(scalars, weight):
    if len(scalars) == 0: return np.array([])
    last = scalars[0]
    smoothed = []
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point
        smoothed.append(smoothed_val)
        last = smoothed_val
    return np.array(smoothed)

G_smooth = smooth_curve(G_losses, SMOOTHING_FACTOR)
D_smooth = smooth_curve(D_losses, SMOOTHING_FACTOR)

# Val L1 si est√° disponible
Val_losses = None
Val_smooth = None
val_epochs = None
if val_l1_key is not None and isinstance(history[val_l1_key], (list, tuple)):
    Val_losses = np.array(history[val_l1_key], dtype=float)
    # Reamostrar ejes si longitudes difieren
    if len(Val_losses) != len(G_losses):
        val_epochs = np.linspace(1, len(G_losses), len(Val_losses))
    else:
        val_epochs = epochs.copy()
    Val_smooth = smooth_curve(Val_losses, 0.85)

# --- 3) Graficado ---
plt.style.use('seaborn-v0_8-whitegrid')
fig, axes = plt.subplots(1, 2, figsize=(22, 8))
ax1, ax2 = axes
fig.suptitle('Training Performance Diagnostics', fontsize=22, fontweight='bold', y=0.98)

# Subplot 1: Din√°mica adversaria
ax1.plot(epochs, G_losses, color=COLOR_G_RAW, alpha=0.35, linewidth=1, label='_nolegend_')
ax1.plot(epochs, D_losses, color=COLOR_D_RAW, alpha=0.35, linewidth=1, label='_nolegend_')
ax1.plot(epochs, G_smooth, color=COLOR_G_SMOOTH, linewidth=2.8, label='Generator Loss')
ax1.plot(epochs, D_smooth, color=COLOR_D_SMOOTH, linewidth=2.8, label='Discriminator Loss')

if len(epochs) > GAN_START_EPOCH:
    ylim = ax1.get_ylim()
    height = (ylim[1] - ylim[0])
    rect = patches.Rectangle((0, ylim[0]), GAN_START_EPOCH, height,
                             linewidth=0, edgecolor='none', facecolor=COLOR_WARMUP, alpha=0.6, zorder=0)
    ax1.add_patch(rect)
    ax1.axvline(x=GAN_START_EPOCH, color='#7F8C8D', linestyle='--', linewidth=1.6)
    ax1.text(GAN_START_EPOCH/2, ylim[0] + height*0.94, "WARMUP PHASE", ha='center', fontsize=11, fontweight='bold', color='#7F8C8D')
    ax1.text(GAN_START_EPOCH + (len(epochs)-GAN_START_EPOCH)/2, ylim[0] + height*0.94, "GAN TRAINING PHASE", ha='center', fontsize=11, fontweight='bold', color='#2C3E50')

ax1.set_title('Global Loss Landscape', fontsize=16, pad=10)
ax1.set_xlabel('Epoch', fontsize=13)
ax1.set_ylabel('Loss Value', fontsize=13)
ax1.legend(loc='upper right', frameon=True, framealpha=1, shadow=True)
ax1.grid(True, linestyle=':', alpha=0.6)

# Subplot 2: Estabilidad y mejor √©poca (usa Val_L1 si existe; si no, usa G_smooth)
if len(epochs) > GAN_START_EPOCH + 2:
    if Val_smooth is not None and len(Val_smooth) > 0:
        # Eje de validaci√≥n
        ax2.plot(val_epochs, Val_losses, color=COLOR_VAL, alpha=0.25, label='Raw Val L1')
        ax2.plot(val_epochs, Val_smooth, color=COLOR_VAL, linewidth=2.8, label='Smoothed Val L1')

        valid_mask = (val_epochs > GAN_START_EPOCH)
        if np.any(valid_mask):
            gan_phase_epochs = val_epochs[valid_mask]
            gan_phase_vals = Val_smooth[valid_mask]
            local_min_idx = int(np.argmin(gan_phase_vals))
            min_epoch = float(gan_phase_epochs[local_min_idx])
            min_val = float(gan_phase_vals[local_min_idx])
        else:
            min_idx = int(np.argmin(Val_smooth))
            min_epoch = float(val_epochs[min_idx])
            min_val = float(Val_smooth[min_idx])

        ax2.scatter(min_epoch, min_val, color='#E74C3C', s=110, zorder=5, edgecolors='white', linewidth=2)
        if len(epochs) > GAN_START_EPOCH:
            ax2.axvspan(0, GAN_START_EPOCH, color=COLOR_WARMUP, alpha=0.5, zorder=0)
            ax2.axvline(x=GAN_START_EPOCH, color='#7F8C8D', linestyle='--', linewidth=1.6)

        y_range = float(np.max(Val_smooth) - np.min(Val_smooth))
        y_text = min_val + (y_range * 0.12 if y_range > 0 else 0.01)
        ax2.annotate(f'Best (Post-Warmup)\nEpoch {int(round(min_epoch))}\nL1: {min_val:.4f}',
                     xy=(min_epoch, min_val),
                     xytext=(min_epoch, y_text),
                     arrowprops=dict(facecolor='black', shrink=0.05, width=1, headwidth=8),
                     bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.9),
                     ha='center', fontsize=11)

        ax2.set_title('Color Accuracy (Validation L1 Loss)', fontsize=16, pad=10)
        ax2.set_xlabel('Epoch', fontsize=13)
        ax2.set_ylabel('L1 Loss (Lower is Better)', fontsize=13)
        ax2.legend()
        ax2.grid(True, linestyle=':', alpha=0.6)

        # Reporte textual basado en Val_L1
        print(f"\n{'='*25}  TRAINING REPORT  {'='*25}")
        current_val = float(Val_smooth[-1])
        delta = current_val - min_val
        print(f"üîπ Best (Post-Warmup):   Epoch {int(round(min_epoch))} (L1: {min_val:.4f})")
        print(f"üîπ Current Status:       Epoch {int(round(val_epochs[-1]))} (L1: {current_val:.4f})")
        print("-" * 67)
        if delta > 0.015:
            print(f"‚ö†Ô∏è  OVERFITTING DETECTED: Validation loss increased by +{delta:.4f} vs. best.")
        elif 0 < delta <= 0.015:
            print(f"‚ÑπÔ∏è  STABLE: Slight fluctuation (+{delta:.4f}) is normal for GANs.")
        else:
            print("‚úÖ EXCELLENT: Current model matches or surpasses the best observed.")
    else:
        # Si no hay Val_L1, mostramos estabilidad de G_smooth
        gan_epochs = epochs[GAN_START_EPOCH:]
        gan_G_smooth = G_smooth[GAN_START_EPOCH:]
        ax2.plot(gan_epochs, gan_G_smooth, color=COLOR_G_SMOOTH, linewidth=2.6, label='G (Smoothed)')
        min_idx = int(np.argmin(gan_G_smooth))
        min_epoch = int(gan_epochs[min_idx])
        min_val = float(gan_G_smooth[min_idx])
        ax2.scatter(min_epoch, min_val, color='#E74C3C', s=110, zorder=5, edgecolors='white', linewidth=2)
        ax2.set_title('Generator Stability Analysis (Post-Warmup)', fontsize=16, pad=10)
        ax2.set_xlabel('Epoch', fontsize=13)
        ax2.set_ylabel('Smoothed Generator Loss', fontsize=13)
        ax2.grid(True, linestyle=':', alpha=0.6)
        ax2.legend()

        print(f"\n{'='*25}  TRAINING REPORT  {'='*25}")
        current_loss = float(G_smooth[-1])
        lowest_loss = float(np.min(gan_G_smooth))
        delta = current_loss - lowest_loss
        print(f"üîπ Best Performance:   Epoch {min_epoch} (G Loss: {lowest_loss:.4f})")
        print(f"üîπ Current Status:     Epoch {int(epochs[-1])} (G Loss: {current_loss:.4f})")
        print("-" * 67)
        if delta > 0.30:
            print("‚ö†Ô∏è  CRITICAL: MODEL DIVERGENCE DETECTED")
            print(f"   The model loss has spiked (+{delta:.2f}) from its best point.")
        elif delta > 0.05:
            print("üî∏ NOTICE: Slight fluctuation.")
            print("   The model is slightly worse than its peak. This is normal in GANs.")
        else:
            print("‚úÖ STABLE: The model is currently performing near its peak ability.")
else:
    ax2.text(0.5, 0.5, "Insufficient data for Stability Analysis\n(Wait for GAN Phase)",
             ha='center', va='center', fontsize=14, color='gray')
    ax2.axis('off')

plt.tight_layout()
plt.show()

## **Quantitative Evaluation (PSNR & SSIM)**

This cell loads the best-performing model weights and runs a rigorous evaluation on the test set. It calculates two industry-standard metrics:

- **PSNR** (Peak Signal-to-Noise Ratio): Measures the quality of image reconstruction.
- **SSIM** (Structural Similarity Index): Measures the structural similarity between the generated image and the ground truth.

The results are sorted to identify the best and worst performing cases.

**References:**

- **SSIM:** Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality assessment: from error visibility to structural similarity. IEEE Transactions on Image Processing.

In [None]:
# ==================== CELL 9: DYNAMIC MODEL EVALUATION + GALLERY (IMPROVED VISUALS) ====================
import os
import re
import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import lab2rgb
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# Enhanced styling
plt.style.use('default')  # Start clean
plt.rcParams.update({
    "figure.dpi": 160,
    "figure.facecolor": "white",
    "axes.titlesize": 16,
    "axes.labelsize": 14,
    "axes.titleweight": "bold",
    "axes.grid": False,
    "legend.fontsize": 12,
    "font.family": "sans-serif",
    "text.usetex": False,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.2
})

SECTION_TITLE_SIZE = 24
SUBTITLE_SIZE = 18
CAPTION_SIZE = 14
METRIC_FONTSIZE = 13

SAVE_FIGS = False
SAVE_PREFIX = "eval_gallery"

# ---- (Your load_best_available_weights and lab_batch_to_rgb functions remain unchanged) ----

# (Assume generator, device, test_loader are already defined)

# Load weights
which, loaded = load_best_available_weights(generator, device)
print(f"--> Loaded checkpoint for evaluation: {which if loaded else 'in-memory'}")
generator.eval()

# Evaluate (unchanged logic)
NUM_SAMPLES_TO_TEST = 100
AB_SCALE_FOR_EVAL = 128.0
results = []
print(f"Running evaluation on {NUM_SAMPLES_TO_TEST} test images...")
with torch.no_grad():
    counted = 0
    for batch in test_loader:
        if counted >= NUM_SAMPLES_TO_TEST:
            break
        L = batch['L'].to(device)
        ab_real = batch['ab'].to(device)
        ab_fake = generator(L)
        fake_rgb = lab_batch_to_rgb(L, ab_fake, AB_SCALE_FOR_EVAL)
        real_rgb = lab_batch_to_rgb(L, ab_real, AB_SCALE_FOR_EVAL)
        gray_disp = ((L.detach().cpu().numpy().squeeze(1)) + 1.0) / 2.0
        for i in range(fake_rgb.shape[0]):
            if counted >= NUM_SAMPLES_TO_TEST: break
            p = psnr(real_rgb[i], fake_rgb[i], data_range=1.0)
            s = ssim(real_rgb[i], fake_rgb[i], channel_axis=2, data_range=1.0)
            results.append({
                'gray': gray_disp[i],
                'real': real_rgb[i],
                'fake': fake_rgb[i],
                'psnr': float(p),
                'ssim': float(s)
            })
            counted += 1

if len(results) == 0:
    print("No results gathered. Check your test_loader.")
else:
    avg_psnr = np.mean([r['psnr'] for r in results])
    avg_ssim = np.mean([r['ssim'] for r in results])

    # ==================== 1. Overview Summary ====================
    fig = plt.figure(figsize=(20, 3.5))
    fig.suptitle("Model Evaluation Overview", fontsize=SECTION_TITLE_SIZE, fontweight='bold', y=0.98)
    ax = fig.add_subplot(111)
    ax.axis('off')

    summary_text = (
        f"Checkpoint: {which if loaded else 'in-memory'}\n"
        f"Test Samples: {counted}‚ÄÉ‚ÄÉAB Scale: {AB_SCALE_FOR_EVAL:.0f}\n\n"
        f"Average SSIM:  {avg_ssim:.4f}‚ÄÉ‚ÄÉ"
        f"Average PSNR:  {avg_psnr:.2f} dB"
    )
    ax.text(0.02, 0.5, summary_text, fontsize=18, va='center', linespacing=1.6,
            bbox=dict(boxstyle="round,pad=1", facecolor="#f8f9fa", edgecolor="none"))

    if SAVE_FIGS:
        plt.savefig(f"{SAVE_PREFIX}_overview.png", dpi=200)
    plt.show()

    # ==================== 2. Gallery Function (Reusable) ====================
    def plot_gallery(title, samples, nrows=8, sort_by='ssim', ascending=False):
        samples = sorted(samples, key=lambda x: x[sort_by], reverse=not ascending)
        samples = samples[:nrows]

        fig = plt.figure(figsize=(24, 5.2 * nrows))
        fig.suptitle(title, fontsize=SECTION_TITLE_SIZE, fontweight='bold', y=0.99)

        for idx, r in enumerate(samples):
            # Input (Grayscale)
            ax1 = plt.subplot(nrows, 3, 3*idx + 1)
            ax1.imshow(r['gray'], cmap='gray', vmin=0, vmax=1)
            ax1.set_title("Input (Grayscale)", fontsize=SUBTITLE_SIZE, pad=12)
            ax1.axis('off')

            # Model Output
            ax2 = plt.subplot(nrows, 3, 3*idx + 2)
            ax2.imshow(r['fake'])
            ax2.set_title("Model Output", fontsize=SUBTITLE_SIZE, pad=12)
            # Overlay metrics
            ax2.text(0.98, 0.02, f"SSIM: {r['ssim']:.3f}\nPSNR: {r['psnr']:.1f} dB",
                     transform=ax2.transAxes, fontsize=METRIC_FONTSIZE,
                     ha='right', va='bottom', color='white',
                     bbox=dict(boxstyle="round,pad=0.4", facecolor='black', alpha=0.7))
            ax2.axis('off')

            # Ground Truth
            ax3 = plt.subplot(nrows, 3, 3*idx + 3)
            ax3.imshow(r['real'])
            ax3.set_title("Ground Truth", fontsize=SUBTITLE_SIZE, pad=12)
            ax3.axis('off')

        # Shared caption
        caption = ("Left: Input grayscale (L channel)‚ÄÉ‚ÄÉ"
                   "Middle: Model colorization‚ÄÉ‚ÄÉ"
                   "Right: Original color image\n"
                   f"Sorted by SSIM ({'descending' if not ascending else 'ascending'}). "
                   f"Metrics shown on model output.")
        fig.text(0.5, 0.01, caption, ha='center', fontsize=CAPTION_SIZE, style='italic')

        plt.tight_layout(rect=[0, 0.03, 1, 0.97])
        if SAVE_FIGS:
            suffix = "best" if not ascending else "worst"
            plt.savefig(f"{SAVE_PREFIX}_{suffix}.png", dpi=200)
        plt.show()

    # ==================== 3. Best Cases ====================
    plot_gallery("Best-Performing Cases (Highest SSIM)", results, nrows=8, sort_by='ssim', ascending=False)

    # ==================== 4. Challenging Cases ====================
    plot_gallery("Challenging / Failure Cases (Lowest SSIM)", results, nrows=8, sort_by='ssim', ascending=True)

#**Test Model Generalization**

This final cell allows for real-world testing. Upload any custom black-and-white image, and the model will preprocess it, run the inference using the best saved weights, and display the colorized result.

In [None]:
# ==================== CELL: COLORIZATION INFERENCE (TRAIN-STYLE PREPROCESS + SAFE PAD + MILD BOOST) ====================
from google.colab import files
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import io, torch, cv2, os, re
from skimage import color

# ---- Unified Best-Available Loader (re-use if already defined) ----
def load_best_available_weights(generator, device):
    def try_load(path):
        ckpt = torch.load(path, map_location=device)
        state = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
        generator.load_state_dict(state)
        return True
    if os.path.exists("colorizer_GANBEST.pth"):
        try: try_load("colorizer_GANBEST.pth"); return "colorizer_GANBEST.pth", True
        except Exception as e: print(f"Note: GANBEST load failed: {e}")
    if os.path.exists("colorizer_FINAL.pth"):
        try: try_load("colorizer_FINAL.pth"); return "colorizer_FINAL.pth", True
        except Exception as e: print(f"Note: FINAL load failed: {e}")
    epoch_files = [f for f in os.listdir('.') if re.match(r'^colorizer_epoch_\d+\.pth$', f)]
    if epoch_files:
        epoch_files.sort(key=lambda n: int(re.findall(r'\d+', n)[-1]), reverse=True)
        for f in epoch_files:
            try: try_load(f); return f, True
            except Exception as e: print(f"Note: load failed for {f}: {e}")
    if os.path.exists("colorizer_BEST.pth"):
        try: try_load("colorizer_BEST.pth"); return "colorizer_BEST.pth", True
        except Exception as e: print(f"Note: BEST load failed: {e}")
    return None, False

which, loaded = load_best_available_weights(generator, device)
print(f"Loaded model for inference: {which if loaded else 'in-memory'}")
generator.eval()

# --- CONFIGURATION ---
TARGET_SIZE = (256, 256)  # match training canvas
MODEL_STRIDE = 32         # 5 downsamples ‚Üí 32. Use 64 if your net downsamples 6 times.
AB_SCALE = 128.0          # keep 128 here for this model

# Optional, subtle enhancements (safe for submission)
APPLY_L_CLAHE = True
CLAHE_CLIP = 2.0
CLAHE_TILES = (8, 8)
APPLY_L_GAMMA = True
L_GAMMA = 0.95
EDGE_AWARE_SMOOTHING = True
SMOOTH_SIGMA_SPATIAL = 1.0
SMOOTH_SIGMA_LUMA = 10.0
GLOBAL_AB_GAIN = 1.10
CONF_WEIGHTED_BOOST = 0.10
CONF_EPS = 1e-6

# --- Preprocessing aligned to validation: Resize shortest side to 256, then CenterCrop(256,256) ---
def resize_and_centercrop(img_pil, target_size):
    w, h = img_pil.size
    scale = 256.0 / min(w, h)
    new_w, new_h = int(round(w * scale)), int(round(h * scale))
    img_resized = img_pil.resize((new_w, new_h), Image.Resampling.LANCZOS)
    left = (new_w - target_size[1]) // 2
    top  = (new_h - target_size[0]) // 2
    return img_resized.crop((left, top, left + target_size[1], top + target_size[0]))

def pil_to_rgb_array_train_style(image_bytes, target_size):
    img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
    img = resize_and_centercrop(img, target_size)
    return np.array(img)

def rgb_to_L(rgb):
    rgb_f = rgb.astype(np.float32) / 255.0 if rgb.dtype != np.float32 and rgb.dtype != np.float64 else rgb
    return color.rgb2lab(rgb_f)[..., 0]  # [0,100]

def L_to_model_input(L):  # [0,100] -> [-1,1]
    return (L / 50.0) - 1.0

# --- Safe pad/crop for model forward ---
def safe_pad_to_stride(img2d, stride=32, mode='reflect'):
    H, W = img2d.shape
    H_pad = (stride - H % stride) % stride
    W_pad = (stride - W % stride) % stride
    t, b = H_pad // 2, H_pad - H_pad // 2
    l, r = W_pad // 2, W_pad - W_pad // 2
    if H_pad == 0 and W_pad == 0:
        return img2d, (0,0,0,0)
    return np.pad(img2d, ((t,b),(l,r)), mode=mode), (t,b,l,r)

def safe_unpad(ab, pads):
    t,b,l,r = pads
    return ab if (t|b|l|r)==0 else ab[:, t:ab.shape[1]-b, l:ab.shape[2]-r]

@torch.no_grad()
def run_net_on_Lnorm_safe(L_norm_2d, stride=MODEL_STRIDE):
    L_pad, pads = safe_pad_to_stride(L_norm_2d, stride=stride, mode='reflect')
    inp = L_pad.reshape(1, 1, L_pad.shape[0], L_pad.shape[1]).astype(np.float32)
    out = generator(torch.from_numpy(inp).to(device))  # (1,2,Hp,Wp) in [-1,1]
    ab_pad = out.detach().cpu().numpy()[0]
    ab = safe_unpad(ab_pad, pads)  # (2,H,W)
    return np.ascontiguousarray(ab.astype(np.float32))

# --- Optional refinements ---
def edge_aware_smooth_ab(L_base, ab_pred, sigma_spatial=1.0, sigma_luma=10.0):
    ab_lab = ab_pred * AB_SCALE
    a = ab_lab[0].astype(np.float32)
    b = ab_lab[1].astype(np.float32)
    d = 5
    a_s = cv2.bilateralFilter(a, d=d, sigmaColor=sigma_luma, sigmaSpace=sigma_spatial)
    b_s = cv2.bilateralFilter(b, d=d, sigmaColor=sigma_luma, sigmaSpace=sigma_spatial)
    return np.stack([a_s, b_s], axis=0) / AB_SCALE

def apply_l_adjustments(L):
    L_adj = L.copy()
    if APPLY_L_CLAHE:
        L_u8 = np.clip((L_adj / 100.0) * 255.0, 0, 255).astype(np.uint8)
        clahe = cv2.createCLAHE(clipLimit=CLAHE_CLIP, tileGridSize=CLAHE_TILES)
        L_u8 = clahe.apply(L_u8)
        L_adj = (L_u8.astype(np.float32) / 255.0) * 100.0
    if APPLY_L_GAMMA:
        L01 = np.clip(L_adj / 100.0, 0.0, 1.0)
        L01 = np.power(L01, L_GAMMA)
        L_adj = np.clip(L01, 0, 1) * 100.0
    return L_adj

def apply_ab_color_boost(ab_pred, global_gain=1.10, conf_boost=0.10, eps=1e-6):
    a, b = ab_pred[0], ab_pred[1]
    mag = np.sqrt(a*a + b*b)
    mag_norm = np.clip(mag / (1.0 + eps), 0.0, 1.0)
    gain_map = global_gain + conf_boost * mag_norm
    a_boost = np.clip(a * gain_map, -1.0, 1.0)
    b_boost = np.clip(b * gain_map, -1.0, 1.0)
    return np.stack([a_boost, b_boost], axis=0)

def lab_from_L_and_ab(L_base, ab_pred):
    ab_lab = ab_pred * AB_SCALE
    H, W = L_base.shape
    lab = np.zeros((H, W, 3), dtype=np.float64)
    lab[..., 0] = np.clip(L_base, 0, 100)
    lab[..., 1] = np.clip(ab_lab[0], -128, 127)
    lab[..., 2] = np.clip(ab_lab[1], -128, 127)
    return color.lab2rgb(lab)

print("Upload a grayscale/B&W image to colorize...")
uploaded = files.upload()
if uploaded:
    name = list(uploaded.keys())[0]
    print("Processing:", name)
    try:
        rgb = pil_to_rgb_array_train_style(uploaded[name], TARGET_SIZE)
        L = rgb_to_L(rgb)  # [0,100]

        # Forward with safe pad/crop
        ab = run_net_on_Lnorm_safe(L_to_model_input(L), stride=MODEL_STRIDE)  # (2,H,W) in [-1,1]

        # Optional: smoothing + mild color boost + L contrast
        if EDGE_AWARE_SMOOTHING:
            ab = edge_aware_smooth_ab(L, ab, SMOOTH_SIGMA_SPATIAL, SMOOTH_SIGMA_LUMA)
        ab = apply_ab_color_boost(ab, GLOBAL_AB_GAIN, CONF_WEIGHTED_BOOST, CONF_EPS)
        L_adj = apply_l_adjustments(L)

        out = lab_from_L_and_ab(L_adj, ab)

        # Show
        plt.figure(figsize=(14,6))
        plt.subplot(1,2,1); plt.imshow(L/100.0, cmap='gray'); plt.title('Input (Grayscale)', fontweight='bold'); plt.axis('off')
        plt.subplot(1,2,2); plt.imshow(out); plt.title('Colorized Output', fontweight='bold'); plt.axis('off')
        plt.tight_layout(); plt.show()
    except Exception as e:
        print("‚ùå Error:", str(e))
else:
    print("No file uploaded.")