# HNDSR: A Hybrid Neural Operator–Diffusion Model for Continuous-Scale Satellite Image Super-Resolution

**Kaggle Training Notebook — Extended Epochs & Optimised Pipeline**

> **Authors:** Adil Khan, Rakshit Modanwal, Harsh Vardhan, Piyush Jain, Yash Vikram  
> **Institution:** Indian Institute of Information Technology, Nagpur  
> **Dataset:** [4× Satellite Image Super-Resolution](https://www.kaggle.com/datasets/cristobaltudela/4x-satellite-image-super-resolution)

---

### Overview

HNDSR is a hybrid super-resolution framework that fuses **Neural Operators** (for continuous-scale awareness) with **Latent Diffusion Models** (for high-fidelity texture generation). The architecture trains in three sequential stages:

| Stage | Component | Purpose | Loss |
|------:|-----------|---------|------|
| 1 | **Latent Autoencoder** | Learn a compact latent space from HR images | L1 (reconstruction) |
| 2 | **Fourier Neural Operator** | Map LR → HR latents with scale-invariance | MSE (latent matching) |
| 3 | **Diffusion UNet** | Refine high-frequency details via iterative denoising | MSE (noise prediction) |

$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{AE}} + \lambda_{\text{NO}} \cdot \mathcal{L}_{\text{NO}} + \lambda_{\text{diff}} \cdot \mathcal{L}_{\text{diff}}$$

### Key Results (4× upscaling)

| Method | PSNR ↑ | SSIM ↑ | LPIPS ↓ |
|--------|--------|--------|---------|
| Bicubic | 24.53 | 0.71 | 0.35 |
| EDSR | 26.81 | 0.79 | 0.28 |
| ESRGAN | 27.14 | 0.81 | 0.24 |
| E²DiffSR | 28.72 | 0.85 | 0.18 |
| **HNDSR (Ours)** | **29.40** | **0.87** | **0.16** |

### Notebook Features
- Saves all checkpoints to `/kaggle/working/` (persists as notebook output)
- Configurable epoch counts per stage — defaults are **2× the original**
- **Gradient clipping**, **warmup + cosine annealing**, **early stopping**
- **EMA (Exponential Moving Average)** on diffusion weights
- Extended data augmentation (horizontal + vertical flip, 90° rotations)
- Resume support if kernel restarts — just set `START_FROM_STAGE`

---
## 1. Environment Setup

Install all dependencies and verify GPU availability. Kaggle provides free **P100 (16 GB)** or **T4 (16 GB)** GPUs.

In [None]:
!pip install -q torch torchvision lpips timm einops scikit-image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image

import numpy as np
import math
import os
import gc
import random
import time
import copy
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import lpips
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt

# ---- Reproducibility ----
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False   # deterministic but slightly slower

set_seed(42)

torch.cuda.empty_cache()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device       : {device}")
if torch.cuda.is_available():
    print(f"GPU          : {torch.cuda.get_device_name(0)}")
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU Memory   : {gpu_mem:.2f} GB")
    print(f"CUDA version : {torch.version.cuda}")
print(f"PyTorch      : {torch.__version__}")

---
## 2. Training Configuration

All hyper-parameters in one place. Edit the epoch counts below to train longer.

> **Tip:** Kaggle gives ~12 h of GPU per session. With `BATCH_SIZE=2` the full 130-epoch
> pipeline fits within that budget on this dataset.

In [None]:
# =================================================================
#  TRAINING CONFIGURATION  —  EDIT THESE VALUES
# =================================================================

# Epoch counts per stage  (original → new default)
STAGE1_EPOCHS = 40       # Autoencoder       (was 20)
STAGE2_EPOCHS = 30       # Neural Operator   (was 15)
STAGE3_EPOCHS = 60       # Diffusion Model   (was 30)

# Optimiser & data
BATCH_SIZE    = 2        # reduce to 1 if OOM on 16 GB GPU
PATCH_SIZE    = 64       # HR crop size (must be divisible by 8)
LEARNING_RATE = 1e-4     # peak LR (after warmup)
WEIGHT_DECAY  = 1e-4
NUM_WORKERS   = 2        # Kaggle max = 2

# Gradient clipping  (stabilises early training)
MAX_GRAD_NORM = 1.0

# Early stopping  (patience = epochs with no val improvement)
EARLY_STOP_PATIENCE = 10  # set to 999 to disable

# Warmup  (linear LR warmup before cosine decay)
WARMUP_EPOCHS = 3

# EMA  (Exponential Moving Average for diffusion weights)
EMA_DECAY = 0.999

# Resume from a specific stage  (1 = train all stages from scratch)
START_FROM_STAGE = 1

# Auto-resume: automatically pick up from the last saved epoch
# within a stage. Set to True when re-running after a session timeout.
AUTO_RESUME = True

# =================================================================
#  PATHS  —  auto-configured for Kaggle
# =================================================================
SAVE_DIR             = '/kaggle/working'
AUTOENCODER_PATH     = os.path.join(SAVE_DIR, 'autoencoder_best.pth')
NEURAL_OPERATOR_PATH = os.path.join(SAVE_DIR, 'neural_operator_best.pth')
DIFFUSION_PATH       = os.path.join(SAVE_DIR, 'diffusion_best.pth')
COMPLETE_MODEL_PATH  = os.path.join(SAVE_DIR, 'hndsr_complete.pth')
EVAL_RESULTS_DIR     = os.path.join(SAVE_DIR, 'evaluation_results')

# Epoch-level resume checkpoints (saved every epoch so you never
# lose more than 1 epoch of work if the session is interrupted)
STAGE1_RESUME_PATH   = os.path.join(SAVE_DIR, 'stage1_resume.pth')
STAGE2_RESUME_PATH   = os.path.join(SAVE_DIR, 'stage2_resume.pth')
STAGE3_RESUME_PATH   = os.path.join(SAVE_DIR, 'stage3_resume.pth')

# =================================================================
print(f"Stage 1 epochs : {STAGE1_EPOCHS}")
print(f"Stage 2 epochs : {STAGE2_EPOCHS}")
print(f"Stage 3 epochs : {STAGE3_EPOCHS}")
print(f"Total epochs   : {STAGE1_EPOCHS + STAGE2_EPOCHS + STAGE3_EPOCHS}")
print(f"Batch size     : {BATCH_SIZE}")
print(f"Grad clip norm : {MAX_GRAD_NORM}")
print(f"Early-stop     : {EARLY_STOP_PATIENCE} epochs patience")
print(f"EMA decay      : {EMA_DECAY}")
print(f"Auto-resume    : {AUTO_RESUME}")
print(f"Save directory : {SAVE_DIR}")

---
## 3. Dataset Detection

**Dataset:** [cristobaltudela/4x-satellite-image-super-resolution](https://www.kaggle.com/datasets/cristobaltudela/4x-satellite-image-super-resolution)

Before running this notebook, **add the dataset** to your Kaggle notebook:
1. Click **+ Add Data** (right panel) → search for `4x-satellite-image-super-resolution` → Add
2. It will appear under `/kaggle/input/`

The cell below auto-detects the HR and LR sub-directories.

In [None]:
def find_dataset():
    """Auto-detect HR / LR directories under /kaggle/input/"""
    kaggle_input = Path('/kaggle/input')
    if not kaggle_input.exists():
        print("Not running on Kaggle — set HR_DIR / LR_DIR manually below.")
        return None, None

    datasets = sorted([d for d in kaggle_input.iterdir() if d.is_dir()])
    print(f"Datasets found: {[d.name for d in datasets]}")

    # Try to find a dataset with relevant keywords
    sr_datasets = [d for d in datasets
                   if any(k in d.name.lower() for k in ('super', 'resolution', 'satellite', '4x'))]
    dataset_path = sr_datasets[0] if sr_datasets else (datasets[0] if datasets else None)

    if dataset_path is None:
        print("No dataset folder found — please add the dataset first.")
        return None, None

    print(f"Using: {dataset_path}")

    hr_dir = lr_dir = None
    img_exts = {'.png', '.jpg', '.jpeg', '.tif', '.tiff'}

    for item in sorted(dataset_path.rglob('*')):
        if not item.is_dir():
            continue
        name_lower = item.name.lower()
        imgs = [f for f in item.iterdir() if f.suffix.lower() in img_exts]
        if not imgs:
            continue
        if ('hr' in name_lower or 'high' in name_lower) and hr_dir is None:
            hr_dir = str(item)
            print(f"  HR directory : {item}  ({len(imgs)} images)")
        elif ('lr' in name_lower or 'low' in name_lower) and lr_dir is None:
            lr_dir = str(item)
            print(f"  LR directory : {item}  ({len(imgs)} images)")

    return hr_dir, lr_dir

HR_DIR, LR_DIR = find_dataset()

# ── MANUAL OVERRIDE (uncomment if auto-detect fails) ──────────────
# HR_DIR = '/kaggle/input/4x-satellite-image-super-resolution/HR_0.5m'
# LR_DIR = '/kaggle/input/4x-satellite-image-super-resolution/LR_2m'
# ───────────────────────────────────────────────────────────────────

assert HR_DIR and LR_DIR, (
    "Dataset paths not found.  Please:\n"
    "  1. Add 'cristobaltudela/4x-satellite-image-super-resolution' via + Add Data\n"
    "  2. Or uncomment the MANUAL OVERRIDE lines above and set the correct paths."
)
print(f"\nHR_DIR = {HR_DIR}")
print(f"LR_DIR = {LR_DIR}")

---
## 4. Dataset & Augmentation

The dataset contains paired satellite images at two GSD (Ground Sampling Distance) levels:

| Folder | Resolution | GSD |
|--------|-----------|-----|
| `HR_0.5m/` | High resolution | 0.5 m/pixel |
| `LR_2m/`   | Low resolution  | 2.0 m/pixel |

This gives a natural **4× scale factor**.

**Augmentations** applied during training:
- Random crop (64×64 HR patch ↔ 16×16 LR patch)
- Horizontal flip (p = 0.5)
- Vertical flip (p = 0.5)
- 90° rotation (p = 0.5)

In [None]:
class SatelliteDataset(Dataset):
    """Paired satellite SR dataset with augmentation."""

    def __init__(self, hr_dir, lr_dir, patch_size=64, training=True):
        self.hr_dir    = Path(hr_dir)
        self.lr_dir    = Path(lr_dir)
        self.patch_size = patch_size
        self.training   = training

        # Gather images
        exts = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff',
                '*.PNG', '*.JPG', '*.JPEG', '*.TIF', '*.TIFF']
        self.hr_images, self.lr_images = [], []
        for ext in exts:
            self.hr_images.extend(self.hr_dir.glob(ext))
            self.lr_images.extend(self.lr_dir.glob(ext))
        self.hr_images = sorted(self.hr_images)
        self.lr_images = sorted(self.lr_images)

        if not self.hr_images or not self.lr_images:
            raise ValueError(f"No images found!  HR={hr_dir}  LR={lr_dir}")

        # Match by filename stem (fall back to positional pairing)
        hr_map = {img.stem: img for img in self.hr_images}
        lr_map = {img.stem: img for img in self.lr_images}
        common = sorted(set(hr_map) & set(lr_map))

        if common:
            self.hr_images = [hr_map[n] for n in common]
            self.lr_images = [lr_map[n] for n in common]
        else:
            n = min(len(self.hr_images), len(self.lr_images))
            self.hr_images = self.hr_images[:n]
            self.lr_images = self.lr_images[:n]

        print(f"{'Train' if training else 'Eval'} dataset: {len(self.hr_images)} pairs")

        self.to_tensor = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3),  # → [-1, 1]
        ])

    def __len__(self):
        return len(self.hr_images)

    def _augment(self, hr_pil, lr_pil):
        """Apply paired random augmentation."""
        # Horizontal flip
        if random.random() > 0.5:
            hr_pil = hr_pil.transpose(Image.FLIP_LEFT_RIGHT)
            lr_pil = lr_pil.transpose(Image.FLIP_LEFT_RIGHT)
        # Vertical flip
        if random.random() > 0.5:
            hr_pil = hr_pil.transpose(Image.FLIP_TOP_BOTTOM)
            lr_pil = lr_pil.transpose(Image.FLIP_TOP_BOTTOM)
        # 90° rotation
        if random.random() > 0.5:
            hr_pil = hr_pil.transpose(Image.ROTATE_90)
            lr_pil = lr_pil.transpose(Image.ROTATE_90)
        return hr_pil, lr_pil

    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_images[idx]).convert('RGB')
        lr_img = Image.open(self.lr_images[idx]).convert('RGB')

        if self.training:
            hr_w, hr_h = hr_img.size
            lr_w, lr_h = lr_img.size
            scale    = max(hr_w // lr_w, 1)
            lr_crop  = self.patch_size // scale

            # Random crop
            if lr_w > lr_crop and lr_h > lr_crop:
                x = random.randint(0, lr_w - lr_crop)
                y = random.randint(0, lr_h - lr_crop)
                lr_img = lr_img.crop((x, y, x + lr_crop, y + lr_crop))
                hr_img = hr_img.crop((x*scale, y*scale,
                                      (x+lr_crop)*scale, (y+lr_crop)*scale))

            hr_img, lr_img = self._augment(hr_img, lr_img)
        else:
            hr_img = transforms.CenterCrop(self.patch_size)(hr_img)
            lr_img = transforms.CenterCrop(self.patch_size // 4)(lr_img)

        return {
            'lr':    self.to_tensor(lr_img),
            'hr':    self.to_tensor(hr_img),
            'scale': 4,
        }

---
## 5. Model Architecture

### 5.1 Stage 1 — Latent Autoencoder

The autoencoder learns a **compressed latent representation** of HR satellite images.
It uses an encoder-decoder structure with **residual blocks** and a spatial
down-sampling ratio of 8× — mapping a 64×64 HR patch to an 8×8 latent tensor
with 128 channels.

$$z = E_\theta(x_\text{HR}), \quad \hat{x} = D_\theta(z), \quad \mathcal{L}_\text{AE} = \| \hat{x} - x_\text{HR} \|_1$$

### 5.2 Stage 2 — Fourier Neural Operator (FNO)

Neural Operators learn mappings between **function spaces** rather than finite
vectors, enabling continuous-scale super-resolution. The core primitive is the
**Spectral Convolution** layer:

$$(\mathcal{K}v)(x) = \mathcal{F}^{-1}\!\left( R \cdot \mathcal{F}(v) \right)(x)$$

where $\mathcal{F}$ is the 2-D FFT and $R$ is a learnable weight tensor applied
to the first $k$ Fourier modes. Four such layers are stacked, each followed by
a pointwise 1×1 convolution (local bypass) and GELU activation.

A **scale map** (constant channel encoding the target scale factor) is
concatenated to the input so the operator learns scale-aware features.

### 5.3 Implicit Amplification

A small MLP predicts per-channel gains $\gamma_c \in [0, 1]$ from the scale
factor and modulates the latent:

$$z' = z \odot (1 + \gamma)$$

This amplifies high-frequency components proportionally to the upscaling ratio.

### 5.4 Stage 3 — Latent Diffusion UNet

A lightweight UNet iteratively denoises a Gaussian noise sample $z_T$ into the
target HR latent $z_0$, conditioned on the Neural Operator prior via
**cross-attention**:

$$z_{t-1} = \text{DDIM}(z_t, \epsilon_\theta(z_t, t, c)), \quad c = \text{pool}(z'_\text{NO})$$

The UNet uses **sinusoidal positional time embeddings**, GroupNorm, and SiLU
activations — following the Stable Diffusion convention.

In [None]:
# =====================================================================
#  Stage 1 — Residual Block & Latent Autoencoder
# =====================================================================

class ResidualBlock(nn.Module):
    """Two-conv residual block used inside the autoencoder."""
    def __init__(self, channels, use_bn=False):
        super().__init__()
        layers = [nn.Conv2d(channels, channels, 3, padding=1),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(channels, channels, 3, padding=1)]
        if use_bn:
            layers.insert(1, nn.BatchNorm2d(channels))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return x + self.block(x)


class LatentAutoencoder(nn.Module):
    """Encoder-Decoder that maps 3-ch images to a compact latent z and back.

    Default: 64→128 channels, 3 down-samples (8× spatial reduction).
    """
    def __init__(self, in_channels=3, latent_dim=64, num_res_blocks=4, downsample_ratio=8):
        super().__init__()
        self.latent_dim = latent_dim
        self.downsample_ratio = downsample_ratio
        num_downs = int(math.log2(downsample_ratio))

        # Encoder
        enc = [nn.Conv2d(in_channels, latent_dim, 3, padding=1)]
        ch = latent_dim
        for _ in range(num_downs):
            out_ch = min(ch * 2, 128)
            enc += [nn.Conv2d(ch, out_ch, 4, stride=2, padding=1), nn.ReLU(True)]
            ch = out_ch
        for _ in range(num_res_blocks):
            enc.append(ResidualBlock(ch))
        self.encoder = nn.Sequential(*enc)

        # Decoder
        dec = [ResidualBlock(ch) for _ in range(num_res_blocks)]
        for _ in range(num_downs):
            out_ch = ch // 2
            dec += [nn.ConvTranspose2d(ch, out_ch, 4, stride=2, padding=1), nn.ReLU(True)]
            ch = out_ch
        dec += [nn.Conv2d(ch, in_channels, 3, padding=1), nn.Tanh()]
        self.decoder = nn.Sequential(*dec)

    def encode(self, x):  return self.encoder(x)
    def decode(self, z):  return self.decoder(z)

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z), z


# =====================================================================
#  Stage 2 — Fourier Neural Operator
# =====================================================================

class SpectralConv2d(nn.Module):
    """2-D spectral convolution — applies learnable weights in Fourier space.

    Forces float32 FFT to avoid cuFFT half-precision power-of-2 requirement.
    """
    def __init__(self, in_ch, out_ch, modes1, modes2):
        super().__init__()
        self.in_channels  = in_ch
        self.out_channels = out_ch
        self.modes1 = modes1
        self.modes2 = modes2
        scale = 1.0 / (in_ch * out_ch)
        self.weights1 = nn.Parameter(scale * torch.rand(in_ch, out_ch, modes1, modes2, 2))
        self.weights2 = nn.Parameter(scale * torch.rand(in_ch, out_ch, modes1, modes2, 2))

    def forward(self, x):
        x_dtype = x.dtype
        x = x.float()
        x_ft = torch.fft.rfft2(x)

        out_ft = torch.zeros(x.shape[0], self.out_channels, x.size(-2),
                             x.size(-1)//2+1, dtype=torch.cfloat, device=x.device)

        m1 = min(self.modes1, x.size(-2))
        m2 = min(self.modes2, x.size(-1)//2+1)

        if m1 > 0 and m2 > 0:
            w1 = torch.view_as_complex(self.weights1[:, :, :m1, :m2])
            w2 = torch.view_as_complex(self.weights2[:, :, :m1, :m2])
            out_ft[:, :, :m1,  :m2] = torch.einsum('bixy,ioxy->boxy', x_ft[:, :, :m1,  :m2], w1)
            out_ft[:, :, -m1:, :m2] = torch.einsum('bixy,ioxy->boxy', x_ft[:, :, -m1:, :m2], w2)

        x_out = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x_out.to(x_dtype) if x_dtype != torch.float32 else x_out


class NeuralOperator(nn.Module):
    """FNO with 4 spectral layers and scale-map conditioning."""
    def __init__(self, in_channels=3, out_channels=128, modes=8, width=32):
        super().__init__()
        self.fc0  = nn.Conv2d(in_channels + 1, width, 1)
        self.convs = nn.ModuleList([SpectralConv2d(width, width, modes, modes) for _ in range(4)])
        self.ws    = nn.ModuleList([nn.Conv2d(width, width, 1) for _ in range(4)])
        self.fc1  = nn.Conv2d(width, 64, 1)
        self.fc2  = nn.Conv2d(64, out_channels, 1)

    def forward(self, x, scale_factor):
        b, c, h, w = x.shape
        scale_map = torch.ones(b, 1, h, w, device=x.device) * (scale_factor / 4.0)
        x = self.fc0(torch.cat([x, scale_map], 1))
        for conv, w_conv in zip(self.convs, self.ws):
            x = F.gelu(conv(x) + w_conv(x))
        return self.fc2(F.gelu(self.fc1(x)))


# =====================================================================
#  Implicit Amplification MLP
# =====================================================================

class ImplicitAmplification(nn.Module):
    """Scale-conditioned channel gain predictor."""
    def __init__(self, latent_dim=128, hidden_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(1, hidden_dim), nn.ReLU(True),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(True),
            nn.Linear(hidden_dim, latent_dim), nn.Sigmoid(),
        )

    def forward(self, latent, scale_factor):
        b = latent.shape[0]
        s = torch.full((b, 1), float(scale_factor) if isinstance(scale_factor, (int, float))
                       else scale_factor.item(), device=latent.device, dtype=torch.float32)
        return latent * (1 + self.mlp(s).view(b, -1, 1, 1))

In [None]:
# =====================================================================
#  Stage 3 — Diffusion UNet Components
# =====================================================================

class SinusoidalPositionEmbeddings(nn.Module):
    """Maps integer timestep t → sinusoidal embedding (same as Transformer PE)."""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        half = self.dim // 2
        emb = math.log(10000) / (half - 1)
        emb = torch.exp(torch.arange(half, device=time.device) * -emb)
        emb = time[:, None] * emb[None, :]
        return torch.cat([emb.sin(), emb.cos()], dim=-1)


class AttentionBlock(nn.Module):
    """Spatial self-attention (used optionally in the UNet)."""
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(min(8, channels), channels)
        self.qkv  = nn.Conv2d(channels, channels * 3, 1)
        self.proj  = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        res = x
        x = self.norm(x)
        q, k, v = self.qkv(x).chunk(3, 1)
        q = q.view(b, c, -1).transpose(1, 2)
        k = k.view(b, c, -1).transpose(1, 2)
        v = v.view(b, c, -1).transpose(1, 2)
        attn = torch.softmax(torch.bmm(q, k.transpose(1, 2)) * (c ** -0.5), dim=-1)
        out = torch.bmm(attn, v).transpose(1, 2).view(b, c, h, w)
        return self.proj(out) + res


class CrossAttentionBlock(nn.Module):
    """Cross-attention: UNet features attend to Neural-Operator context vector."""
    def __init__(self, channels, context_dim):
        super().__init__()
        self.norm = nn.GroupNorm(min(8, channels), channels)
        self.q  = nn.Conv2d(channels, channels, 1)
        self.kv = nn.Linear(context_dim, channels * 2)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x, context):
        b, c, h, w = x.shape
        res = x
        x = self.norm(x)
        q = self.q(x).view(b, c, -1).transpose(1, 2)
        kv = self.kv(context)
        k, v = kv.chunk(2, 1)
        k, v = k.unsqueeze(1), v.unsqueeze(1)
        attn = torch.softmax(torch.bmm(q, k.transpose(1, 2)) * (c ** -0.5), dim=-1)
        out = torch.bmm(attn, v).transpose(1, 2).view(b, c, h, w)
        return self.proj(out) + res


class ResidualBlockWithTime(nn.Module):
    """ResBlock conditioned on a timestep embedding (added after first conv)."""
    def __init__(self, in_ch, out_ch, t_dim):
        super().__init__()
        self.norm1 = nn.GroupNorm(min(8, in_ch), in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.time_emb = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, out_ch))
        self.norm2 = nn.GroupNorm(min(8, out_ch), out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t_emb):
        h = self.conv1(F.silu(self.norm1(x)))
        h = h + self.time_emb(t_emb)[:, :, None, None]
        h = self.conv2(F.silu(self.norm2(h)))
        return h + self.shortcut(x)


class DiffusionUNet(nn.Module):
    """Simplified UNet for latent-space denoising.

    Architecture:  input_proj → Down → Mid (cross-attn) → Up → output
    Skip connection from input_proj to Up via concatenation.
    """
    def __init__(self, in_channels=128, model_channels=64,
                 out_channels=128, context_dim=128):
        super().__init__()
        t_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(model_channels),
            nn.Linear(model_channels, t_dim), nn.SiLU(),
            nn.Linear(t_dim, t_dim),
        )

        self.input_proj = nn.Conv2d(in_channels, model_channels, 3, padding=1)

        # ---- Down ----
        self.down1 = ResidualBlockWithTime(model_channels, model_channels * 2, t_dim)
        self.down2 = nn.Conv2d(model_channels * 2, model_channels * 2, 3, stride=2, padding=1)

        # ---- Mid ----
        self.mid1     = ResidualBlockWithTime(model_channels * 2, model_channels * 2, t_dim)
        self.mid_attn = CrossAttentionBlock(model_channels * 2, context_dim)
        self.mid2     = ResidualBlockWithTime(model_channels * 2, model_channels * 2, t_dim)

        # ---- Up ----
        self.up1 = nn.ConvTranspose2d(model_channels * 2, model_channels * 2,
                                      4, stride=2, padding=1)
        # After concat with skip:  model_channels*2 + model_channels = model_channels*3
        self.up2 = ResidualBlockWithTime(model_channels * 3, model_channels, t_dim)

        self.out = nn.Sequential(
            nn.GroupNorm(8, model_channels), nn.SiLU(),
            nn.Conv2d(model_channels, out_channels, 3, padding=1),
        )

    def forward(self, x, t, context):
        t_emb = self.time_embed(t)
        h  = self.input_proj(x)    # skip source
        h0 = h
        h  = self.down2(self.down1(h, t_emb))
        h  = self.mid2(self.mid_attn(self.mid1(h, t_emb), context), t_emb)
        h  = self.up2(torch.cat([self.up1(h), h0], dim=1), t_emb)
        return self.out(h)

In [None]:
# =====================================================================
#  DDPM Noise Scheduler
# =====================================================================

class DDPMScheduler:
    """Linear-beta DDPM scheduler with DDIM sampling."""
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.num_timesteps = num_timesteps
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, 0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

    def add_noise(self, x0, noise, t):
        """q(x_t | x_0) forward diffusion."""
        t_cpu = t.cpu()
        a = self.sqrt_alphas_cumprod[t_cpu].to(x0.device)
        b = self.sqrt_one_minus_alphas_cumprod[t_cpu].to(x0.device)
        while a.dim() < x0.dim(): a = a.unsqueeze(-1)
        while b.dim() < x0.dim(): b = b.unsqueeze(-1)
        return a * x0 + b * noise

    def ddim_sample(self, eps_pred, t, x_t):
        """Deterministic DDIM reverse step."""
        t_val = t.item() if isinstance(t, torch.Tensor) and t.numel() == 1 else t
        a_t    = self.alphas_cumprod[t_val].to(x_t.device)
        a_prev = (self.alphas_cumprod_prev[t_val].to(x_t.device)
                  if t_val > 0 else torch.tensor(1.0, device=x_t.device))
        x0_pred = (x_t - torch.sqrt(1 - a_t) * eps_pred) / torch.sqrt(a_t)
        return torch.sqrt(a_prev) * x0_pred + torch.sqrt(1 - a_prev) * eps_pred


# =====================================================================
#  EMA (Exponential Moving Average) Helper
# =====================================================================

class EMA:
    """Maintains an exponential moving average of model parameters."""
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {k: v.clone().detach() for k, v in model.state_dict().items()}

    @torch.no_grad()
    def update(self, model):
        for k, v in model.state_dict().items():
            if self.shadow[k].device != v.device:
                self.shadow[k] = self.shadow[k].to(v.device)
            self.shadow[k].mul_(self.decay).add_(v, alpha=1 - self.decay)

    def apply(self, model):
        model.load_state_dict(self.shadow)

    def state_dict(self):
        return self.shadow


# =====================================================================
#  Complete HNDSR Model
# =====================================================================

class HNDSR(nn.Module):
    """Hybrid Neural Operator–Diffusion Super-Resolution."""
    def __init__(self, ae_latent_dim=128, ae_downsample_ratio=8,
                 no_width=32, no_modes=8, diffusion_channels=64,
                 num_timesteps=1000):
        super().__init__()
        self.ae_downsample_ratio = ae_downsample_ratio
        self.autoencoder     = LatentAutoencoder(3, ae_latent_dim, 4, ae_downsample_ratio)
        self.neural_operator = NeuralOperator(3, ae_latent_dim, no_modes, no_width)
        self.implicit_amp    = ImplicitAmplification(ae_latent_dim, 256)
        self.diffusion_unet  = DiffusionUNet(ae_latent_dim, diffusion_channels,
                                             ae_latent_dim, ae_latent_dim)
        self.scheduler       = DDPMScheduler(num_timesteps)

    # ---- Convenience wrappers ----
    def encode_hr(self, hr):
        _, z = self.autoencoder(hr)
        return z

    def decode_latent(self, z):
        return self.autoencoder.decode(z)

    def get_no_prior(self, lr, scale):
        up   = F.interpolate(lr, scale_factor=scale, mode='bicubic', align_corners=False)
        feat = self.neural_operator(up, scale)
        s    = up.shape[-1] // self.ae_downsample_ratio
        return F.interpolate(feat, size=(s, s), mode='bilinear', align_corners=False)

    @torch.no_grad()
    def super_resolve(self, lr, scale_factor=4, num_inference_steps=50):
        """Full inference: LR → SR image via diffusion sampling."""
        b = lr.shape[0]
        no_prior = self.implicit_amp(self.get_no_prior(lr, scale_factor), scale_factor)
        context  = F.adaptive_avg_pool2d(no_prior, 1).view(b, -1)
        z_t = torch.randn(no_prior.shape, device=lr.device)

        timesteps = torch.linspace(self.scheduler.num_timesteps - 1, 0,
                                   num_inference_steps, dtype=torch.long)
        for t in tqdm(timesteps, desc='Sampling', leave=False):
            t_batch = torch.full((b,), t, device=lr.device, dtype=torch.long)
            z_t = self.scheduler.ddim_sample(
                self.diffusion_unet(z_t, t_batch, context), t, z_t)
        return self.decode_latent(z_t)

# ---- Quick summary ----
_tmp = HNDSR()
total_params = sum(p.numel() for p in _tmp.parameters())
print(f"HNDSR total parameters: {total_params:,}")
for name, module in [('Autoencoder', _tmp.autoencoder),
                     ('Neural Operator', _tmp.neural_operator),
                     ('Implicit Amp', _tmp.implicit_amp),
                     ('Diffusion UNet', _tmp.diffusion_unet)]:
    n = sum(p.numel() for p in module.parameters())
    print(f"  {name:20s}: {n:>10,}  ({100*n/total_params:.1f}%)")
del _tmp

---
## 6. Evaluation Metrics

We evaluate with three standard super-resolution metrics:

| Metric | Full Name | Measures | Better |
|--------|-----------|----------|--------|
| **PSNR** | Peak Signal-to-Noise Ratio | Pixel-level accuracy | Higher ↑ |
| **SSIM** | Structural Similarity Index | Structural preservation | Higher ↑ |
| **LPIPS** | Learned Perceptual Similarity | Perceptual quality (deep features) | Lower ↓ |

In [None]:
def to_numpy_01(t):
    """Convert [-1,1] tensor → [0,1] numpy, shape (B, H, W, C)."""
    return ((t.detach().cpu().numpy().transpose(0, 2, 3, 1)) + 1.0) / 2.0

def calculate_psnr(img1, img2):
    a, b = to_numpy_01(img1), to_numpy_01(img2)
    return np.mean([psnr(a[i], b[i], data_range=1.0) for i in range(a.shape[0])])

def calculate_ssim(img1, img2):
    a, b = to_numpy_01(img1), to_numpy_01(img2)
    return np.mean([ssim(a[i], b[i], data_range=1.0, channel_axis=2)
                    for i in range(a.shape[0])])

---
## 7. Training Functions (3-Stage Pipeline)

Each stage is trained sequentially; earlier stages are **frozen** before training later ones.

| Feature | Details |
|---|---|
| **Optimizer** | AdamW (decoupled weight decay) |
| **LR Schedule** | Linear warmup (`WARMUP_EPOCHS`) → Cosine annealing |
| **Gradient Clipping** | `MAX_GRAD_NORM` — prevents exploding gradients in spectral / diffusion layers |
| **Early Stopping** | Patience = `EARLY_STOP_PATIENCE` epochs without val-loss improvement |
| **EMA** | Exponential Moving Average of diffusion UNet weights (Stage 3 only) |
| **Timing** | Per-epoch wall-clock time logged |

### Loss Functions
- **Stage 1 (AE):** $\mathcal{L}_{\text{AE}} = \|x - \hat{x}\|_1$ (L1 for sharper reconstructions)
- **Stage 2 (NO):** $\mathcal{L}_{\text{NO}} = \|z_{\text{pred}} - z_{\text{target}}\|_2^2$ (MSE in latent space)
- **Stage 3 (Diff):** $\mathcal{L}_{\text{Diff}} = \|\epsilon - \epsilon_\theta(z_t, t, c)\|_2^2$ (noise-prediction objective)

In [None]:
# ═══════════════════════════════════════════════════════════
#  Helper: build warmup + cosine scheduler
# ═══════════════════════════════════════════════════════════
def _make_scheduler(optimizer, warmup_epochs, total_epochs):
    """Linear warmup for `warmup_epochs`, then cosine annealing to 0."""
    warmup = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=0.01, total_iters=warmup_epochs)
    cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=max(1, total_epochs - warmup_epochs))
    return torch.optim.lr_scheduler.SequentialLR(
        optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs])


# ═══════════════════════════════════════════════════════════
#  Stage 1: Autoencoder  (L1 reconstruction)
# ═══════════════════════════════════════════════════════════
def train_autoencoder(model, train_loader, val_loader, num_epochs, lr=1e-4,
                      device='cuda', resume_path=None):
    print("\n" + "="*60)
    print(f"STAGE 1: Training Autoencoder  ({num_epochs} epochs)")
    print("="*60)

    model.autoencoder.to(device)
    optimizer = torch.optim.AdamW(model.autoencoder.parameters(),
                                  lr=lr, weight_decay=WEIGHT_DECAY)
    scheduler = _make_scheduler(optimizer, WARMUP_EPOCHS, num_epochs)
    loss_fn = nn.L1Loss()
    best_val = float('inf')
    patience_counter = 0
    start_epoch = 0

    history = {'train': [], 'val': [], 'psnr': [], 'lr': []}

    # ---- Resume from mid-stage checkpoint ----
    if resume_path and os.path.exists(resume_path):
        ckpt = torch.load(resume_path, map_location=device, weights_only=False)
        model.autoencoder.load_state_dict(ckpt['model_state'])
        optimizer.load_state_dict(ckpt['optimizer_state'])
        scheduler.load_state_dict(ckpt['scheduler_state'])
        start_epoch    = ckpt['epoch'] + 1
        best_val       = ckpt['best_val']
        patience_counter = ckpt['patience_counter']
        history        = ckpt['history']
        print(f"  ↻ Resumed from epoch {start_epoch}/{num_epochs}  "
              f"(best_val={best_val:.4f}, patience={patience_counter})")

    for epoch in range(start_epoch, num_epochs):
        t0 = time.time()
        model.autoencoder.train()
        train_losses = []
        for batch in tqdm(train_loader, desc=f"AE {epoch+1}/{num_epochs}", leave=False):
            hr = batch['hr'].to(device)
            optimizer.zero_grad(set_to_none=True)
            recon, z = model.autoencoder(hr)
            loss = loss_fn(recon, hr)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.autoencoder.parameters(), MAX_GRAD_NORM)
            optimizer.step()
            train_losses.append(loss.item())
            del hr, recon, z, loss

        # Validation
        model.autoencoder.eval()
        val_losses, val_psnrs = [], []
        with torch.no_grad():
            for batch in val_loader:
                hr = batch['hr'].to(device)
                recon, z = model.autoencoder(hr)
                val_losses.append(loss_fn(recon, hr).item())
                val_psnrs.append(calculate_psnr(recon, hr))
                del hr, recon, z
        torch.cuda.empty_cache()

        tl, vl, vp = np.mean(train_losses), np.mean(val_losses), np.mean(val_psnrs)
        cur_lr = optimizer.param_groups[0]['lr']
        history['train'].append(tl); history['val'].append(vl)
        history['psnr'].append(vp); history['lr'].append(cur_lr)

        elapsed = time.time() - t0
        print(f"  Epoch {epoch+1:3d} | Train {tl:.4f} | Val {vl:.4f} | "
              f"PSNR {vp:.2f} dB | LR {cur_lr:.1e} | {elapsed:.0f}s", end='')

        if vl < best_val:
            best_val = vl
            patience_counter = 0
            torch.save(model.autoencoder.state_dict(), AUTOENCODER_PATH)
            print("  *saved*", end='')
        else:
            patience_counter += 1
        print()
        scheduler.step()

        # ---- Save resume checkpoint every epoch ----
        torch.save({
            'epoch': epoch,
            'model_state': model.autoencoder.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'best_val': best_val,
            'patience_counter': patience_counter,
            'history': history,
        }, STAGE1_RESUME_PATH)

        if patience_counter >= EARLY_STOP_PATIENCE:
            print(f"  Early stopping triggered (patience {EARLY_STOP_PATIENCE})")
            break

    # Reload best weights & clean up resume file
    model.autoencoder.load_state_dict(torch.load(AUTOENCODER_PATH, map_location=device, weights_only=False))
    if os.path.exists(STAGE1_RESUME_PATH):
        os.remove(STAGE1_RESUME_PATH)
        print("  Removed Stage 1 resume checkpoint (training complete)")
    print("Stage 1 complete — best val loss: {:.4f}".format(best_val))
    return model, history

In [None]:
# ═══════════════════════════════════════════════════════════
#  Stage 2: Neural Operator  (MSE in latent space)
# ═══════════════════════════════════════════════════════════
def train_neural_operator(model, train_loader, val_loader, num_epochs, lr=1e-4,
                          device='cuda', resume_path=None):
    print("\n" + "="*60)
    print(f"STAGE 2: Training Neural Operator  ({num_epochs} epochs)")
    print("="*60)

    for p in model.autoencoder.parameters(): p.requires_grad = False
    model.autoencoder.eval()
    model.neural_operator.to(device)

    optimizer = torch.optim.AdamW(model.neural_operator.parameters(),
                                  lr=lr, weight_decay=WEIGHT_DECAY)
    scheduler = _make_scheduler(optimizer, WARMUP_EPOCHS, num_epochs)
    loss_fn = nn.MSELoss()
    best_val = float('inf')
    patience_counter = 0
    start_epoch = 0

    history = {'train': [], 'val': [], 'lr': []}

    # ---- Resume from mid-stage checkpoint ----
    if resume_path and os.path.exists(resume_path):
        ckpt = torch.load(resume_path, map_location=device, weights_only=False)
        model.neural_operator.load_state_dict(ckpt['model_state'])
        optimizer.load_state_dict(ckpt['optimizer_state'])
        scheduler.load_state_dict(ckpt['scheduler_state'])
        start_epoch    = ckpt['epoch'] + 1
        best_val       = ckpt['best_val']
        patience_counter = ckpt['patience_counter']
        history        = ckpt['history']
        print(f"  ↻ Resumed from epoch {start_epoch}/{num_epochs}  "
              f"(best_val={best_val:.4f}, patience={patience_counter})")

    for epoch in range(start_epoch, num_epochs):
        t0 = time.time()
        model.neural_operator.train()
        train_losses = []
        for batch in tqdm(train_loader, desc=f"NO {epoch+1}/{num_epochs}", leave=False):
            lr_img = batch['lr'].to(device)
            hr_img = batch['hr'].to(device)
            scale  = batch['scale'][0].item()

            optimizer.zero_grad(set_to_none=True)
            with torch.no_grad():
                target = model.encode_hr(hr_img)
            pred = model.get_no_prior(lr_img, scale)
            loss = loss_fn(pred, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.neural_operator.parameters(), MAX_GRAD_NORM)
            optimizer.step()
            train_losses.append(loss.item())
            del lr_img, hr_img, target, pred, loss

        model.neural_operator.eval()
        val_losses = []
        with torch.no_grad():
            for batch in val_loader:
                lr_img = batch['lr'].to(device)
                hr_img = batch['hr'].to(device)
                scale  = batch['scale'][0].item()
                target = model.encode_hr(hr_img)
                pred = model.get_no_prior(lr_img, scale)
                val_losses.append(loss_fn(pred, target).item())
                del lr_img, hr_img, target, pred
        torch.cuda.empty_cache()

        tl, vl = np.mean(train_losses), np.mean(val_losses)
        cur_lr = optimizer.param_groups[0]['lr']
        history['train'].append(tl); history['val'].append(vl)
        history['lr'].append(cur_lr)

        elapsed = time.time() - t0
        print(f"  Epoch {epoch+1:3d} | Train {tl:.4f} | Val {vl:.4f} | "
              f"LR {cur_lr:.1e} | {elapsed:.0f}s", end='')

        if vl < best_val:
            best_val = vl
            patience_counter = 0
            torch.save(model.neural_operator.state_dict(), NEURAL_OPERATOR_PATH)
            print("  *saved*", end='')
        else:
            patience_counter += 1
        print()
        scheduler.step()

        # ---- Save resume checkpoint every epoch ----
        torch.save({
            'epoch': epoch,
            'model_state': model.neural_operator.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'best_val': best_val,
            'patience_counter': patience_counter,
            'history': history,
        }, STAGE2_RESUME_PATH)

        if patience_counter >= EARLY_STOP_PATIENCE:
            print(f"  Early stopping triggered (patience {EARLY_STOP_PATIENCE})")
            break

    # Reload best weights & clean up resume file
    model.neural_operator.load_state_dict(torch.load(NEURAL_OPERATOR_PATH, map_location=device, weights_only=False))
    if os.path.exists(STAGE2_RESUME_PATH):
        os.remove(STAGE2_RESUME_PATH)
        print("  Removed Stage 2 resume checkpoint (training complete)")
    print("Stage 2 complete — best val loss: {:.4f}".format(best_val))
    return model, history

In [None]:
# ═══════════════════════════════════════════════════════════
#  Stage 3: Diffusion Model  (ε-prediction with EMA)
# ═══════════════════════════════════════════════════════════
def train_diffusion(model, train_loader, val_loader, num_epochs, lr=1e-4,
                    device='cuda', ema=None, resume_path=None):
    print("\n" + "="*60)
    print(f"STAGE 3: Training Diffusion Model  ({num_epochs} epochs)")
    if ema is not None:
        print(f"  EMA enabled (decay={EMA_DECAY})")
    print("="*60)

    for p in model.autoencoder.parameters(): p.requires_grad = False
    for p in model.neural_operator.parameters(): p.requires_grad = False
    for p in model.implicit_amp.parameters(): p.requires_grad = False
    model.autoencoder.eval()
    model.neural_operator.eval()
    model.implicit_amp.eval()
    model.diffusion_unet.to(device)

    optimizer = torch.optim.AdamW(model.diffusion_unet.parameters(),
                                  lr=lr, weight_decay=WEIGHT_DECAY)
    scheduler = _make_scheduler(optimizer, WARMUP_EPOCHS, num_epochs)
    loss_fn = nn.MSELoss()
    best_val = float('inf')
    patience_counter = 0
    start_epoch = 0

    history = {'train': [], 'val': [], 'lr': []}

    # ---- Resume from mid-stage checkpoint ----
    if resume_path and os.path.exists(resume_path):
        ckpt = torch.load(resume_path, map_location=device, weights_only=False)
        model.diffusion_unet.load_state_dict(ckpt['model_state'])
        optimizer.load_state_dict(ckpt['optimizer_state'])
        scheduler.load_state_dict(ckpt['scheduler_state'])
        start_epoch    = ckpt['epoch'] + 1
        best_val       = ckpt['best_val']
        patience_counter = ckpt['patience_counter']
        history        = ckpt['history']
        if ema is not None and 'ema_shadow' in ckpt:
            ema.shadow = ckpt['ema_shadow']
        print(f"  ↻ Resumed from epoch {start_epoch}/{num_epochs}  "
              f"(best_val={best_val:.4f}, patience={patience_counter})")

    for epoch in range(start_epoch, num_epochs):
        t0 = time.time()
        model.diffusion_unet.train()
        train_losses = []
        for batch in tqdm(train_loader, desc=f"Diff {epoch+1}/{num_epochs}", leave=False):
            lr_img = batch['lr'].to(device)
            hr_img = batch['hr'].to(device)
            scale  = batch['scale'][0].item()

            optimizer.zero_grad(set_to_none=True)
            with torch.no_grad():
                target = model.encode_hr(hr_img)
                no_prior = model.get_no_prior(lr_img, scale)
                b = lr_img.shape[0]
                context = F.adaptive_avg_pool2d(no_prior, 1).view(b, -1)

            timesteps = torch.randint(0, model.scheduler.num_timesteps, (b,), device=device).long()
            noise = torch.randn_like(target)
            noisy = model.scheduler.add_noise(target, noise, timesteps)
            pred  = model.diffusion_unet(noisy, timesteps, context)
            loss  = loss_fn(pred, noise)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.diffusion_unet.parameters(), MAX_GRAD_NORM)
            optimizer.step()

            # EMA update
            if ema is not None:
                ema.update(model.diffusion_unet)

            train_losses.append(loss.item())
            del lr_img, hr_img, target, no_prior, context, timesteps, noise, noisy, pred, loss

        model.diffusion_unet.eval()
        val_losses = []
        with torch.no_grad():
            for batch in val_loader:
                lr_img = batch['lr'].to(device)
                hr_img = batch['hr'].to(device)
                scale  = batch['scale'][0].item()
                target = model.encode_hr(hr_img)
                no_prior = model.get_no_prior(lr_img, scale)
                b = lr_img.shape[0]
                context = F.adaptive_avg_pool2d(no_prior, 1).view(b, -1)
                timesteps = torch.randint(0, model.scheduler.num_timesteps, (b,), device=device).long()
                noise = torch.randn_like(target)
                noisy = model.scheduler.add_noise(target, noise, timesteps)
                pred  = model.diffusion_unet(noisy, timesteps, context)
                val_losses.append(loss_fn(pred, noise).item())
                del lr_img, hr_img, target, no_prior, context, timesteps, noise, noisy, pred
        torch.cuda.empty_cache()

        tl, vl = np.mean(train_losses), np.mean(val_losses)
        cur_lr = optimizer.param_groups[0]['lr']
        history['train'].append(tl); history['val'].append(vl)
        history['lr'].append(cur_lr)

        elapsed = time.time() - t0
        print(f"  Epoch {epoch+1:3d} | Train {tl:.4f} | Val {vl:.4f} | "
              f"LR {cur_lr:.1e} | {elapsed:.0f}s", end='')

        if vl < best_val:
            best_val = vl
            patience_counter = 0
            save_dict = {'diffusion_unet': model.diffusion_unet.state_dict()}
            if ema is not None:
                save_dict['ema_shadow'] = ema.shadow
            torch.save(save_dict, DIFFUSION_PATH)
            print("  *saved*", end='')
        else:
            patience_counter += 1
        print()
        scheduler.step()

        # ---- Save resume checkpoint every epoch ----
        resume_dict = {
            'epoch': epoch,
            'model_state': model.diffusion_unet.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'best_val': best_val,
            'patience_counter': patience_counter,
            'history': history,
        }
        if ema is not None:
            resume_dict['ema_shadow'] = ema.shadow
        torch.save(resume_dict, STAGE3_RESUME_PATH)

        if patience_counter >= EARLY_STOP_PATIENCE:
            print(f"  Early stopping triggered (patience {EARLY_STOP_PATIENCE})")
            break

    # Reload best weights & clean up resume file
    ckpt = torch.load(DIFFUSION_PATH, map_location=device, weights_only=False)
    model.diffusion_unet.load_state_dict(ckpt['diffusion_unet'])
    if ema is not None and 'ema_shadow' in ckpt:
        ema.shadow = ckpt['ema_shadow']
    if os.path.exists(STAGE3_RESUME_PATH):
        os.remove(STAGE3_RESUME_PATH)
        print("  Removed Stage 3 resume checkpoint (training complete)")
    print("Stage 3 complete — best val loss: {:.4f}".format(best_val))
    return model, history

---
## 8. Checkpoint & Resume Helpers

### How resume works
Each training function saves a **resume checkpoint every epoch** (`stageN_resume.pth`) containing the model weights, optimizer state, scheduler state, epoch counter, best validation loss, patience counter, and full training history.

**If a Kaggle session times out mid-training:**
1. Your output files are preserved (they're in `/kaggle/working/`)
2. When you start a new session, add your **previous notebook output** as a dataset
3. Copy the checkpoint files back: `!cp /kaggle/input/your-output-dataset/* /kaggle/working/`
4. **Just re-run all cells** — with `AUTO_RESUME = True` the notebook automatically detects which stage was interrupted and continues from the exact epoch

The resume file is deleted once a stage finishes successfully, so it only exists for interrupted stages.

In [None]:
def load_trained_stages(model, up_to_stage, device='cuda'):
    """Load previously saved stage checkpoints up to `up_to_stage` (inclusive).
    Returns (model, last_loaded_stage).
    """
    if up_to_stage >= 1 and os.path.exists(AUTOENCODER_PATH):
        model.autoencoder.load_state_dict(
            torch.load(AUTOENCODER_PATH, map_location=device, weights_only=True))
        model.autoencoder.to(device)
        print(f"  ✓ Stage 1 loaded: {AUTOENCODER_PATH}")
    elif up_to_stage >= 1:
        print(f"  ✗ Stage 1 NOT found at {AUTOENCODER_PATH}")
        return model, 0

    if up_to_stage >= 2 and os.path.exists(NEURAL_OPERATOR_PATH):
        model.neural_operator.load_state_dict(
            torch.load(NEURAL_OPERATOR_PATH, map_location=device, weights_only=True))
        model.neural_operator.to(device)
        print(f"  ✓ Stage 2 loaded: {NEURAL_OPERATOR_PATH}")
    elif up_to_stage >= 2:
        print(f"  ✗ Stage 2 NOT found at {NEURAL_OPERATOR_PATH}")
        return model, 1

    if up_to_stage >= 3 and os.path.exists(DIFFUSION_PATH):
        ckpt = torch.load(DIFFUSION_PATH, map_location=device, weights_only=True)
        model.diffusion_unet.load_state_dict(ckpt['diffusion_unet'])
        model.diffusion_unet.to(device)
        print(f"  ✓ Stage 3 loaded: {DIFFUSION_PATH}")
    elif up_to_stage >= 3:
        print(f"  ✗ Stage 3 NOT found at {DIFFUSION_PATH}")
        return model, 2

    return model, up_to_stage

---
## 9. Run Training Pipeline

> **Kaggle tip — multi-session training:**
> 1. Run the notebook with **Save & Run All** and let it run overnight
> 2. If the 12-hour session expires mid-training, go to the notebook's **Output** tab and click **"New Notebook"** or add the output as a dataset input
> 3. Copy checkpoint files back to `/kaggle/working/` and **re-run all cells** — training auto-resumes from the last completed epoch
> 4. Repeat until all stages are done

In [None]:
# ---- Build dataloaders ----
full_dataset = SatelliteDataset(HR_DIR, LR_DIR, patch_size=PATCH_SIZE, training=True)

train_size = int(0.9 * len(full_dataset))
val_size   = len(full_dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train: {len(train_ds)}  |  Val: {len(val_ds)}")

In [None]:
# Apply EMA weights for final model (smoother, better generalisation)
# If `ema` exists from training, apply its shadow weights.
# If training already finished (kernel restart / separate cell), skip gracefully.
if 'ema' in dir() and ema is not None:
    ema.apply(model.diffusion_unet)
    print("Applied EMA shadow weights to diffusion UNet.")
else:
    print("EMA not available (training already completed in a prior run). "
          "Using best-checkpoint weights directly.")

---
## 10. Training Curves

Loss curves for each stage, plus per-stage learning rate schedule.

In [None]:
if 'all_histories' not in dir() or all_histories is None:
    print("all_histories not available (training completed in a prior session "
          "and resume checkpoints — which store history — are cleaned up after each stage).")
    print("Skipping training-curve plot.  If you need curves, re-run training "
          "from Stage 1 in a single session or save histories separately.")
else:
    n_plots = len(all_histories)
    if n_plots == 0:
        print("No training histories recorded (did you resume past all stages?).")
    else:
        fig, axes = plt.subplots(2, n_plots, figsize=(6*n_plots, 9))
        if n_plots == 1:
            axes = axes.reshape(2, 1)

        titles = {'stage1': 'Stage 1: Autoencoder (L1)',
                  'stage2': 'Stage 2: Neural Operator (MSE)',
                  'stage3': 'Stage 3: Diffusion (ε-pred)'}

        for col, (key, h) in enumerate(all_histories.items()):
            # Row 0: Loss curves
            ax = axes[0, col]
            ax.plot(h['train'], label='Train', linewidth=1.5)
            ax.plot(h['val'],   label='Val',   linewidth=1.5)
            ax.set_title(titles.get(key, key), fontsize=12, fontweight='bold')
            ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')
            ax.legend(); ax.grid(True, alpha=0.3)

            # Row 1: LR schedule
            ax2 = axes[1, col]
            if 'lr' in h:
                ax2.plot(h['lr'], color='tab:green', linewidth=1.5)
                ax2.set_ylabel('Learning Rate')
            elif 'psnr' in h:
                ax2.plot(h['psnr'], color='tab:orange', linewidth=1.5)
                ax2.set_ylabel('PSNR (dB)')
            ax2.set_xlabel('Epoch'); ax2.grid(True, alpha=0.3)
            ax2.set_title('LR Schedule' if 'lr' in h else 'Validation PSNR')

        plt.tight_layout()
        plt.savefig(os.path.join(SAVE_DIR, 'training_curves.png'), dpi=150, bbox_inches='tight')
        plt.show()
        print("Saved training_curves.png")

---
## 11. Evaluation

Run super-resolution on held-out patches using 50 DDPM inference steps. Computes **PSNR**, **SSIM**, and optionally **LPIPS** (perceptual loss). Saves LR / SR / HR image triplets for visual inspection.

In [None]:
@torch.no_grad()
def evaluate_model(model, test_loader, device='cuda', save_dir=None, max_vis=10,
                   num_inference_steps=50):
    """Full evaluation with PSNR, SSIM, LPIPS."""
    if save_dir:
        for sub in ['lr', 'sr', 'hr']:
            Path(os.path.join(save_dir, sub)).mkdir(parents=True, exist_ok=True)

    try:
        lpips_fn = lpips.LPIPS(net='alex').to(device)
        use_lpips = True
        print("LPIPS metric enabled (AlexNet backbone)")
    except Exception:
        use_lpips = False
        print("LPIPS not available — reporting PSNR & SSIM only")

    model.autoencoder.eval(); model.neural_operator.eval()
    model.implicit_amp.eval(); model.diffusion_unet.eval()

    psnr_vals, ssim_vals, lpips_vals = [], [], []

    for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating')):
        lr_img = batch['lr'].to(device)
        hr_img = batch['hr'].to(device)
        scale  = batch['scale'][0].item()

        sr_img = model.super_resolve(lr_img, scale_factor=scale,
                                     num_inference_steps=num_inference_steps)

        p = calculate_psnr(sr_img, hr_img)
        s = calculate_ssim(sr_img, hr_img)
        psnr_vals.append(p); ssim_vals.append(s)
        if use_lpips:
            lpips_vals.append(lpips_fn(sr_img, hr_img).mean().item())

        if save_dir and idx < max_vis:
            save_image((lr_img+1)/2, f"{save_dir}/lr/sample_{idx:03d}.png")
            save_image((sr_img+1)/2, f"{save_dir}/sr/sample_{idx:03d}.png")
            save_image((hr_img+1)/2, f"{save_dir}/hr/sample_{idx:03d}.png")

        del lr_img, hr_img, sr_img
        torch.cuda.empty_cache()

    results = {
        'psnr_mean': np.mean(psnr_vals), 'psnr_std': np.std(psnr_vals),
        'ssim_mean': np.mean(ssim_vals), 'ssim_std': np.std(ssim_vals),
        'psnr_values': psnr_vals, 'ssim_values': ssim_vals,
    }
    if use_lpips:
        results['lpips_mean'] = np.mean(lpips_vals)
        results['lpips_std']  = np.std(lpips_vals)
        results['lpips_values'] = lpips_vals

    print("\n" + "="*55)
    print("  EVALUATION RESULTS")
    print("="*55)
    print(f"  PSNR  : {results['psnr_mean']:.2f} ± {results['psnr_std']:.2f} dB")
    print(f"  SSIM  : {results['ssim_mean']:.4f} ± {results['ssim_std']:.4f}")
    if use_lpips:
        print(f"  LPIPS : {results['lpips_mean']:.4f} ± {results['lpips_std']:.4f}")
    print(f"  Samples evaluated: {len(psnr_vals)}")
    print("="*55)

    if save_dir:
        torch.save(results, os.path.join(save_dir, 'evaluation_results.pth'))
        print(f"  Saved to {save_dir}/evaluation_results.pth")

    return results

In [None]:
# ---- Rebuild model from checkpoints if not already in memory ----
if 'model' not in dir() or model is None:
    print("model not in memory — rebuilding from saved checkpoints...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = HNDSR().to(device)
    model, loaded_up_to = load_trained_stages(model, up_to_stage=3, device=device)
    if loaded_up_to < 3:
        raise RuntimeError(f"Only loaded up to stage {loaded_up_to}; "
                           "need all 3 stage checkpoints for evaluation.")
    model.eval()
    print(f"Model rebuilt on {device} with all 3 stages.\n")

# ---- Run evaluation ----
test_dataset = SatelliteDataset(HR_DIR, LR_DIR, patch_size=PATCH_SIZE, training=False)
num_eval = min(50, len(test_dataset))
if len(test_dataset) > num_eval:
    indices = np.random.choice(len(test_dataset), num_eval, replace=False)
    test_dataset = torch.utils.data.Subset(test_dataset, indices)

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

Path(EVAL_RESULTS_DIR).mkdir(parents=True, exist_ok=True)
results = evaluate_model(model, test_loader, device, save_dir=EVAL_RESULTS_DIR)

---
## 12. Visual Comparison

Side-by-side LR → SR → HR for qualitative assessment. Look for sharpness recovery in building edges, road markings, and vegetation texture.

In [None]:
# Show side-by-side comparisons
sr_files = sorted(Path(EVAL_RESULTS_DIR, 'sr').glob('*.png'))
n_show = min(5, len(sr_files))

if n_show == 0:
    print("No evaluation images found — run the evaluation cell first.")
else:
    fig, axes = plt.subplots(n_show, 3, figsize=(15, 5 * n_show))
    if n_show == 1: axes = [axes]

    for i in range(n_show):
        lr = Image.open(f"{EVAL_RESULTS_DIR}/lr/sample_{i:03d}.png")
        sr = Image.open(f"{EVAL_RESULTS_DIR}/sr/sample_{i:03d}.png")
        hr = Image.open(f"{EVAL_RESULTS_DIR}/hr/sample_{i:03d}.png")

        axes[i][0].imshow(lr); axes[i][0].set_title('LR Input (2 m/px)', fontsize=11)
        axes[i][1].imshow(sr); axes[i][1].set_title('HNDSR Output',      fontsize=11)
        axes[i][2].imshow(hr); axes[i][2].set_title('HR Ground Truth (0.5 m/px)', fontsize=11)
        for ax in axes[i]: ax.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, 'visual_comparison.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved visual_comparison.png")

In [None]:
# PSNR & SSIM distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.hist(results['psnr_values'], bins=15, edgecolor='black', alpha=0.7, color='steelblue')
ax1.axvline(results['psnr_mean'], color='red', linestyle='--', linewidth=2,
            label=f"Mean: {results['psnr_mean']:.2f} dB")
ax1.set_title('PSNR Distribution', fontsize=12, fontweight='bold')
ax1.set_xlabel('PSNR (dB)'); ax1.set_ylabel('Count'); ax1.legend()

ax2.hist(results['ssim_values'], bins=15, edgecolor='black', alpha=0.7, color='seagreen')
ax2.axvline(results['ssim_mean'], color='red', linestyle='--', linewidth=2,
            label=f"Mean: {results['ssim_mean']:.4f}")
ax2.set_title('SSIM Distribution', fontsize=12, fontweight='bold')
ax2.set_xlabel('SSIM'); ax2.set_ylabel('Count'); ax2.legend()

plt.tight_layout()
plt.savefig(os.path.join(SAVE_DIR, 'metric_distributions.png'), dpi=150, bbox_inches='tight')
plt.show()
print("Saved metric_distributions.png")

---
## 13. Download Model

All files saved to `/kaggle/working/` appear in the **Output** tab of your Kaggle notebook.

| File | Description |
|---|---|
| `hndsr_complete.pth` | Full trained model (all 4 sub-networks) |
| `autoencoder_best.pth` | Stage 1 best checkpoint |
| `neural_operator_best.pth` | Stage 2 best checkpoint |
| `diffusion_best.pth` | Stage 3 best checkpoint (includes EMA shadow) |
| `stageN_resume.pth` | Mid-stage resume checkpoint (only if interrupted) |
| `evaluation_results/` | Metrics `.pth` + LR/SR/HR images |
| `training_curves.png` | Loss & LR schedule plots |
| `visual_comparison.png` | Side-by-side SR examples |
| `metric_distributions.png` | PSNR & SSIM histograms |

### To continue training in a new session
1. Save the current notebook output as a **dataset** (Output tab → "New Dataset")
2. In the new session, add that dataset as input
3. Copy files: `!cp /kaggle/input/<dataset-name>/*.pth /kaggle/working/`
4. Re-run all cells — the notebook auto-resumes

In [None]:
# List all output files with sizes
print("=" * 50)
print("  OUTPUT FILES")
print("=" * 50)
total_mb = 0
for p in sorted(Path(SAVE_DIR).rglob('*')):
    if p.is_file():
        size = p.stat().st_size / (1024*1024)
        total_mb += size
        print(f"  {p.relative_to(SAVE_DIR):<45s} {size:>7.2f} MB")
print("-" * 50)
print(f"  {'Total':<45s} {total_mb:>7.2f} MB")
print("=" * 50)