# Automatic Image Colorization — Baseline CNN (STL-10, 96×96)

This notebook implements a **baseline encoder–decoder CNN** for **automatic image colorization** on the **STL-10 dataset**.  
The task is formulated as a supervised regression problem:

- **Input:** grayscale image `L` (1 channel, 96×96)
- **Target:** RGB color image (3 channels, 96×96)

## 1) Dataset pipeline (STL10GrayColor)
We build a custom `Dataset` that returns `(gray, color)` pairs:
- Load STL-10 RGB images
- (Train only) apply **random horizontal flip**
- Resize to **96×96**
- Convert to:
  - grayscale tensor `[1, H, W]` in `[0,1]`
  - RGB tensor `[3, H, W]` in `[0,1]`

## 2) Model: CNN Encoder–Decoder (96×96)
The network is a standard **bottleneck architecture**:
- **Encoder:** conv blocks + max pooling  
  `1×96×96 → 64×48×48 → 128×24×24 → 256×12×12 → 512×6×6`
- **Decoder:** transposed convolutions (upsampling) back to 96×96  
  `512×6×6 → ... → 3×96×96`
- Final `Sigmoid` ensures predictions stay in `[0,1]`.

## 3) Training & evaluation
- **Loss:** Mean Squared Error (**MSE**) on RGB pixels
- **Optimizer:** Adam
- **Metric:** **PSNR**, computed from the validation MSE:
  \[
  \mathrm{PSNR} = 10 \log_{10}\left(\frac{1}{\mathrm{MSE}}\right)
  \]
  (assuming images are normalized in `[0,1]`)

We also provide visualization utilities:
- training curves (MSE + PSNR)
- qualitative results: **top-k best predictions** (lowest per-image MSE)

In [None]:
# ============================================================
# Automatic Image Colorization (CNN Encoder–Decoder, STL-10 96x96)
# ============================================================
import os
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.datasets import STL10


os.makedirs("resultsSTL", exist_ok=True)


# ============================================================
# 1) Dataset: grayscale input -> RGB target
# ============================================================
class STL10GrayColor(Dataset):
    """
    Returns:
        gray:  Tensor [1, H, W] in [0, 1]
        color: Tensor [3, H, W] in [0, 1]
    """

    def __init__(self, root: str = "./data", split: str = "train", download: bool = True, image_size: int = 96):
        self.base = STL10(root=root, split=split, download=download)
        self.split = split
        self.image_size = image_size

        self.resize_color = transforms.Resize((image_size, image_size))
        self.to_tensor = transforms.ToTensor()

        self.gray_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
        ])

        self.augment = transforms.RandomHorizontalFlip(p=0.5) if split == "train" else None

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, _ = self.base[idx]  # PIL RGB image (label not used)

        if self.augment is not None:
            img = self.augment(img)

        img = self.resize_color(img)

        color = self.to_tensor(img)
        gray = self.gray_transform(img)
        return gray, color


# ============================================================
# 2) CNN Encoder–Decoder for 96x96
# ============================================================
class ColorizationCNN96(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 96 -> 48

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 48 -> 24

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 24 -> 12

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 12 -> 6
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # 6 -> 12
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 12 -> 24
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 24 -> 48
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 48 -> 96
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Sigmoid(),  # outputs in [0, 1]
        )

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)


# ============================================================
# 3) Training / evaluation utilities
# ============================================================
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    n_samples = 0

    for gray, color in dataloader:
        gray = gray.to(device)
        color = color.to(device)

        optimizer.zero_grad()
        pred = model(gray)
        loss = criterion(pred, color)
        loss.backward()
        optimizer.step()

        bs = gray.size(0)
        total_loss += loss.item() * bs
        n_samples += bs

    return total_loss / n_samples


def evaluate(model, dataloader, criterion, device, eps: float = 1e-12):
    model.eval()
    total_loss = 0.0
    n_samples = 0

    with torch.no_grad():
        for gray, color in dataloader:
            gray = gray.to(device)
            color = color.to(device)

            pred = model(gray)
            loss = criterion(pred, color)

            bs = gray.size(0)
            total_loss += loss.item() * bs
            n_samples += bs

    mse = total_loss / n_samples
    mse_safe = max(mse, eps)
    psnr = 10 * math.log10(1.0 / mse_safe)
    return mse, psnr


def show_samples(model, dataloader, device, n=5):
    model.eval()
    gray, color = next(iter(dataloader))
    gray = gray.to(device)
    color = color.to(device)

    with torch.no_grad():
        pred = model(gray)

    gray = gray[:n].cpu()
    color = color[:n].cpu()
    pred = pred[:n].cpu()

    fig, axes = plt.subplots(nrows=n, ncols=3, figsize=(9, 3 * n))
    if n == 1:
        axes = np.expand_dims(axes, axis=0)

    for i in range(n):
        g = gray[i].squeeze(0).numpy()
        gt = np.transpose(color[i].numpy(), (1, 2, 0))
        pr = np.transpose(pred[i].numpy(), (1, 2, 0))

        axes[i, 0].imshow(g, cmap="gray")
        axes[i, 0].set_title("Input (grayscale)")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(pr)
        axes[i, 1].set_title("Prediction")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(gt)
        axes[i, 2].set_title("Ground truth")
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()


# ============================================================
# 4) Ranking best samples by per-image MSE
# ============================================================
def get_top_k_samples(model, dataloader, device, k=10):
    model.eval()
    all_errors, all_gray, all_color, all_pred = [], [], [], []

    with torch.no_grad():
        for gray, color in dataloader:
            gray = gray.to(device)
            color = color.to(device)

            pred = model(gray)
            err = ((pred - color) ** 2).view(gray.size(0), -1).mean(dim=1)

            all_errors.append(err.cpu())
            all_gray.append(gray.cpu())
            all_color.append(color.cpu())
            all_pred.append(pred.cpu())

    errors = torch.cat(all_errors)
    gray_all = torch.cat(all_gray)
    color_all = torch.cat(all_color)
    pred_all = torch.cat(all_pred)

    _, indices = torch.sort(errors)
    topk_idx = indices[:k]

    return gray_all[topk_idx], pred_all[topk_idx], color_all[topk_idx], errors[topk_idx]


def show_top_k_samples(model, dataloader, device, k=10, filename="resultsSTL/topk_best_colorization.png"):
    gray_k, pred_k, color_k, err_k = get_top_k_samples(model, dataloader, device, k=k)

    n = gray_k.size(0)
    fig, axes = plt.subplots(nrows=n, ncols=3, figsize=(9, 3 * n))
    if n == 1:
        axes = np.expand_dims(axes, axis=0)

    for i in range(n):
        g = gray_k[i].squeeze(0).numpy()
        gt = np.transpose(color_k[i].numpy(), (1, 2, 0))
        pr = np.transpose(pred_k[i].numpy(), (1, 2, 0))

        axes[i, 0].imshow(g, cmap="gray")
        axes[i, 0].set_title("Input (grayscale)")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(pr)
        axes[i, 1].set_title(f"Prediction\nMSE={err_k[i]:.4f}")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(gt)
        axes[i, 2].set_title("Ground truth")
        axes[i, 2].axis("off")

    plt.tight_layout()
    fig.savefig(filename, dpi=300)
    plt.show()


# ============================================================
# 5) Main
# ============================================================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    train_dataset = STL10GrayColor(split="train", download=True, image_size=96)
    test_dataset = STL10GrayColor(split="test", download=True, image_size=96)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

    model = ColorizationCNN96().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    num_epochs = 25
    train_losses, val_losses, val_psnrs = [], [], []

    for epoch in range(1, num_epochs + 1):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_psnr = evaluate(model, test_loader, criterion, device)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_psnrs.append(val_psnr)

        print(
            f"[Epoch {epoch:02d}/{num_epochs}] "
            f"Train MSE: {train_loss:.4f} | Val MSE: {val_loss:.4f} | Val PSNR: {val_psnr:.2f} dB"
        )

    plt.figure()
    plt.plot(train_losses, label="Train MSE")
    plt.plot(val_losses, label="Val MSE")
    plt.xlabel("Epoch")
    plt.ylabel("MSE")
    plt.title("Training / Validation Loss (STL-10, 96x96)")
    plt.legend()
    plt.grid(True)
    plt.show()

    plt.figure()
    plt.plot(val_psnrs, label="Val PSNR (dB)")
    plt.xlabel("Epoch")
    plt.ylabel("PSNR (dB)")
    plt.title("Validation PSNR (STL-10, 96x96)")
    plt.legend()
    plt.grid(True)
    plt.show()

    show_top_k_samples(model, test_loader, device, k=50, filename="resultsSTL/top50_best_colorization.png")


if __name__ == "__main__":
    main()


# Automatic Image Colorization — U-Net (STL-10, 96×96)

This section upgrades the baseline encoder–decoder to a **U-Net** for colorization at **96×96**.  
The key change is the use of **skip connections**: encoder feature maps are concatenated with decoder features at matching resolutions, helping preserve **edges and fine textures** that are typically lost through the bottleneck.

## Model: ColorizationUNet96
- **Encoder:** 4 downsampling stages (conv blocks + max pool)
- **Bottleneck:** conv block at 6×6
- **Decoder:** transposed convolutions + **concat skips** + conv blocks
- Output head predicts **RGB** in `[0,1]` via `Sigmoid`.

## Training objective
We train with a **mixed regression loss** to reduce color averaging:
- **Train loss:** weighted combination of **L1 + MSE**
- **Evaluation:** PSNR is computed from **MSE only** on the validation/test set (pixel fidelity metric)

## Outputs
As in the baseline section, we log:
- training curves (loss + PSNR)
- qualitative examples (top-k lowest per-image MSE), saved under `resultsSTL_UNET/`


In [None]:
# ============================================================
# Automatic Image Colorization (U-Net, STL-10 96x96)
# ============================================================
import os
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.datasets import STL10


os.makedirs("resultsSTL_UNET", exist_ok=True)


# ============================================================
# 1) Dataset: grayscale input -> RGB target
# ============================================================
class STL10GrayColor(Dataset):
    """
    Returns:
        gray:  Tensor [1, H, W] in [0, 1]
        color: Tensor [3, H, W] in [0, 1]
    """

    def __init__(self, root: str = "./data", split: str = "train", download: bool = True, image_size: int = 96):
        self.base = STL10(root=root, split=split, download=download)
        self.split = split
        self.image_size = image_size

        self.resize_color = transforms.Resize((image_size, image_size))
        self.to_tensor = transforms.ToTensor()

        self.gray_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
        ])

        self.augment = transforms.RandomHorizontalFlip(p=0.5) if split == "train" else None

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, _ = self.base[idx]  # PIL RGB image (label not used)

        if self.augment is not None:
            img = self.augment(img)

        img = self.resize_color(img)

        color = self.to_tensor(img)
        gray = self.gray_transform(img)
        return gray, color


# ============================================================
# 2) U-Net model for 96x96 colorization
# ============================================================
def conv_block(in_ch, out_ch):
    """Two 3x3 convs + BN + ReLU (standard U-Net block)."""
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
    )


class ColorizationUNet96(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.enc1 = conv_block(1, 64)
        self.pool1 = nn.MaxPool2d(2)  # 96 -> 48

        self.enc2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)  # 48 -> 24

        self.enc3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)  # 24 -> 12

        self.enc4 = conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(2)  # 12 -> 6

        # Bottleneck
        self.bottleneck = conv_block(512, 512)  # 6x6

        # Decoder + skip connections
        self.up4 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1)  # 6 -> 12
        self.dec4 = conv_block(512 + 512, 512)  # concat with enc4

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)  # 12 -> 24
        self.dec3 = conv_block(256 + 256, 256)  # concat with enc3

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)  # 24 -> 48
        self.dec2 = conv_block(128 + 128, 128)  # concat with enc2

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)   # 48 -> 96
        self.dec1 = conv_block(64 + 64, 64)  # concat with enc1

        self.final_conv = nn.Conv2d(64, 3, kernel_size=1)
        self.final_act = nn.Sigmoid()  # outputs in [0, 1]

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)

        e2 = self.enc2(p1)
        p2 = self.pool2(e2)

        e3 = self.enc3(p2)
        p3 = self.pool3(e3)

        e4 = self.enc4(p3)
        p4 = self.pool4(e4)

        b = self.bottleneck(p4)

        u4 = self.up4(b)
        d4 = self.dec4(torch.cat([u4, e4], dim=1))

        u3 = self.up3(d4)
        d3 = self.dec3(torch.cat([u3, e3], dim=1))

        u2 = self.up2(d3)
        d2 = self.dec2(torch.cat([u2, e2], dim=1))

        u1 = self.up1(d2)
        d1 = self.dec1(torch.cat([u1, e1], dim=1))

        out = self.final_act(self.final_conv(d1))
        return out


# ============================================================
# 3) Training / evaluation utilities
# ============================================================
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    n_samples = 0

    for gray, color in dataloader:
        gray = gray.to(device)
        color = color.to(device)

        optimizer.zero_grad()
        pred = model(gray)
        loss = criterion(pred, color)
        loss.backward()
        optimizer.step()

        bs = gray.size(0)
        total_loss += loss.item() * bs
        n_samples += bs

    return total_loss / n_samples


def evaluate(model, dataloader, criterion, device, eps: float = 1e-12):
    model.eval()
    total_loss = 0.0
    n_samples = 0

    with torch.no_grad():
        for gray, color in dataloader:
            gray = gray.to(device)
            color = color.to(device)

            pred = model(gray)
            loss = criterion(pred, color)

            bs = gray.size(0)
            total_loss += loss.item() * bs
            n_samples += bs

    mse = total_loss / n_samples
    mse_safe = max(mse, eps)
    psnr = 10 * math.log10(1.0 / mse_safe)
    return mse, psnr


def show_samples(model, dataloader, device, n=5):
    model.eval()
    gray, color = next(iter(dataloader))
    gray = gray.to(device)
    color = color.to(device)

    with torch.no_grad():
        pred = model(gray)

    gray = gray[:n].cpu()
    color = color[:n].cpu()
    pred = pred[:n].cpu()

    fig, axes = plt.subplots(nrows=n, ncols=3, figsize=(9, 3 * n))
    if n == 1:
        axes = np.expand_dims(axes, axis=0)

    for i in range(n):
        g = gray[i].squeeze(0).numpy()
        gt = np.transpose(color[i].numpy(), (1, 2, 0))
        pr = np.transpose(pred[i].numpy(), (1, 2, 0))

        axes[i, 0].imshow(g, cmap="gray")
        axes[i, 0].set_title("Input (grayscale)")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(pr)
        axes[i, 1].set_title("Prediction")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(gt)
        axes[i, 2].set_title("Ground truth")
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()


# ============================================================
# 4) Ranking best samples by per-image MSE
# ============================================================
def get_top_k_samples(model, dataloader, device, k=10):
    model.eval()
    all_errors, all_gray, all_color, all_pred = [], [], [], []

    with torch.no_grad():
        for gray, color in dataloader:
            gray = gray.to(device)
            color = color.to(device)

            pred = model(gray)
            err = ((pred - color) ** 2).view(gray.size(0), -1).mean(dim=1)

            all_errors.append(err.cpu())
            all_gray.append(gray.cpu())
            all_color.append(color.cpu())
            all_pred.append(pred.cpu())

    errors = torch.cat(all_errors)
    gray_all = torch.cat(all_gray)
    color_all = torch.cat(all_color)
    pred_all = torch.cat(all_pred)

    _, indices = torch.sort(errors)
    topk_idx = indices[:k]

    return gray_all[topk_idx], pred_all[topk_idx], color_all[topk_idx], errors[topk_idx]


def show_top_k_samples(model, dataloader, device, k=10, filename="resultsSTL_UNET/topk_best_colorization.png"):
    gray_k, pred_k, color_k, err_k = get_top_k_samples(model, dataloader, device, k=k)

    n = gray_k.size(0)
    fig, axes = plt.subplots(nrows=n, ncols=3, figsize=(9, 3 * n))
    if n == 1:
        axes = np.expand_dims(axes, axis=0)

    for i in range(n):
        g = gray_k[i].squeeze(0).numpy()
        gt = np.transpose(color_k[i].numpy(), (1, 2, 0))
        pr = np.transpose(pred_k[i].numpy(), (1, 2, 0))

        axes[i, 0].imshow(g, cmap="gray")
        axes[i, 0].set_title("Input (grayscale)")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(pr)
        axes[i, 1].set_title(f"Prediction\nMSE={err_k[i]:.4f}")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(gt)
        axes[i, 2].set_title("Ground truth")
        axes[i, 2].axis("off")

    plt.tight_layout()
    fig.savefig(filename, dpi=300)
    plt.show()


# ============================================================
# 5) Main
# ============================================================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    train_dataset = STL10GrayColor(split="train", download=True, image_size=96)
    test_dataset = STL10GrayColor(split="test", download=True, image_size=96)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

    model = ColorizationUNet96().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    mse_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()

    def combined_loss(pred, target):
        # Weighted combination: encourages sharper colors (L1) while keeping pixel fidelity (MSE)
        return 0.7 * l1_loss(pred, target) + 0.3 * mse_loss(pred, target)

    num_epochs = 20
    train_losses, val_losses, val_psnrs = [], [], []

    for epoch in range(1, num_epochs + 1):
        train_loss = train_one_epoch(model, train_loader, combined_loss, optimizer, device)

        # Validation: PSNR is computed from MSE, so evaluate with MSE on the val set
        val_loss, val_psnr = evaluate(model, test_loader, mse_loss, device)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_psnrs.append(val_psnr)

        print(
            f"[Epoch {epoch:02d}/{num_epochs}] "
            f"Train loss: {train_loss:.4f} | Val MSE: {val_loss:.4f} | Val PSNR: {val_psnr:.2f} dB"
        )

    plt.figure()
    plt.plot(train_losses, label="Train loss (combined)")
    plt.plot(val_losses, label="Val MSE")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training / Validation Loss (STL-10, 96x96)")
    plt.legend()
    plt.grid(True)
    plt.show()

    plt.figure()
    plt.plot(val_psnrs, label="Val PSNR (dB)")
    plt.xlabel("Epoch")
    plt.ylabel("PSNR (dB)")
    plt.title("Validation PSNR (STL-10, 96x96)")
    plt.legend()
    plt.grid(True)
    plt.show()

    show_top_k_samples(model, test_loader, device, k=50, filename="resultsSTL_UNET/top50_best_colorization.png")


if __name__ == "__main__":
    main()
