# Efficient Diffusion Models for Image Super-Resolution

##### Neural Networks - Master in Artificial Intelligence and Robotics, Sapienza University of Rome

---

### Authors:
> 1986191: Leonardo Mariut \
> 2190452: Mohamed Zakaria Benjelloun Tuimy

---

### Aim:

This project improves ResShift for single-image super resolution, focusing on image quality while keeping model complexity and inference time comparable. Instead of directly predicting the full HR image, the network predicts residuals on top of a bicubic upsampled base. This residual formulation reduces redundancy, accelerates convergence, and lets the model specialize on recovering high-frequency details. The core idea is to learn features both in the spatial domain and in complementary frequency and wavelet domains, and to combine those feature so the network recovers fine textures and sharp edges efficiently.

---

### Papers:

The work builds on a few recent papers and classic ideas. 

See Implicit Diffusion Models for Continuous Super-Resolution for diffusion conditioning ideas, arbitrary-steps Image Super-resolution via Diffusion Inversion for inversion and sampling strategies, 

Dual-domain Modulation Network for Lightweight Image Super-Resolution for guidance on fusing spatial and frequency streams, and NeRF for the intuition behind using frequency positional encodings. 

The links in the repo point to each paper:

> [Implicit Diffusion Models for Continuous Super-Resolution](https://arxiv.org/abs/2303.16491) \
> [Arbitrary-steps Image Super-resolution via Diffusion Inversion](https://arxiv.org/html/2412.09013v1) \
> [Dual-domain Modulation Network for Lightweight Image Super-Resolution](https://arxiv.org/abs/2503.10047) 

---

### Key concepts:

**Fourier transform**: Fourier transforms decompose image content into frequency components so amplitude and phase information becomes explicit. High frequency coefficients capture edges and fine texture, low frequency coefficients capture coarse structure. Working with frequency representations lets the model pay direct attention to texture and edge fidelity that spatial convolutions can miss.

**Conditioned diffusion model**: The model is a conditional reverse Markov process that starts from a noisy initial state informed by the bicubic upsampled low resolution image. Conditioning on that base image gives a much faster and more stable inference path than starting from pure noise, and it biases the sampler to preserve global structure while refining high frequency detail. The diffusion model predicts residuals relative to the bicubic input rather than the absolute HR image.

**Laplacian pyramids**: A Laplacian pyramid is a multi-scale image representation where an image is decomposed into progressively lower-resolution approximations and the band-pass residuals between them. For super-resolution, this means the network can reconstruct fine details in a coarse-to-fine manner: at each level, the model predicts residuals that correct and refine the upsampled lower-resolution estimate.

**UNet backbone**: The network uses a standard encoder-decoder topology with skip connections. Time conditioning is injected via a sinusoidal positional embedding fed through a small MLP so the network knows the current reverse step. The UNet is a convenient backbone because it mixes local and contextual information at multiple scales while keeping spatial resolution alignment for skip connections.

**Dual-domain fusion**: The model processes features in three parallel branches: a spatial branch for standard convolutional processing, a DCT branch that explicitly filters and learns in the frequency domain, and a DWT branch that captures multi-resolution wavelet coefficients. Outputs from the three branches are concatenated and fused, letting the network integrate texture and structure cues from complementary representations.

**Eta schedule**: The noise schedule eta(t) controls how the noisy mixture between x0 and y0 evolves over the forward process. We parametrize a smoothly decaying schedule that keeps early steps relatively noisy while letting later steps converge, which stabilizes training and produces better sampling behavior when reversing the chain.

**Perceptual loss**: An optional perceptual loss uses VGG features to compare predicted and ground truth images in feature space rather than only pixel space. Perceptual loss helps with texture realism and avoids some of the blurring effects that pure L1 or L2 losses introduce, while still being lightweight when used with a small weight.

**Evaluation metrics**: We report PSNR and SSIM on the validation set. PSNR measures mean squared error in log scale and captures overall pixel fidelity. SSIM measures structural similarity and better correlates with perceived image quality for small structural differences. Both metrics are useful together because PSNR penalizes global error while SSIM focuses on local structural preservation.

**Frequency positional encodings**: Inspired by NeRF, we add Fourier-style positional encodings to some inputs so the network has access to a richer set of sinusoidal basis functions of different frequencies. This gives the model an easy, spatially aware mechanism to represent high frequency variations and to modulate predictions according to local spatial phase, which is helpful when reconstructing textures and fine edges at high upsampling factors. When used with Laplacian pyramids, frequency encodings can be applied at each pyramid level with scale-dependent frequency ranges. This ensures residuals at finer pyramid levels receive richer high-frequency modulation, which is critical for textures like grass, hair, or fabric.

---

### The dataset:

We use [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K) as the primary dataset. The dataset was downloaded from the DIV2K distribution and contains 800 high resolution images for training and 200 high resolution images for validation in our setup. Training images are cropped into random patches to build the training samples, while validation uses original HR patches so metrics are measured on true full-resolution content. For the baseline comparison the low resolution input is produced by lossy bicubic downsampling followed by bicubic upsampling to the target size (4x), which also serves as the y0 base image the model conditions on.

---

# Implementation

In [1]:
!pip install torch torchvision torch_dct pytorch-wavelets



## 1. Imports

In [2]:
import os, math, random, re
from glob import glob
from typing import Tuple

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

# external packages required: 
from torch_dct import dct_2d, idct_2d
from pytorch_wavelets import DWTForward, DWTInverse

## 2. Globals and hyperparameters

In [None]:
# flags, paths, model and training hyperparams
DEBUG_RUN = False

# Dataset
DATA_ROOT = "datasets/DIV2K"
TRAIN_HR_DIR = os.path.join(DATA_ROOT, "train/HR")
VALID_HR_DIR = os.path.join(DATA_ROOT, "valid/HR")
SAVE_DIR = "checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

# Markov chain and network
SCALE = 4
T_STEPS = 15
KAPPA = 2.0
ETA_1_FIXED = 0.008
P_SCHEDULE = 0.3
NUM_CHANNELS = 64 # UNet channels

# Training
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 380
FREQ_LOSS_WEIGHT = 2.0
SAVE_INTERVAL = 10 # Epochs save weights interval
EVAL_INTERVAL = 10 # Epochs evaluate interval

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 144
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
print(f"Device: {DEVICE}")

# Additional setup - positional encodings and perceptual loss
USE_FOURIER = True
NUM_POS_FREQ = 6
USE_AMPLITUDE = True
AMP_KERNEL = 7
MULTISCALE_SCALES = (1, 2, 4)

USE_VGG_PERCEPTUAL = False
PERCEPTUAL_WEIGHT = 0.05

POS_CHANNELS = (4 * NUM_POS_FREQ) if USE_FOURIER else 0
AMP_CHANNELS = 1 if USE_AMPLITUDE else 0
BASE_IN_CHANNELS = 6 + POS_CHANNELS + AMP_CHANNELS

# Run few epochs mostly to test code
if DEBUG_RUN:
    print("DEBUG RUN: small dataset/epochs for quick test")
    NUM_EPOCHS = 2
    BATCH_SIZE = 4
    EVAL_INTERVAL = 1
    SAVE_INTERVAL = 1

Device: cuda


## 3. Utils

Positional encodings and amplitude helpers

In [4]:
# Fourier positional encoding and local amplitude map
def get_2d_fourier_pos_enc(H: int, W: int, n_freqs: int = NUM_POS_FREQ, base: float = 2.0, device=None) -> torch.Tensor:
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ys = torch.linspace(0.0, 1.0, H, device=device)
    xs = torch.linspace(0.0, 1.0, W, device=device)
    grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij')  # [H,W]
    coords = torch.stack([grid_x, grid_y], dim=-1)           # [H,W,2]
    freqs = (base ** torch.arange(n_freqs, device=device)) * math.pi
    encs = []
    for f in freqs:
        encs.append(torch.sin(coords * f))
        encs.append(torch.cos(coords * f))
    enc = torch.cat(encs, dim=-1)   # [H,W,4*n_freqs]
    return enc.permute(2, 0, 1).unsqueeze(0)  # [1,C,H,W]


def compute_local_amplitude(x: torch.Tensor, kernel_size: int = AMP_KERNEL, eps: float = 1e-6) -> torch.Tensor:
    # x in [0,1], returns [B,1,H,W] normalized amplitude map
    mu = F.avg_pool2d(x, kernel_size, stride=1, padding=kernel_size//2)
    mu2 = F.avg_pool2d(x * x, kernel_size, stride=1, padding=kernel_size//2)
    var = (mu2 - mu * mu).clamp(min=0.0)
    amp = torch.sqrt(var.mean(dim=1, keepdim=True) + eps)
    B = amp.shape[0]
    flat = amp.view(B, -1)
    amin = flat.min(dim=1)[0].view(B,1,1,1)
    amax = flat.max(dim=1)[0].view(B,1,1,1)
    amp = (amp - amin) / (amax - amin + eps)
    return amp


def multiscale_loss(pred: torch.Tensor, target: torch.Tensor, scales=MULTISCALE_SCALES, spatial_loss_fn=None, freq_weight: float = FREQ_LOSS_WEIGHT, spatial_weight: float = 1.0) -> torch.Tensor:
    # combined L1 spatial + DCT L1 across scales, expects pred/target in [-1,1]
    if spatial_loss_fn is None:
        spatial_loss_fn = nn.L1Loss()
    total = 0.0
    total_weight = 0.0
    pred01 = (pred + 1.0) / 2.0
    target01 = (target + 1.0) / 2.0
    for s in scales:
        if s == 1:
            p_s, t_s = pred01, target01
        else:
            p_s = F.interpolate(pred01, scale_factor=1.0/s, mode='bilinear', align_corners=False)
            t_s = F.interpolate(target01, scale_factor=1.0/s, mode='bilinear', align_corners=False)
        l_spatial = spatial_loss_fn(p_s, t_s)
        p_dct = dct_2d(p_s, norm='ortho')
        t_dct = dct_2d(t_s, norm='ortho')
        l_freq = spatial_loss_fn(p_dct, t_dct)
        weight = 1.0 / float(s)
        total += weight * (spatial_weight * l_spatial + freq_weight * l_freq)
        total_weight += weight
    return total / (total_weight + 1e-12)


VGG perceptual loss helpers

In [5]:
# load VGG features if requested
def load_vgg_features(model_name: str = "vgg19_bn", pretrained: bool = True, device: torch.device = torch.device("cpu")):
    try:
        if model_name == "vgg19_bn":
            from torchvision.models import vgg19_bn, VGG19_BN_Weights
            weights = VGG19_BN_Weights.DEFAULT if pretrained else None
            model = vgg19_bn(weights=weights).features.to(device).eval()
        else:
            from torchvision.models import vgg16_bn, VGG16_BN_Weights
            weights = VGG16_BN_Weights.DEFAULT if pretrained else None
            model = vgg16_bn(weights=weights).features.to(device).eval()
    except Exception:
        if model_name == "vgg19_bn":
            from torchvision.models import vgg19_bn
            model = vgg19_bn(pretrained=pretrained).features.to(device).eval()
        else:
            from torchvision.models import vgg16_bn
            model = vgg16_bn(pretrained=pretrained).features.to(device).eval()
    for p in model.parameters():
        p.requires_grad = False
    return model


if USE_VGG_PERCEPTUAL:
    vgg = load_vgg_features(model_name="vgg19_bn", pretrained=True, device=DEVICE)
    def vgg_features(x):
        mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1,3,1,1)
        std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1,3,1,1)
        x_norm = (x - mean) / std
        return vgg(x_norm)
    def perceptual_loss(pred, target):
        p = (pred + 1.0) / 2.0
        t = (target + 1.0) / 2.0
        return F.mse_loss(vgg_features(p), vgg_features(t))
else:
    def perceptual_loss(pred, target): return 0.0

## 4. Data

In [None]:
# DIV2K dataset that returns HR and upsampled LR (y0)
class DIV2KDataset(Dataset):
    def __init__(self, hr_dir: str, crop_size: int = 256, scale: int = SCALE):
        self.paths = sorted(glob(os.path.join(hr_dir, "*.png")))
        self.crop_size = crop_size
        self.scale = scale
        self.transform = T.Compose([T.ToTensor(), T.Normalize(mean=[0.5]*3, std=[0.5]*3)])  # [0,1] -> [-1,1]
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        hr_image = Image.open(self.paths[idx]).convert('RGB')
        w,h = hr_image.size
        if w >= self.crop_size and h >= self.crop_size:
            x = random.randint(0, w - self.crop_size)
            y = random.randint(0, h - self.crop_size)
            hr_patch = hr_image.crop((x,y,x+self.crop_size,y+self.crop_size))
        else:
            hr_patch = hr_image.resize((self.crop_size,self.crop_size), Image.BICUBIC)
        lr_w, lr_h = self.crop_size // self.scale, self.crop_size // self.scale
        lr_patch = hr_patch.resize((lr_w, lr_h), Image.BICUBIC)
        y0_patch = lr_patch.resize((self.crop_size, self.crop_size), Image.BICUBIC)
        hr_tensor = self.transform(hr_patch)
        y0_tensor = self.transform(y0_patch)
        return hr_tensor, y0_tensor


# Dataloaders and positional encoding cache
train_dataset = DIV2KDataset(TRAIN_HR_DIR, crop_size=256, scale=SCALE)
valid_dataset = DIV2KDataset(VALID_HR_DIR, crop_size=256, scale=SCALE)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
if DEBUG_RUN:
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)
    valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

POS_ENC_CACHE = None
def ensure_pos_enc(H: int, W: int):
    global POS_ENC_CACHE
    if POS_ENC_CACHE is None or POS_ENC_CACHE.shape[-2:] != (H,W):
        POS_ENC_CACHE = get_2d_fourier_pos_enc(H, W, n_freqs=NUM_POS_FREQ, device=DEVICE)
    return POS_ENC_CACHE

## 5. Network

Simple UNet with Spatial and Freqeuncy/Wavelet domain blocks

In [7]:
# Network blocks
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim:int):
        super().__init__()
        self.dim = dim
    def forward(self, t: torch.Tensor):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        freqs = torch.exp(torch.arange(half_dim, device=device) * -emb)
        vals = t[:, None] * freqs[None, :]
        return torch.cat((vals.sin(), vals.cos()), dim=-1)


class ConvBlock(nn.Module):
    def __init__(self, in_channels:int, out_channels:int):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.conv(x)


class DCTBranch(nn.Module):
    def __init__(self, channels:int):
        super().__init__()
        self.process = nn.Sequential(nn.Conv2d(channels, channels, 1), nn.ReLU(inplace=True), nn.Conv2d(channels, channels, 1))
    def forward(self, x):
        x_dct = dct_2d(x, norm='ortho')
        x_dct = self.process(x_dct)
        return idct_2d(x_dct, norm='ortho')


class DWTBranch(nn.Module):
    def __init__(self, channels:int):
        super().__init__()
        self.dwt = DWTForward(J=1, wave='haar', mode='reflect')
        self.idwt = DWTInverse(wave='haar', mode='reflect')
        self.ll_process = nn.Conv2d(channels, channels, 1)
        self.lh_process = nn.Conv2d(channels, channels, 1)
        self.hl_process = nn.Conv2d(channels, channels, 1)
        self.hh_process = nn.Conv2d(channels, channels, 1)
    def forward(self, x):
        yl, yh = self.dwt(x)
        yl = self.ll_process(yl)
        lh_proc = self.lh_process(yh[0][:, :, 0, :, :])
        hl_proc = self.hl_process(yh[0][:, :, 1, :, :])
        hh_proc = self.hh_process(yh[0][:, :, 2, :, :])
        yh[0] = torch.stack([lh_proc, hl_proc, hh_proc], dim=2)
        return self.idwt((yl, yh))


class DualDomainBlock(nn.Module):
    def __init__(self, channels:int):
        super().__init__()
        self.spatial_branch = ConvBlock(channels, channels)
        self.dct_branch = DCTBranch(channels)
        self.dwt_branch = DWTBranch(channels)
        self.fusion = nn.Conv2d(channels * 3, channels, kernel_size=1)
    def forward(self, x):
        x_spatial = self.spatial_branch(x)
        x_dct = self.dct_branch(x)
        x_dwt = self.dwt_branch(x)
        combined = torch.cat([x_spatial, x_dct, x_dwt], dim=1)
        fused = self.fusion(combined)
        return x + fused


# UNet with time conditioning
class DualDomainUNet(nn.Module):
    def __init__(self, in_channels:int = BASE_IN_CHANNELS, base_channels:int = NUM_CHANNELS, time_dim:int = 128):
        super().__init__()
        self.time_mlp = nn.Sequential(SinusoidalPosEmb(time_dim), nn.Linear(time_dim, time_dim * 4), nn.GELU(), nn.Linear(time_dim * 4, base_channels))
        self.inc = ConvBlock(in_channels, base_channels)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DualDomainBlock(base_channels))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DualDomainBlock(base_channels))
        self.bot = DualDomainBlock(base_channels)
        self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(base_channels, base_channels, 1))
        self.dec1 = DualDomainBlock(base_channels * 2)
        self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(base_channels * 2, base_channels, 1))
        self.dec2 = DualDomainBlock(base_channels * 2)
        self.outc = nn.Conv2d(base_channels * 2, 3, kernel_size=1)
    def forward(self, xt: torch.Tensor, y0: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        if t.dim() > 1: t = t.squeeze(1)
        t = t.to(xt.device).float()
        t_emb = self.time_mlp(t).unsqueeze(-1).unsqueeze(-1)
        x_in = torch.cat([xt, y0], dim=1)
        x1 = self.inc(x_in)
        x1 = x1 + t_emb
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x_bot = self.bot(x3)
        x = self.up1(x_bot)
        x = self.dec1(torch.cat([x, x2], dim=1))
        x = self.up2(x)
        x = self.dec2(torch.cat([x, x1], dim=1))
        return self.outc(x)

In [8]:
# Noise schedule
def make_eta_schedule(T: int, p: float = P_SCHEDULE, kappa: float = KAPPA) -> np.ndarray:
    eta = np.zeros(T + 1, dtype=np.float32)
    eta[1] = ETA_1_FIXED
    b0 = np.exp(0.5 / (T - 1) * np.log(0.999 / eta[1]))
    for t in range(2, T + 1):
        exponent = ((t-1) / (T-1)) ** p
        eta[t] = (np.sqrt(eta[1]) * (b0 ** ((T - 1) * exponent)))**2
    return eta
eta = torch.from_numpy(make_eta_schedule(T_STEPS, p=P_SCHEDULE, kappa=KAPPA)).to(DEVICE)
alpha = torch.diff(eta, prepend=torch.tensor([0.0], device=DEVICE))


# Augment y0
def augment_y0(y0: torch.Tensor) -> torch.Tensor:
    B,C,H,W = y0.shape
    extras = []
    if USE_FOURIER:
        pos = ensure_pos_enc(H, W).to(y0.device)
        extras.append(pos.repeat(B,1,1,1))
    if USE_AMPLITUDE:
        y0_01 = (y0 + 1.0) / 2.0
        amp = compute_local_amplitude(y0_01, kernel_size=AMP_KERNEL)
        extras.append(amp.to(y0.device))
    return torch.cat([y0] + extras, dim=1) if extras else y0


# Sampling helpers
def sample_xt(x0: torch.Tensor, y0: torch.Tensor, t_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    eta_t = eta[t_idx]
    mean = (1 - eta_t) * x0 + eta_t * y0
    noise = torch.randn_like(x0)
    xt = mean + KAPPA * torch.sqrt(eta_t) * noise
    return xt, noise


@torch.no_grad()
def reverse_sample_augmented(y0: torch.Tensor) -> Tuple[torch.Tensor, list]:
    model.eval()
    chain_steps = []
    x_t = y0 + KAPPA * torch.sqrt(eta[T_STEPS]) * torch.randn_like(y0)
    chain_steps.append(x_t)

    for t in range(T_STEPS, 0, -1):
        t_tensor = torch.full((y0.shape[0],), float(t), device=DEVICE)
        y0_aug = augment_y0(y0)
        predicted_residual = model(x_t, y0_aug, t_tensor)
        predicted_x0 = (y0 + predicted_residual).clamp(-1.0, 1.0)  # learn residuals
        eta_t, eta_t_1, alpha_t = eta[t], eta[t-1], alpha[t]
        term1 = (eta_t_1 / eta_t) * x_t
        term2 = (alpha_t / eta_t) * (1 - eta_t_1) * predicted_x0
        term3 = (alpha_t * eta_t_1 / eta_t) * y0
        mu = term1 + term2 + term3
        variance = (KAPPA ** 2) * (eta_t_1 / eta_t) * alpha_t
        x_t = mu + torch.sqrt(variance) * torch.randn_like(x_t) if t > 1 else mu
        if t == (T_STEPS // 2) + 1:
            chain_steps.append(x_t)

    final_sr = (x_t.clamp(-1,1) + 1) / 2.0
    processed_chain = [(step.clamp(-1,1) + 1)/2.0 for step in chain_steps]
    return final_sr, processed_chain

# Evaluation and plotting
@torch.no_grad()
def evaluate_and_plot(model, valid_loader, epoch, n_examples=3):
    model.eval()
    total_psnr_sr, total_psnr_bic, total_ssim = 0.0, 0.0, 0.0
    plot_data = []

    for i, (hr_img, y0_img) in enumerate(valid_loader):
        hr, y0 = hr_img.to(DEVICE), y0_img.to(DEVICE)
        sr_tensor, _ = reverse_sample_augmented(y0)

        hr_np = (hr.squeeze(0).cpu().permute(1,2,0).numpy() + 1)/2.0
        y0_np = (y0.squeeze(0).cpu().permute(1,2,0).numpy() + 1)/2.0
        sr_np = sr_tensor.squeeze(0).cpu().permute(1,2,0).numpy()

        psnr_sr = compare_psnr(hr_np, sr_np, data_range=1.0)
        psnr_bic = compare_psnr(hr_np, y0_np, data_range=1.0)

        win_size = min(7, min(hr_np.shape[:2]))
        ssim = compare_ssim(hr_np, sr_np, data_range=1.0, channel_axis=2, win_size=win_size)

        total_psnr_sr += psnr_sr
        total_psnr_bic += psnr_bic
        total_ssim += ssim

        if i < n_examples:
            plot_data.append({'hr':hr_np, 'lr':y0_np, 'sr':sr_np, 'psnr_sr':psnr_sr, 'ssim':ssim, 'psnr_bic':psnr_bic})

    n = len(valid_loader)
    print(f"Eval Epoch {epoch}: PSNR(SR)={total_psnr_sr/n:.4f}, PSNR(Bic)={total_psnr_bic/n:.4f}, SSIM={total_ssim/n:.4f}")

    fig, axes = plt.subplots(len(plot_data), 3, figsize=(15,5*len(plot_data)))
    if len(plot_data) == 1:
        axes = [axes]

    for i, d in enumerate(plot_data):
        ax_row = axes[i]
        ax_row[0].imshow(d['lr']); ax_row[0].set_title("LR (bic) {:.2f} dB".format(d["psnr_bic"]))
        ax_row[1].imshow(d['sr']); ax_row[1].set_title("SR (final) PSNR: {:.2f} dB, SSIM: {:.4f}".format(d["psnr_sr"], d["ssim"]))
        ax_row[2].imshow(d['hr']); ax_row[2].set_title("HR (Ground Truth)")
        for ax in ax_row:
            ax.set_xticks([]); ax.set_yticks([])

    plt.tight_layout()
    out_path = os.path.join(SAVE_DIR, "eval_epoch_{}.png".format(epoch))
    plt.savefig(out_path); plt.close()
    print("Saved eval plot to", out_path)

## 6. Training

Instantiate model, optimizer, scheduler

In [9]:
model     = DualDomainUNet(in_channels=BASE_IN_CHANNELS, base_channels=NUM_CHANNELS).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-7)

Ceckpoint loading (from last saved weights)

In [None]:
RESUME_FROM = None  # or set to path to resume - wieghts name

def save_checkpoint(epoch: int):
    ck = os.path.join(SAVE_DIR, "resshift_epoch_{}.pt".format(epoch))
    torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'scheduler_state': scheduler.state_dict()}, ck)
    print("Saved checkpoint:", ck)


def find_latest_checkpoint() -> str | None:
    candidates = glob(os.path.join(SAVE_DIR, "resshift_epoch_*.pt"))
    if not candidates:
        return None
    best = None
    best_epoch = -1
    for p in candidates:
        m = re.search(r"resshift_epoch_(\d+)\.pt$", p)
        if m:
            e = int(m.group(1))
            if e > best_epoch:
                best_epoch = e
                best = p
    return best


def _move_optimizer_state_to_device(opt_state, device):
    # move optimizer state tensors to device
    for state in list(opt_state.values()):
        for k, v in list(state.items()):
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)


def load_checkpoint(path: str):
    ck = torch.load(path, map_location=DEVICE)
    mfile = re.search(r"resshift_epoch_(\d+)\.pt$", path)
    epoch_from_filename = int(mfile.group(1)) if mfile else None

    start_epoch = 1
    if isinstance(ck, dict):
        model_state = None
        if 'model_state' in ck:
            model_state = ck['model_state']
        elif 'model_state_dict' in ck:
            model_state = ck['model_state_dict']
        elif 'state_dict' in ck:
            model_state = ck['state_dict']

        if model_state is not None:
            model.load_state_dict(model_state)
            if 'optimizer_state' in ck:
                try:
                    optimizer.load_state_dict(ck['optimizer_state'])
                    _move_optimizer_state_to_device(optimizer.state, DEVICE)
                except Exception as e:
                    print("Warning: couldn't load optimizer_state cleanly:", e)
            if 'scheduler_state' in ck:
                try:
                    scheduler.load_state_dict(ck['scheduler_state'])
                except Exception as e:
                    print("Warning: couldn't load scheduler_state cleanly:", e)
            if 'epoch' in ck:
                start_epoch = int(ck.get('epoch', 0)) + 1
            elif epoch_from_filename is not None:
                start_epoch = epoch_from_filename + 1
            else:
                start_epoch = 1
        else:
            # maybe ck is a raw state_dict
            try:
                model.load_state_dict(ck)
                start_epoch = epoch_from_filename + 1 if epoch_from_filename is not None else 1
            except Exception as e:
                raise RuntimeError(f"Unrecognized checkpoint format for '{path}': {e}")
    else:
        try:
            model.load_state_dict(ck)
            start_epoch = epoch_from_filename + 1 if epoch_from_filename is not None else 1
        except Exception as e:
            raise RuntimeError(f"Unrecognized checkpoint format for '{path}': {e}")

    scheduler.last_epoch = max(0, start_epoch - 1)
    print(f"Loaded checkpoint '{path}'. Resuming at epoch {start_epoch}.")
    return start_epoch


start_epoch = 1
latest = RESUME_FROM if RESUME_FROM else find_latest_checkpoint()
if latest is not None:
    try:
        start_epoch = load_checkpoint(latest)
    except Exception as e:
        print("Failed to load checkpoint:", e)
        print("Starting from scratch.")
        start_epoch = 1
else:
    print("No checkpoint found. Starting from scratch.")

Loaded checkpoint 'checkpoints/resshift_epoch_370.pt'. Resuming at epoch 371.


Training loop

In [11]:
def run_training():
    print("Starting training...")
    global start_epoch
    for epoch in range(start_epoch, NUM_EPOCHS+1):
        model.train()
        total_loss = 0.0
        pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")

        for batch_idx, (x0_batch, y0_batch) in enumerate(pbar):
            x0, y0 = x0_batch.to(DEVICE), y0_batch.to(DEVICE)
            t_indices = torch.randint(1, T_STEPS + 1, (x0.shape[0],), device=DEVICE)

            xt_list, noise_list = [], []
            for i, t in enumerate(t_indices):
                xt, noise = sample_xt(x0[i:i+1], y0[i:i+1], t.item())
                xt_list.append(xt)
                noise_list.append(noise)
            xt = torch.cat(xt_list, dim=0)

            y0_aug = augment_y0(y0)

            # model predicts residual (x0 - y0)
            predicted_residual = model(xt, y0_aug, t_indices.float())
            predicted_x0 = y0 + predicted_residual

            loss_ms = multiscale_loss(predicted_x0, x0, scales=MULTISCALE_SCALES, spatial_loss_fn=nn.L1Loss(), freq_weight=FREQ_LOSS_WEIGHT)
            loss_perc = PERCEPTUAL_WEIGHT * perceptual_loss(predicted_x0, x0) if USE_VGG_PERCEPTUAL else 0.0
            loss = loss_ms + loss_perc

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch} done. Avg loss: {avg_loss:.5f}")

        scheduler.step()

        if epoch % EVAL_INTERVAL == 0 and epoch > 0:
            evaluate_and_plot(model, valid_dataloader, epoch)

        if epoch % SAVE_INTERVAL == 0 and epoch > 0:
            save_checkpoint(epoch)

    final_ck = os.path.join(SAVE_DIR, "resshift_final.pt")
    torch.save({
        'epoch': NUM_EPOCHS,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'scheduler_state': scheduler.state_dict(),
    }, final_ck)
    print("Training finished. Final model saved to", final_ck)

## 7. Run training and evaluate progressively 

In [None]:
print("Ready. Model summary:")
with torch.no_grad():
    xt_test = torch.randn(1,3,256,256, device=DEVICE)
    y0_test = torch.randn(1,3,256,256, device=DEVICE)
    ensure_pos_enc(256, 256)
    y0_aug_test = augment_y0(y0_test)
    t_test = torch.tensor([1.0], device=DEVICE)
    out = model(xt_test, y0_aug_test, t_test)
    print("Forward pass OK. Output shape:", out.shape)

run_training()

Ready. Model summary:
Forward pass OK. Output shape: torch.Size([1, 3, 256, 256])
Starting training...


Epoch 371/380:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 371 done. Avg loss: 0.06581


Epoch 372/380:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 372 done. Avg loss: 0.06866


Epoch 373/380:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 373 done. Avg loss: 0.06735


Epoch 374/380:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 374 done. Avg loss: 0.06663


Epoch 375/380:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 375 done. Avg loss: 0.06684


Epoch 376/380:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 376 done. Avg loss: 0.06781


Epoch 377/380:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 377 done. Avg loss: 0.06705


Epoch 378/380:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 378 done. Avg loss: 0.06889


Epoch 379/380:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 379 done. Avg loss: 0.06805


Epoch 380/380:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 380 done. Avg loss: 0.06918
Eval Epoch 380: PSNR(SR)=28.7265, PSNR(Bic)=28.1391, SSIM=0.7826
Saved eval plot to checkpoints/eval_epoch_380.png
Saved checkpoint: checkpoints/resshift_epoch_380.pt
Training finished. Final model saved to checkpoints/resshift_final.pt


---

### Limitations and reflections

1. Early training often shows color shifts or wrong channel mapping during the first few epochs. This typically results from inconsistent normalization or from accidentally changing channel order between data loading, augmentation and model input.

2. Checkerboard artifacts and ringing appear when high frequency terms are learned aggressively or when upsampling is implemented with transposed convolutions.

3. Training stability is a recurring problem. runs have shown exploding gradients in early stages and stalled or vanishing gradients in later stages where the model appears not to learn.

4. The plain L1 objective can bias the model to reproduce the bicubic upsample because that solution is already close in pixel space.

5. Dataset composition is a limiting factor. DIV2K contains many low-texture or blurred regions that make selective sharpening difficult and that increase variance in evaluation.

6. Memory, batch size and throughput are practical constraints on a 16 GB GPU. small batches limit optimizer behavior and reduce the number of unique samples seen per epoch.

7. Matching spatial and frequency or wavelet features is difficult in practice. fused features may conflict and the network can favor one domain over the other, causing inconsistent restorations.

8. Artifacts observed include jagged edges, blurred textures, edge ringing and mismatched textures in uniform-color regions.

9. The eta schedule for the Markov chain and the kappa coefficient are sensitive hyperparameters. small changes can noticeably alter sampling dynamics and final outputs.

10. The model capacity chosen here (base channels = 64) is small and can struggle to generalize to complex textures while still being prone to underfitting in some regions.

11. Small differences between training and evaluation implementations can produce large gaps in results. common mismatches include normalization and channel order, different up and down sampling kernels, inconsistent padding modes for convolutions and wavelet transforms, running networks in train mode at evaluation time, and precision differences between mixed precision training and full precision evaluation.

12. Paper reproducibility often fails on minor implementation details that are not explicitly stated. Other difficulties include implemented the stated details.

13. PSNR and SSIM do not always correlate with perceived image quality. PSNR measures mean squared error in log scale and SSIM measures local structural similarity, but both can favor overly smooth outputs that lack realistic high frequency detail; perceptual quality can degrade even when these metrics improve. Good examples are good SR images getting lower scores compared to their LR bicubic upsampled equivalent. 

14. Final super-resolved images sometimes remain blurry or exhibit jagged edges despite reasonable numeric metrics.

---

### Practical fixes and experiment ideas

1. Implement a training schedule that starts with only L1 for the first N epochs, then gradually adds perceptual and frequency losses. 

2. Test reducing frequency loss to 0.0 at early steps and linearly increasing it. 

3. For upsampling replace transposed conv with nn.Upsample followed by Conv2d. add anti-alias filter in the downsampling path to avoid aliasing.

---

### Further references:

**Dataset**: [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K) \
**VGG Perceptual Loss**: [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) \
**NeRF Positional Encodings**: [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](https://arxiv.org/abs/2003.08934) \
**UNet**: [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)

---

### Instructions:

**Reproducibility**: Download the [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K) dataset and create the folder **datasets/DIV2K** next to the notebook. Put **training HR** images into **datasets/DIV2K/train/HR/** and **validation HR** images into **datasets/DIV2K/valid/HR/**. The notebook expects PNG files in those folders. The training loop writes checkpoints under checkpoints/ and will resume automatically if it finds compatible files. The notebook has been developed with an **Nvidia RTX 5060 Ti with 16 GB VRAM** in mind for the reported timings and memory behavior. If you use a different GPU, reduce batch size and base channel width accordingly to avoid out of memory errors.

To run the notebook, open the notebook and execute the cells in order. If you want a shorter test use the debug flag to reduce epochs and batch size. To restore a training run, place a checkpoint file named like **resshift_epoch_{N}.pt** in checkpoints/ then rerun the notebook; it will auto-detect the latest one and resume. For evaluation, run the provided evaluation cell which will generate comparison images showing LR (bicubic), SR, and HR and save them to the checkpoints folder.

Project structure example:

```
. \
├── efficient_diffusion_super_res.ipynb                 # main notebook with training and evaluation \
├── efficient_diffusion_super_res_reformatted.py        # clean script version of notebook for headless runs \
├── checkpoints/                                        # saved checkpoints and eval images \
├── datasets/ \
│   └── DIV2K/ \
│       ├── train/HR/                                   # 800 HR images \
│       └── valid/HR/                                   # 200 HR images \
└── README.md 
```