# Deep Learning-Based Image Super-Resolution in LAB Color Space

This notebook demonstrates how to use deep learning to improve the resolution of images. We'll use the LAB color space, which separates lightness from color, making it easier for our model to focus on important details.

In [None]:
# Configuration for training
config = {
    "batch_size": 16,
    "adam_lr": 1e-3,
    "loss_type": "mse",  # or 'lpips'
    "resume_weights_path": None,  # Path to checkpoint to resume from
    "num_epochs": 10,
    "train_batches_to_load": 64,
    "validation_batches_to_load": 16,
    "overfit_one_batch": False,
    "max_samples": None,  # Set to an integer to limit training samples
    "visualize_every_n_epochs": 1,
}

## What is Image Super-Resolution?
Image super-resolution is the process of taking a blurry or low-resolution image and making it clearer and sharper. This is useful in many fields, like medical imaging, security, and photography.

## Import Required Libraries
Let's start by importing the libraries we need. These include tools for deep learning (PyTorch), image processing, and visualization.

In [None]:
import warnings
from pathlib import Path
from io import BytesIO
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, IterableDataset
from torchvision import transforms
from skimage import color
from datasets import load_dataset
from PIL import Image
import lpips
from torchinfo import summary

## Why Use the LAB Color Space?
The LAB color space separates the lightness (L) from the color information (A and B channels). This helps our model focus on making the image brighter and clearer, while keeping the colors accurate.

In [None]:
def to_lab_tensor(img: Image.Image) -> torch.Tensor:
    arr = np.asarray(img.convert("RGB"), dtype=np.float32) / 255.0
    lab = color.rgb2lab(arr)
    l_channel = lab[..., 0:1] / 100.0
    a = lab[..., 1:2] / 128.0
    b = lab[..., 2:3] / 128.0
    lab_norm = np.concatenate([l_channel, a, b], axis=-1)
    return torch.from_numpy(lab_norm.transpose(2, 0, 1).copy()).float()


def to_numpy_img(lab_tensor: torch.Tensor) -> np.ndarray:
    arr = lab_tensor.detach().cpu().numpy().transpose(1, 2, 0)
    l_channel = arr[..., 0] * 100.0
    a = arr[..., 1] * 128.0
    b = arr[..., 2] * 128.0
    lab = np.stack([l_channel, a, b], axis=-1)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)
        rgb = color.lab2rgb(lab)
    return np.clip(rgb, 0, 1)

## Preparing the Dataset
We will use a large image dataset and process it so that the model can learn to turn low-resolution images into high-resolution ones. We'll crop, resize, and convert images to LAB format.

In [None]:
def preprocess_stream(dataset, crop_size: int, scale: int):
    lowres_size = crop_size // scale
    crop = transforms.RandomCrop(crop_size)
    for example in dataset:
        img_data = example["image"]
        img = (
            img_data
            if isinstance(img_data, Image.Image)
            else Image.open(BytesIO(img_data)).convert("RGB")
        )
        if min(img.size) < crop_size:
            scale_factor = crop_size / min(img.size)
            new_size = (
                int(round(img.size[0] * scale_factor)),
                int(round(img.size[1] * scale_factor)),
            )
            img = img.resize(new_size, resample=Image.Resampling.BICUBIC)
        img_patch = crop(img)
        lab_patch = to_lab_tensor(img_patch)
        lab_patch_lowres = torch.nn.functional.interpolate(
            lab_patch.unsqueeze(0),
            size=(lowres_size, lowres_size),
            mode="bicubic",
            align_corners=False,
        ).squeeze(0)
        yield lab_patch_lowres, lab_patch


class SuperResStream(IterableDataset):
    def __init__(self, dataset, crop_size: int, scale: int):
        self.dataset = dataset
        self.crop_size = crop_size
        self.scale = scale

    def __iter__(self):
        return preprocess_stream(self.dataset, self.crop_size, self.scale)

## Load the Datasets
Let's load the training and validation datasets using Hugging Face Datasets. We'll use streaming mode to avoid downloading the entire dataset at once.

In [None]:
train_dataset = load_dataset(
    "imagenet-1k",
    split="train",
    streaming=True,
    trust_remote_code=True,
)
val_dataset = load_dataset(
    "imagenet-1k",
    split="validation",
    streaming=True,
    trust_remote_code=True,
)

In [None]:
crop_size = 128
scale = 2
train_stream = SuperResStream(train_dataset, crop_size, scale)
val_stream = SuperResStream(val_dataset, crop_size, scale)

In [None]:
train_loader = DataLoader(
    train_stream,
    batch_size=config["batch_size"],
    num_workers=0,
    pin_memory=True,
)
val_loader = DataLoader(
    val_stream,
    batch_size=config["batch_size"],
    num_workers=0,
    pin_memory=True,
)

## Build the Deep Super-Resolution Model
We'll use a deep neural network to predict the high-resolution version of the image. The model focuses on the L (lightness) channel, which is most important for sharpness.

In [None]:
class DeepSuperResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.head = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, padding=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        body_layers = []
        for _ in range(8):
            body_layers.append(nn.Conv2d(128, 128, kernel_size=3, padding=1))
            body_layers.append(nn.ReLU(inplace=True))
        self.body = nn.Sequential(*body_layers)
        self.upsample = nn.Sequential(
            nn.Conv2d(128, 512, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 512, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True),
        )
        self.tail = nn.Conv2d(
            128,
            1,
            kernel_size=3,
            padding=1,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.head(x)
        x = self.body(x)
        x = self.upsample(x)
        x = self.tail(x)
        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepSuperResNet().to(device)
print("Model summary:")
print(
    summary(
        model,
        input_size=(config["batch_size"], 3, 64, 64),
        col_names=("input_size", "output_size", "num_params"),
        depth=4,
        row_settings=("var_names",),
    )
)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=config.get("adam_lr", 1e-3))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=5,
    min_lr=1e-7,
)

In [None]:
if config["loss_type"] == "lpips":
    lpips_loss_fn = lpips.LPIPS(net="vgg", spatial=False).to(device)

    def criterion(pred, target):
        pred = 2 * pred - 1
        target = 2 * target - 1
        return lpips_loss_fn(pred, target).mean()
elif config["loss_type"] == "mse":
    mse_loss_fn = nn.MSELoss()

    def criterion(pred, target):
        return mse_loss_fn(pred, target)
else:
    raise ValueError(f"Unknown loss_type: {config['loss_type']}")

## Training Loop: One Epoch
Let's run one epoch of training. We'll process a set number of batches, update the model, and keep track of the loss.

In [None]:
model.train()
running_loss, batch_count = 0.0, 0
train_iterator = iter(train_loader)
last_lowres, last_highres = None, None
for _ in range(config["train_batches_to_load"]):
    try:
        lowres, highres = next(train_iterator)
    except StopIteration:
        train_iterator = iter(train_loader)
        lowres, highres = next(train_iterator)
    lowres = lowres.to(device)
    highres = highres.to(device)
    upsampled = torch.nn.functional.interpolate(
        lowres,
        size=highres.shape[-2:],
        mode="bicubic",
        align_corners=False,
    )
    output_L = model(upsampled)
    highres_L = highres[:, 0:1]
    if output_L.shape[-2:] != highres_L.shape[-2:]:
        output_L = torch.nn.functional.interpolate(
            output_L,
            size=highres_L.shape[-2:],
            mode="bicubic",
            align_corners=False,
        )
    loss = criterion(output_L, highres_L)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    running_loss += loss.item()
    batch_count += 1
    last_lowres, last_highres = (
        lowres.detach().cpu(),
        highres.detach().cpu(),
    )
    print(f"Batch {_+1}/{config['train_batches_to_load']}, Loss: {loss.item():.4f}")
avg_loss = running_loss / batch_count if batch_count > 0 else float("nan")
print(f"Average training loss: {avg_loss:.4f}")

## Validation Loop: One Epoch
Now let's evaluate the model on the validation set for one epoch.

In [None]:
model.eval()
val_running_loss, val_batch_count = 0.0, 0
val_iterator = iter(val_loader)
val_last_lowres, val_last_highres = None, None
with torch.no_grad():
    for _ in range(config["validation_batches_to_load"]):
        try:
            val_lowres, val_highres = next(val_iterator)
        except StopIteration:
            val_iterator = iter(val_loader)
            val_lowres, val_highres = next(val_iterator)
        val_lowres = val_lowres.to(device)
        val_highres = val_highres.to(device)
        val_upsampled = torch.nn.functional.interpolate(
            val_lowres,
            size=val_highres.shape[-2:],
            mode="bicubic",
            align_corners=False,
        )
        val_output_L = model(val_upsampled)
        val_highres_L = val_highres[:, 0:1]
        if val_output_L.shape[-2:] != val_highres_L.shape[-2:]:
            val_output_L = torch.nn.functional.interpolate(
                val_output_L,
                size=val_highres_L.shape[-2:],
                mode="bicubic",
                align_corners=False,
            )
        val_loss = criterion(val_output_L, val_highres_L)
        val_running_loss += val_loss.item()
        val_batch_count += 1
        val_last_lowres, val_last_highres = (
            val_lowres.detach().cpu(),
            val_highres.detach().cpu(),
        )
        print(
            f"Val Batch {_+1}/{config['validation_batches_to_load']}, Loss: {val_loss.item():.4f}"
        )
val_avg_loss = (
    val_running_loss / val_batch_count if val_batch_count > 0 else float("nan")
)
print(f"Average validation loss: {val_avg_loss:.4f}")

## Visualize Progress
Let's visualize the results of the last batch from both training and validation.

In [None]:
def visualize_progress(model, lowres, highres, device, title="Output"):
    model.eval()
    with torch.no_grad():
        lowres_AB = lowres[:, 1:3]
        highres_L = highres[:, 0:1]
        upsampled = torch.nn.functional.interpolate(
            lowres,
            size=highres.shape[-2:],
            mode="bicubic",
            align_corners=False,
        )
        pred_L = model(upsampled.to(device)).cpu()
        up_AB = torch.nn.functional.interpolate(
            lowres_AB,
            size=highres_L.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )
        if pred_L.shape[-2:] != up_AB.shape[-2:]:
            pred_L = torch.nn.functional.interpolate(
                pred_L,
                size=up_AB.shape[-2:],
                mode="bicubic",
                align_corners=False,
            )
        outputs = torch.cat([pred_L, up_AB], dim=1)
        n = min(6, lowres.size(0))
        fig, axes = plt.subplots(3, n, figsize=(2.5 * n, 8))
        for i in range(n):
            axes[0, i].imshow(to_numpy_img(lowres[i]))
            axes[0, i].set_title("Low-res", fontsize=10)
            axes[1, i].imshow(to_numpy_img(outputs[i]))
            axes[1, i].set_title("Output", fontsize=10)
            axes[2, i].imshow(to_numpy_img(highres[i]))
            axes[2, i].set_title("High-res", fontsize=10)
            for row in range(3):
                axes[row, i].axis("off")
        plt.suptitle(title)
        plt.tight_layout()
        plt.show()

In [None]:
visualize_progress(
    model, last_lowres, last_highres, device, title="Training Batch Output"
)
visualize_progress(
    model, val_last_lowres, val_last_highres, device, title="Validation Batch Output"
)

## Save Model Weights
Let's save the model weights after training.

In [None]:
weights_dir = Path("weights")
weights_dir.mkdir(exist_ok=True)
weights_path = weights_dir / "superres_model_final.pth"
torch.save(model.state_dict(), str(weights_path))
print(f"Model weights saved to {weights_path}")

## Summary
In this notebook, you learned how to:
- Prepare images for super-resolution using the LAB color space
- Build a deep neural network to improve image resolution
- Visualize and evaluate the results
- Train and validate your model on a large dataset

Try running the notebook and see how well your model can make blurry images sharp!