In [None]:
!pip install -qq ultralytics
!pip install -qq torchmetrics
!pip install -qq lpips

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m31.0 MB/s[0m eta [36m0:00:00[0m
[2K   [91m━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m185.4/363.4 MB[0m [31m5.9 MB/s[0m eta [36m0:00:31[0m

### Dataset and Heatmap Download

This block performs the following steps:

1. **Download CelebA Dataset**  
   Uses `kagglehub` to fetch the *CelebA* dataset (`jessicali9530/celeba-dataset`)  
   and prints the local download path.

2. **Set PyTorch Device**  
   Detects if a GPU is available and sets `torch.device` to `"cuda"` or `"cpu"` accordingly.

3. **Download Precomputed Heatmaps**  
   Retrieves an `.h5` file containing precomputed heatmaps from Hugging Face.  
   - Multiple dataset sizes are available:
     - **10k** samples (small, for quick tests)
     - **30k** samples (medium, balanced)
     - **50k** samples (full set)  
   - !!!Only one `curl` command should be uncommented at a time!!!

4. **Purpose**  
   The CelebA images serve as the main training data,  
   while the heatmaps (e.g., facial landmark or attention maps)  
   are used as auxiliary inputs or for attention-based loss functions in the model.


In [None]:
import kagglehub
import torch
path = kagglehub.dataset_download("jessicali9530/celeba-dataset")
print("Path to dataset files:", path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
#dowload heatmaps
#small 10k
!curl -L -o heatmaps.h5 https://huggingface.co/datasets/RiccardoCarraro/heatmaps/resolve/main/heatmaps_10k.h5

#medium 30k
#curl -L -o heatmaps.h5 https://huggingface.co/datasets/RiccardoCarraro/heatmaps/resolve/main/heatmaps_30k.h5

#uncomment the following line to use the 50k version of the dataset
#!curl -L -o heatmaps.h5 https://huggingface.co/datasets/RiccardoCarraro/heatmaps/resolve/main/heatmaps.h5

### CelebDataSet: CelebA + (optional) facial landmark heatmaps

This `torch.utils.data.Dataset` builds a **progressive, multi-scale** face SR pipeline (16→32→64→128) and can optionally load **landmark heatmaps** from an HDF5 file. It mirrors the common setup in progressive face super‑resolution (e.g., Kim et al., *Progressive Face Super-Resolution via Attention to Facial Landmark*, 2019).

**What it returns (per sample)**
- `x2` — 32×32 target tensor  
- `x4` — 64×64 target tensor  
- `hr` — 128×128 “high‑res” target tensor  
- `lr` — 16×16 low‑res input tensor  
- `heatmap` — (1×128×128) float tensor with facial landmark attention (zeros if not provided)

**How it works**
1. **Splits**  
   Reads `list_eval_partition.csv` to build train/val/test file lists, then picks the split via `state={'train','val','test'}`.

2. **Preprocessing & Augmentation**  
   - Center-crop to 178×178 → resize to 128×128.  
   - Optional training‑time aug: horizontal flip, ±20° rotation (bilinear), color jitter.  
   - Normalize to `[-1, 1]` per channel with `transforms.Normalize((0.5,)*3, (0.5,)*3)`.

3. **Multi‑scale pyramid**  
   From the 128×128 image, it creates:
   - `x4`: 64×64  
   - `x2`: 32×32 (downsample of `x4`)  
   - `lr`: 16×16 (downsample of `x2`)  
   All are converted to tensors with the same normalization as `hr`.

4. **Heatmaps (optional)**  
   If `heatmap_h5` is provided, the code loads **all** heatmaps into RAM from the `heatmaps` dataset inside the HDF5 file (shape `(N, 128, 128)`).  
   - At `__getitem__`, it fetches `heatmaps[index]`, wraps it as a `(1,128,128)` tensor, and returns it.  
   - If not provided, it returns a zero heatmap.  
   **Tip:** ensure the HDF5 ordering matches CelebA’s file order used here (same split and sort).

**Why this layout?**  
- Produces supervision at **multiple scales** (16/32/64/128), which is standard for progressive face SR and aligns with setups that use **landmark‑based attention/heatmap losses** (as in Kim et al., 2019).  
- Keeps the dataloader simple and fast: images are read from disk; heatmaps (small) can be preloaded to RAM.

**Usage**
```python
ds = CelebDataSet(
    data_path="/path/to/celeba",
    state="train",
    heatmap_h5="heatmaps_30k.h5"  # or None
)
x2, x4, hr, lr, heat = ds[0]


In [None]:
import h5py
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
from os.path import join, splitext
from PIL import Image
import csv
import numpy as np

class CelebDataSet(Dataset):
    """
    CelebA dataset with optional landmark-heatmap loading from HDF5.

    Returns: (x2, x4, hr, lr, heatmap)
      - x2: 32×32 target tensor
      - x4: 64×64 target tensor
      - hr: 128×128 target tensor
      - lr: 16×16 input tensor
      - heatmap: 1×128×128 float tensor
    """
    def __init__(
        self,
        data_path: str = './dataset/',
        state: str = 'train',
        data_augmentation: bool = False,
        heatmap_h5: str = None,
    ):
        self.main_path = data_path
        self.state = state
        self.data_augmentation = data_augmentation
        self.img_path = join(self.main_path, 'img_align_celeba/img_align_celeba/')
        self.eval_partition_path = join(self.main_path, 'list_eval_partition.csv')

        # load train/val/test split
        train_list, val_list, test_list = [], [], []
        with open(self.eval_partition_path, 'r') as f:
            reader = csv.reader(f)
            for fname, split in reader:
                fname, split = fname.strip(), split.strip()
                if split == '0':
                    train_list.append(fname)
                elif split == '1':
                    val_list.append(fname)
                else:
                    test_list.append(fname)

        if state == 'train':
            self.image_list = sorted(train_list)
        elif state == 'val':
            self.image_list = sorted(val_list)
        else:
            self.image_list = sorted(test_list)

        # transforms
        if state=='train' and data_augmentation:
            self.pre_process = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.CenterCrop((178,178)),
                transforms.Resize((128,128)),
                transforms.RandomRotation(
                    20,
                    interpolation=InterpolationMode.BILINEAR
                ),
                transforms.ColorJitter(0.4,0.4,0.4,0.1)
            ])
        else:
            self.pre_process = transforms.Compose([
                transforms.CenterCrop((178,178)),
                transforms.Resize((128,128)),
            ])

        self.totensor = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
        ])
        self.down64 = transforms.Resize((64,64))
        self.down32 = transforms.Resize((32,32))
        self.down16 = transforms.Resize((16,16))

        # ACTUALLY load heatmaps into memory
        if heatmap_h5:
            with h5py.File(heatmap_h5, 'r') as h5_file:
                # Load the entire heatmap dataset into RAM
                self.heatmaps = np.array(h5_file['heatmaps'])  # Shape: (N, 128, 128)
        else:
            self.heatmaps = None

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        # load image
        fname = self.image_list[index]
        img = Image.open(join(self.img_path, fname)).convert('RGB')
        img = self.pre_process(img)

        # build multi-scale
        x4 = self.down64(img)    # 64x64
        x2 = self.down32(x4)     # 32x32
        lr = self.down16(x2)     # 16x16

        # to tensor
        hr_tensor = self.totensor(img)
        x4_tensor = self.totensor(x4)
        x2_tensor = self.totensor(x2)
        lr_tensor = self.totensor(lr)

        # load heatmap (already 128×128)
        if self.heatmaps is not None:
            hm = self.heatmaps[index]              # numpy array (128,128)
            heat = torch.from_numpy(hm.copy()).unsqueeze(0)  # (1,128,128)
        else:
            heat = torch.zeros(1,128,128)

        return x2_tensor, x4_tensor, hr_tensor, lr_tensor, heat

### SuperResolutionUNet Architecture

This is an **efficient U-Net–style model** for image super-resolution, inspired by encoder–decoder
designs but optimized for speed and stability.  
- **Input:** Low-resolution 16×16 RGB image.  
- **Output:** High-resolution 128×128 RGB reconstruction.  
- **Encoder:** Four `DoubleConv` blocks with strided convolutions instead of pooling, progressively downsampling features.  
- **Bottleneck:** A deeper `DoubleConv` block processes the compressed representation.  
- **Decoder:** Symmetric `UpBlock` modules (upsample + skip connection) restore spatial resolution to 16×16.  
- **Learned Upsampling:** Three `UpLearn` stages (nearest-neighbor + conv) upscale from 16×16 → 32×32 → 64×64 → 128×128.  
- **Refinement:** Several residual blocks refine details at full resolution.  
- **Global Skip:** The bilinearly upscaled input is added to the output for stability and better identity preservation.  

This structure is lightweight yet expressive, making it suitable for fast face SR experiments and attention-based enhancements.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def conv3x3(in_ch, out_ch, stride=1):
    return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=True)

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv = nn.Sequential(
            conv3x3(in_ch, out_ch, stride=stride),
            nn.LeakyReLU(0.2, inplace=True),
            conv3x3(out_ch, out_ch),
            nn.LeakyReLU(0.2, inplace=True),
        )
    def forward(self, x): return self.conv(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv = DoubleConv(in_ch + skip_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class UpLearn(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(ch, ch, 3, padding=1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
        )
    def forward(self, x): return self.up(x)

class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(ch, ch, 3, padding=1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch, ch, 3, padding=1, bias=True),
        )
    def forward(self, x): return x + self.body(x)

class SuperResolutionUNet(nn.Module):
    """
    Efficient U-Net SR (16x16 → 128x128)
    - Strided convs instead of MaxPool
    - align_corners=False
    - 3x UpLearn (nearest+conv)
    - Small refine head at 128x
    - Optional heatmap injection at refine stage
    """
    def __init__(self, in_channels=3, base_filters=32, out_channels=3, refine_blocks=3):
        super().__init__()

        # Encoder (strided)
        self.enc1 = DoubleConv(in_channels,       base_filters,   stride=1)  # 16x
        self.enc2 = DoubleConv(base_filters,      base_filters*2, stride=2)  # 8x
        self.enc3 = DoubleConv(base_filters*2,    base_filters*4, stride=2)  # 4x
        self.enc4 = DoubleConv(base_filters*4,    base_filters*8, stride=2)  # 2x bottleneck in size after next

        # Bottleneck
        self.bottleneck = DoubleConv(base_filters*8, base_filters*8)

        # Decoder (back to 16x)
        self.up3 = UpBlock(base_filters*8, base_filters*4, base_filters*4)  # 2x->4x, skip e3
        self.up2 = UpBlock(base_filters*4, base_filters*2, base_filters*2)  # 4x->8x, skip e2
        self.up1 = UpBlock(base_filters*2, base_filters,   base_filters)    # 8x->16x, skip e1

        # Learned upsampling to 128x
        self.up_learn1 = UpLearn(base_filters)  # 16->32
        self.up_learn2 = UpLearn(base_filters)  # 32->64
        self.up_learn3 = UpLearn(base_filters)  # 64->128

        # Refine head
        refine_in = base_filters
        self.refine_in = nn.Conv2d(refine_in, base_filters, 1)
        self.refine = nn.Sequential(*[ResBlock(base_filters) for _ in range(refine_blocks)])

        # Final projection + global skip
        self.final_conv = nn.Conv2d(base_filters, out_channels, 1)

    def forward(self, x):
        # encoder
        e1 = self.enc1(x)        # 16x
        e2 = self.enc2(e1)       # 8x
        e3 = self.enc3(e2)       # 4x
        e4 = self.enc4(e3)       # 2x

        b  = self.bottleneck(e4)

        # decoder to 16x
        d3 = self.up3(b, e3)     # 4x
        d2 = self.up2(d3, e2)    # 8x
        d1 = self.up1(d2, e1)    # 16x

        # learned upsampling to 128x
        u1 = self.up_learn1(d1)  # 32x
        u2 = self.up_learn2(u1)  # 64x
        u3 = self.up_learn3(u2)  # 128x

        r = self.refine(self.refine_in(u3))
        out = self.final_conv(r)

        # global residual skip (bilinear, no corners)
        up_input = F.interpolate(x, size=out.shape[2:], mode='bilinear', align_corners=False)
        return out + up_input


### Loss Functions for Super-Resolution Training

This section defines multiple complementary loss functions used to guide the SR model:

1. **Pixel Loss (MSE)**  
   - Standard mean squared error between the super-resolved (`sr`) and high-resolution (`hr`) images.  
   - Encourages overall pixel-level accuracy.

2. **Perceptual Loss (VGG-based)**  
   - Extracts features from a pretrained VGG-16 network (layers conv1_2, conv2_2, conv3_3).  
   - Inputs are expected in `[-1, 1]` range; internally mapped to `[0, 1]` and normalized with ImageNet statistics.  
   - Computes the sum of MSE losses between corresponding SR and HR feature maps.  
   - Focuses on perceptual similarity and high-level structure rather than exact pixel match.

3. **Attention Loss (Masked MAE)**  
   - Uses a heatmap (e.g., from facial landmark detection) to weight pixel errors, focusing more on important regions (eyes, mouth, nose, etc.).  
   - Supports parameters:
     - `gamma`: Controls how strongly to emphasize hot areas.
     - `floor`: Minimum weight for pixels outside the mask to keep gradients stable.
   - Normalizes weights per sample so the loss is scale-invariant.

4. **LPIPS Loss**  
   - Learned perceptual similarity metric based on deep features (VGG backbone).  
   - Inputs in `[-1, 1]`; returns a scalar measuring perceptual distance between SR and HR.  
   - Complements MSE/VGG losses by capturing human-judged similarity.

**Usage:**  
These losses are typically combined with tuned weights (e.g., more weight on perceptual/attention losses for better visual quality, more on pixel loss for higher PSNR).


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16, VGG16_Weights
import lpips

# -------------------------
# Pixel loss
# -------------------------
pixel_crit = nn.MSELoss()

# -------------------------
# Perceptual (VGG) loss
# -------------------------
class VGGPerceptualLoss(nn.Module):
    """
    Expects inputs in [-1,1]. Internally maps to [0,1] and applies ImageNet mean/std.
    Runs VGG feature extraction in float32 (even under autocast) for stability.
    """
    def __init__(self):
        super().__init__()
        vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_FEATURES).features
        self.slice1 = nn.Sequential(*list(vgg[:4])).eval()   # conv1_2
        self.slice2 = nn.Sequential(*list(vgg[4:9])).eval()  # conv2_2
        self.slice3 = nn.Sequential(*list(vgg[9:16])).eval() # conv3_3
        for m in (self.slice1, self.slice2, self.slice3):
            for p in m.parameters():
                p.requires_grad = False

        # ImageNet norm buffers
        self.register_buffer('mean', torch.tensor([0.485,0.456,0.406]).view(1,3,1,1))
        self.register_buffer('std',  torch.tensor([0.229,0.224,0.225]).view(1,3,1,1))

    def _prep(self, x: torch.Tensor) -> torch.Tensor:
        # [-1,1] -> [0,1] -> ImageNet norm
        x01 = (x.clamp(-1,1) + 1) / 2
        return (x01 - self.mean) / self.std

    def forward(self, sr: torch.Tensor, hr: torch.Tensor) -> torch.Tensor:
        # Force fp32 for VGG path even if outer training is mixed precision
        sr32 = self._prep(sr).float()
        hr32 = self._prep(hr).float()

        f1_sr, f1_hr = self.slice1(sr32), self.slice1(hr32)
        f2_sr, f2_hr = self.slice2(f1_sr),  self.slice2(f1_hr)
        f3_sr, f3_hr = self.slice3(f2_sr),  self.slice3(f2_hr)

        # Sum of MSEs across a few layers
        return (F.mse_loss(f1_sr, f1_hr) +
                F.mse_loss(f2_sr, f2_hr) +
                F.mse_loss(f3_sr, f3_hr))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
perceptual_crit = VGGPerceptualLoss().to(device)

def attention_loss(
    sr, hr, heat,
    *, gamma: float = 1.3,   # >1 = more focus on hot zones
       floor: float = 0.10,  # weight outside the mask (0..1)
       eps: float = 1e-6
):
    """
    Per-sample normalized masked MAE:
      loss_i = sum(w_i * |sr - hr|) / sum(w_i), then averaged over the batch.
    - sr, hr: (B, C, H, W) in [-1, 1]
    - heat: (B, 1, H, W) or (B, H, W)
    - gamma: selectivity of the heatmap
    - floor: minimum gradient even outside the mask
    """
    if heat is None:
        return sr.new_tensor(0.0)

    if heat.dim() == 3:
        heat = heat.unsqueeze(1)  # (B,1,H,W)

    heat = heat.to(device=sr.device, dtype=sr.dtype)
    B = heat.size(0)

    # min-max per campione -> [0,1]
    flat = heat.reshape(B, -1)
    hmin = flat.min(dim=1, keepdim=True)[0].reshape(B,1,1,1)
    hmax = flat.max(dim=1, keepdim=True)[0].reshape(B,1,1,1)
    span = (hmax - hmin)

    hn = (heat - hmin) / span.clamp_min(eps)     # [0,1]
    if abs(gamma - 1.0) > 1e-6:
        hn = hn.clamp(0,1).pow(gamma)

    # w' = floor + (1-floor)*hn  in [floor,1]
    w = floor + (1.0 - floor) * hn

    # se mappa ~costante, usa pesi uniformi (tutti 1)
    uniform = (span <= eps)
    if uniform.any():
        w = torch.where(uniform, torch.ones_like(w), w)

    # niente grad attraverso i pesi
    w = w.expand_as(sr).detach()

    # riduzione per-sample (reshape evita problemi di contiguità)
    w_flat   = w.reshape(B, -1)
    mae_flat = (w * (sr - hr).abs()).reshape(B, -1)
    loss_per_sample = mae_flat.sum(dim=1) / w_flat.sum(dim=1).clamp_min(eps)
    return loss_per_sample.mean()


# -------------------------
# LPIPS loss
# -------------------------
# lpips expects inputs in [-1,1]; returns (B,1,1,1) or (B,)
lpips_crit = lpips.LPIPS(net='vgg').to(device)

def lpips_loss(sr, hr):
    return lpips_crit(sr, hr).mean()

## `Trainer` Class Overview

The `Trainer` class handles **training**, **evaluation**, and **metric computation** for a **Super-Resolution U-Net** model with optional perceptual, attention, and LPIPS losses.
It supports **mixed-precision training** with `GradScaler` and computes multiple image quality metrics during training.

---

### **Initialization (`__init__`)**

* **Stores configuration** (`cfg`).
* Builds the **Generator network** (`SuperResolutionUNet`) with parameters from `cfg`.
* Creates an **Adam optimizer** for the generator.
* Sets up a **`GradScaler`** for mixed-precision training (enabled if CUDA is available).
* Defines `metric_stride` to control how often metrics are computed during training.
* Initializes **Multi-Scale SSIM** metric object (`MSSSIMMetric`) configured for images in `[-1, 1]`.

---

### **`psnr_from_mse(mse, max_val)`**

* Computes **Peak Signal-to-Noise Ratio (PSNR)** from a Mean Squared Error value, assuming a given pixel value range (`max_val=2.0` for `[-1, 1]`).

---

### **`compute_overall_metrics(sr, hr)`**

* Takes **super-resolved** (`sr`) and **high-resolution** (`hr`) tensors in `[-1, 1]`.
* Computes:

  * **PSNR** (per batch, averaged).
  * **SSIM** (single-scale).
  * **MS-SSIM** (multi-scale).
  * **LPIPS** (perceptual similarity).
* Returns all metrics in a dictionary.

---

### **`train_epoch(loader)`**

* Runs one full training epoch over the data `loader`.
* Steps:

  1. Sets model to **train** mode.
  2. Iterates over the dataloader:

     * Moves data to GPU.
     * **Forward pass** inside `autocast` for mixed precision:

       * Computes pixel loss (**MSE**).
       * Optionally computes **perceptual loss**, **attention loss**, **LPIPS loss** depending on `cfg`.
       * Combines them using configured weights (`w_perc`, `w_attn`, `w_lpips`).
     * **Backward pass** using gradient scaling.
     * Updates optimizer.
  3. Every `metric_stride` steps, computes and accumulates **overall quality metrics**.
  4. Returns **average losses and metrics** over the epoch.

---

### **`evaluate(loader, num_samples)`**

* Runs evaluation (no gradients) on at most `num_samples` batches.
* Steps:

  1. Sets model to **eval** mode.
  2. Iterates over the loader:

     * Computes `sr` outputs.
     * Measures overall metrics (`psnr_overall`, `ssim_overall`, `msssim_overall`, `lpips_overall`).
     * Computes all active loss components and the combined loss.
  3. Returns **average losses and metrics** over evaluated samples.

In [None]:
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from torchmetrics.functional import structural_similarity_index_measure as ssim
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure as MSSSIMMetric


class Trainer:
    def __init__(self, cfg):
        self.cfg    = cfg
        self.G = SuperResolutionUNet(
            in_channels   = cfg.get('in_channels', 3),
            base_filters  = cfg.get('base_filters', 32),
            out_channels  = cfg.get('out_channels', 3),
            refine_blocks = cfg.get('refine_blocks', 3)
        ).to(device)
        self.optG   = torch.optim.Adam(self.G.parameters(), lr=cfg['lr_g'])
        self.scaler = GradScaler(enabled=torch.cuda.is_available())
        self.metric_stride = 5  # compute metrics every N batches
        self.msssim_metric = MSSSIMMetric(
            data_range=2.0,
            kernel_size=(7, 7),                         # smaller window for 128×128
            betas=(0.0448, 0.2856, 0.3001, 0.2363),     # 4 scales
            normalize="relu",
        ).to(device)

    @staticmethod
    def psnr_from_mse(mse, max_val=2.0):  # max_val=2.0 for [-1,1]
        return 10.0 * torch.log10((max_val * max_val) / (mse + 1e-8))

    def compute_overall_metrics(self, sr, hr):
        """Overall PSNR/SSIM/MS-SSIM/LPIPS with inputs in [-1,1]."""
        with torch.no_grad():
            mse_overall = ((sr - hr) ** 2).mean(dim=(1,2,3))
            psnr_overall   = self.psnr_from_mse(mse_overall, max_val=2.0).mean().item()
            ssim_overall   = float(ssim(sr, hr, data_range=2.0))
            msssim_overall = float(self.msssim_metric(sr, hr).cpu())
            self.msssim_metric.reset()
            lpips_overall  = lpips_loss(sr, hr)
            if torch.is_tensor(lpips_overall):
                lpips_overall = lpips_overall.mean().item()
            return {
                'psnr_overall': psnr_overall,
                'ssim_overall': ssim_overall,
                'msssim_overall': msssim_overall,
                'lpips_overall': lpips_overall
            }

    def train_epoch(self, loader):
        agg = {
            'loss_pixel': 0.0, 'loss_perc': 0.0, 'loss_attn': 0.0,
            'loss_lpips': 0.0, 'loss_combined': 0.0,
            'psnr_overall': 0.0, 'ssim_overall': 0.0,
            'msssim_overall': 0.0, 'lpips_overall': 0.0
        }
        self.G.train()
        step = 0
        metric_steps = 0
        use_cuda = torch.cuda.is_available()

        for _, _, hr, lr, heat in tqdm(loader, desc=f"Training {self.cfg['name']}"):
            lr   = lr.to(device, non_blocking=True)
            hr   = hr.to(device, non_blocking=True)
            heat = heat.to(device, non_blocking=True) if self.cfg.get('use_attn', False) else None

            self.optG.zero_grad(set_to_none=True)
            with autocast(device_type='cuda', enabled=use_cuda, dtype=torch.float16):
                sr    = self.G(lr)
                Lpix  = pixel_crit(sr, hr)
                Lperc = perceptual_crit(sr, hr)              if self.cfg.get('use_perc',  False) else 0.0
                Lattn = attention_loss(sr, hr, heat)         if self.cfg.get('use_attn',  False) else 0.0
                Llp   = lpips_loss(sr, hr)                   if self.cfg.get('use_lpips', False) else 0.0

                loss  = Lpix \
                      + self.cfg.get('w_perc',0.0)  * (Lperc if isinstance(Lperc, torch.Tensor) else 0.0) \
                      + self.cfg.get('w_attn',0.0)  * (Lattn if isinstance(Lattn, torch.Tensor) else 0.0) \
                      + self.cfg.get('w_lpips',0.0) * (Llp   if isinstance(Llp,   torch.Tensor) else 0.0)

            self.scaler.scale(loss).backward()
            self.scaler.step(self.optG)
            self.scaler.update()

            agg['loss_pixel']    += float(Lpix.detach())
            agg['loss_perc']     += float(Lperc.detach()) if self.cfg.get('use_perc',  False) else 0.0
            agg['loss_attn']     += float(Lattn.detach()) if self.cfg.get('use_attn',  False) else 0.0
            agg['loss_lpips']    += float(Llp.detach())   if self.cfg.get('use_lpips', False) else 0.0
            agg['loss_combined'] += float(loss.detach())

            if step % self.metric_stride == 0:
                m = self.compute_overall_metrics(sr.detach(), hr)
                for k in ('psnr_overall','ssim_overall','msssim_overall','lpips_overall'):
                    agg[k] += m[k]
                metric_steps += 1

            step += 1

        out = {}
        for k in ('loss_pixel','loss_perc','loss_attn','loss_lpips','loss_combined'):
            out[k] = agg[k] / step
        for k in ('psnr_overall','ssim_overall','msssim_overall','lpips_overall'):
            out[k] = agg[k] / (metric_steps if metric_steps > 0 else 1)
        return out

    def evaluate(self, loader, num_samples=500):
        self.G.eval()
        agg = {
            'loss_pixel':0.0, 'loss_perc':0.0, 'loss_attn':0.0, 'loss_lpips':0.0, 'loss_combined':0.0,
            'psnr_overall':0.0, 'ssim_overall':0.0, 'msssim_overall':0.0, 'lpips_overall':0.0
        }
        n = 0
        with torch.no_grad():
            for _, _, hr, lr, heat in tqdm(loader, desc=f"Evaluating {self.cfg['name']}"):
                if n >= num_samples:
                    break
                lr   = lr.to(device, non_blocking=True)
                hr   = hr.to(device, non_blocking=True)
                heat = heat.to(device, non_blocking=True) if self.cfg.get('use_attn', False) else None

                sr = self.G(lr)

                m = self.compute_overall_metrics(sr, hr)
                for k in ('psnr_overall','ssim_overall','msssim_overall','lpips_overall'):
                    agg[k] += m[k]

                Lpix  = pixel_crit(sr, hr)
                Lperc = perceptual_crit(sr, hr)              if self.cfg.get('use_perc',  False) else 0.0
                Lattn = attention_loss(sr, hr, heat)         if self.cfg.get('use_attn',  False) else 0.0
                Llp   = lpips_loss(sr, hr)                   if self.cfg.get('use_lpips', False) else 0.0
                Lcomb = Lpix \
                      + self.cfg.get('w_perc',0.0)  * (Lperc if isinstance(Lperc, torch.Tensor) else 0.0) \
                      + self.cfg.get('w_attn',0.0)  * (Lattn if isinstance(Lattn, torch.Tensor) else 0.0) \
                      + self.cfg.get('w_lpips',0.0) * (Llp   if isinstance(Llp,   torch.Tensor) else 0.0)

                agg['loss_pixel']    += float(Lpix)
                agg['loss_perc']     += float(Lperc) if self.cfg.get('use_perc', False) else 0.0
                agg['loss_attn']     += float(Lattn) if self.cfg.get('use_attn', False) else 0.0
                agg['loss_lpips']    += float(Llp)   if self.cfg.get('use_lpips', False) else 0.0
                agg['loss_combined'] += float(Lcomb)

                n += 1

        return {k: (v / max(n,1)) for k, v in agg.items()}


# Training Script Overview

This script wires up datasets, a `Trainer`, schedulers, and utilities to **train/evaluate** a Super-Resolution U-Net with optional losses (pixel/perceptual/LPIPS/attention), prints metrics (incl. **MS-SSIM**), **adapts loss weights over time**, and handles **checkpointing + early stop**.

---

## 1) Setup & Reproducibility

* Selects `device` (CUDA if available) and fixes seeds (`torch`, `numpy`) with deterministic cuDNN.
* Optional pretrained checkpoint path: `ckpt_path`.

---

## 2) Data, Limits, and Loaders

* Reads the number of available heatmaps from `heatmaps.h5` → `HM_COUNT`.
* Sets **training limit** to `HM_COUNT` and **validation limit** to 20% of that.
* Builds `CelebDataSet` for `'train'` and `'val'`, clamps subset sizes to the limits, and wraps them in `DataLoader`s (batch size 64, 2 workers, `pin_memory` if CUDA).
* Prepares a **fixed sample** (`lr_vis`, `hr_vis`, `heat_vis`) from the first training batch for quick visualization.

---

## 3) Model Config(s) → Trainers

* Defines one config (ablation example):

  * `name='ABLATION_30K_mse+vgg+lpips'`
  * `use_attn=False`, `use_perc=True`, `use_lpips=True`
  * loss weights initialized: `w_attn=0`, `w_perc=0.02`, `w_lpips=0.15`
  * optimizer LR `lr_g=1e-4`
  * UNet capacity: `base_filters=48`, `refine_blocks=5`
* For each config:

  * Instantiates a `Trainer(cfg)` (your Trainer class builds the SR-UNet, optimizer, AMP scaler, metrics).
  * Attaches **ReduceLROnPlateau** (`mode='min'`, `patience=6`, `factor=0.5`) targeting a **validation objective** (see §6).

---

## 4) Utilities

### LossWeightScheduler

Time-based weights over the course of training (`epoch/num_epochs`):

* **0–30% (warm-up):** ramp `w_perc` 0→0.06, `w_attn` 0.7→1.2, `w_lpips` off.
* **30–85% (mid):** `w_perc` 0.06→0.09, `w_attn` 1.2→2.0, `w_lpips` 0.00→0.22.
* **85–100% (late):** `w_perc` fixed 0.09, `w_attn` 2.0→1.0, `w_lpips` 0.22→0.25.
  (If a loss is disabled in `cfg`, its weight is forced to 0.)

### EarlyStopping

* Tracks best **validation objective**; stops if it hasn’t improved by `min_delta=1e-4` for `patience=15` epochs.

### Validation Objective (`val_objective`)

Scalar to **minimize** (lower is better):

```
1.5*LPIPS + 0.8*max(0, 0.02 - SSIM) + 0.2*max(0, 20 - PSNR)
```

Penalizes high LPIPS, SSIM below 0.02, and PSNR below 20 dB.

### Misc

* `to01(x)`: maps tensors from `[-1,1]` → `[0,1]` for visualization.
* `save_checkpoint(...)`: saves a **full** training state dict (model/optimizer/scaler/scheduler + early-best).

---

## 5) (Optional) Checkpoint Restore

If `ckpt_path` is set:

* Loads `model/opt/scaler/sched` states into the trainer.
* Restores `early_stop.best` (best validation objective so far).
* Sets `start_epoch = saved_epoch + 1`.
  Prints a confirmation.

---

## 6) Epoch Loop

For `epoch in start_epoch .. num_epochs`:

### a) Update Loss Weights

* Calls `LossWeightScheduler.weights_at(...)` and writes `w_perc/w_attn/w_lpips` into each `cfg`.

### b) TRAIN

* For each config:

  * `train_epoch(loader)` on its `Trainer`:

    * Mixed precision + gradient scaling.
    * Computes the enabled losses and their **weighted sum**.
    * Periodically (every `metric_stride` steps from the `Trainer`) accumulates **PSNR / SSIM / MS-SSIM / LPIPS** on the training stream.
  * Appends per-epoch training metrics to `history`.

* Prints a **training table** per model with:

  * `pix`, `perc`, `attn`, `lpips_loss`, `comb`
  * `PSNR`, `SSIM`, `MS-SSIM`, `LPIPS`

### c) VALIDATION

* For each config:

  * `evaluate(val_loader)` (no grad, averages over the set).
  * Appends validation metrics to `history` (prefixed with `val_`).

* Computes the **reference** validation objective on the **first config** (`ref_name = configs[0]['name']`), then:

  * Steps each **ReduceLROnPlateau** with that model’s own validation objective.
  * Steps **EarlyStopping** with the reference objective.

* Prints a **validation table** with the same metrics plus `val_obj`.

### d) Quick Visualization

* Grabs 5 examples from the validation loader.
* Shows: **upsampled LR**, **HR GT**, and **each model’s reconstruction** (mapped to `[0,1]`).

### e) Save Best & Periodic Checkpoints

* **Best (by val objective, first config):** saves `./<ref_name>/<ref_name>_best.pth` (weights only).
* **Every 10 epochs:** saves a **full checkpoint** (`.pt`) for each config via `save_checkpoint(...)` and triggers `files.download(path)` (handy in Colab).

### f) Early Stop

* If `early_stop.should_stop` is set, prints a message with the **best val objective** and breaks.

In [None]:
# --- core imports
import os, h5py, numpy as np
import torch
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from tabulate import tabulate
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
import h5py
from google.colab import files


# --- device & seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pin = torch.cuda.is_available()

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# --- paths & basic hyperparams
data_path  = path            # <- set this to your CelebA root
heat_h5    = './heatmaps.h5' # <- your heatmaps file
batch_size = 64
num_epochs = 200
ckpt_path = None
# Uncomment this and the other part in the code if loading a pretrained version of the model
#ckpt_path = "./PATH_TO_CHECKPOINT.pt"

with h5py.File(heat_h5, 'r') as f:
    HM_COUNT = int(f['heatmaps'].shape[0])
print("Heatmaps available:", HM_COUNT)

TRAIN_LIMIT = HM_COUNT                        # train limited by heatmaps (check not needed but assure consistency and prevents errors)
VAL_LIMIT   = max(1, int(0.2 * TRAIN_LIMIT))  # 20% of train (if we have 10k training samples, we limit our validation to be 2000)

# --- build datasets
train_ds = CelebDataSet(data_path, 'train', heatmap_h5=heat_h5)
val_ds   = CelebDataSet(data_path, 'val',   heatmap_h5=heat_h5)

train_n = min(TRAIN_LIMIT, len(train_ds))
val_n   = min(VAL_LIMIT,   len(val_ds))

print(f"Using TRAIN_LIMIT={train_n}, VAL_LIMIT={val_n}")

train_subset = Subset(train_ds, range(train_n))
val_subset   = Subset(val_ds,   range(val_n))

# --- dataloaders
loader     = DataLoader(train_subset, batch_size=batch_size, shuffle=True,
                        num_workers=2, pin_memory=pin)
val_loader = DataLoader(val_subset,   batch_size=batch_size, shuffle=False,
                        num_workers=2, pin_memory=pin)

print("Train samples:", len(train_subset))
print("Val samples:  ", len(val_subset))

# --- one fixed batch for viz
# This batch will be used to visualize validation results during validation process in the training loop
data_iter = iter(loader)
_, _, hr_f, lr_f, heat_f = next(data_iter)
lr_vis, hr_vis, heat_vis = [t.to(device, non_blocking=True) for t in (lr_f[0:1], hr_f[0:1], heat_f[0:1])]

# --- single config (you can add more later)
configs = [
    # This commented is the model that worked better. Take it as reference for changes.
    dict(
        name='base_on_10k',
        use_attn=True, use_perc=True, use_lpips=True,
        w_attn=2.0, w_perc=0.02, w_lpips=0.15,
        lr_g=1e-4,
        base_filters=48,
        refine_blocks=5
    ),
]

for cfg in configs:
    for k in ('w_perc','w_attn','w_lpips'):
        cfg.setdefault(k, 0.0)
        cfg.setdefault(f'_base_{k}', cfg[k])
    cfg.setdefault('use_perc', False)
    cfg.setdefault('use_attn', False)
    cfg.setdefault('use_lpips', False)
    cfg.setdefault('schedule', True)

# --- build trainers from configs
trainers = {}
for cfg in configs:
    t = Trainer(cfg)
    if 'metric_stride' in cfg:
        t.metric_stride = cfg['metric_stride']
    t.lr_sched = ReduceLROnPlateau(t.optG, mode='min', patience=6, factor=0.5, verbose=True)
    trainers[cfg['name']] = t

# ==== helpers ===============================================================
# Dynamically adjusts the weights of perceptual, attention, and LPIPS losses
# throughout training based on the current epoch fraction.
#
# The schedule is divided into three phases:
#   1. Warm-up (0–30% of total epochs): Gradually increases perceptual and attention
#      loss weights from their starting values; LPIPS loss remains off.
#   2. Mid-phase (30–85%): Slowly ramps perceptual and LPIPS weights to target values,
#      and increases attention weight further.
#   3. Late-phase (85–100%): Keeps perceptual weight fixed, slightly decreases attention
#      weight, and slightly boosts LPIPS weight to emphasize fine texture details.
#
# If a given loss type is disabled in `cfg`, its weight is kept at 0 for the whole schedule.
# 2) Scheduler che produce SCALE, poi moltiplica per i target base
class LossWeightScheduler:
    def __init__(self, num_epochs):
        self.E = num_epochs

    def scales_at(self, epoch, cfg):
        """Ritorna scale (0..1) per perc/attn/lpips, NON pesi assoluti."""
        t = epoch / self.E  # 0..1

        if t <= 0.30:
            u = (t / 0.30)
            s_perc  = 0.00 + 1.00 * u      # 0 -> 1.00 del target base
            s_attn  = 0.60 + 0.40 * u      # 0.60 -> 1.00 del target base
            s_lpips = 0.00                 # off
        elif t <= 0.85:
            u = (t - 0.30) / 0.55
            s_perc  = 1.00                 # già pieno target
            s_attn  = 1.00                 # già pieno target
            s_lpips = 0.00 + 1.00 * u      # 0 -> 1.00 del target base
        else:
            u = (t - 0.85) / 0.15
            s_perc  = 1.00
            s_attn  = 1.00 - 0.25 * u      # cala al 75% del target base
            s_lpips = 1.00                 # pieno target
        return s_perc, s_attn, s_lpips

    def apply(self, epoch, cfg):
        if not cfg.get('schedule', True):
            # non toccare nulla
            return
        use_perc  = cfg.get('use_perc',  False)
        use_attn  = cfg.get('use_attn',  False)
        use_lpips = cfg.get('use_lpips', False)

        s_perc, s_attn, s_lpips = self.scales_at(epoch, cfg)

        base_perc  = cfg.get('_base_w_perc',  0.0)
        base_attn  = cfg.get('_base_w_attn',  0.0)
        base_lpips = cfg.get('_base_w_lpips', 0.0)

        cfg['w_perc']  = (base_perc  * s_perc)  if use_perc  else 0.0
        cfg['w_attn']  = (base_attn  * s_attn)  if use_attn  else 0.0
        cfg['w_lpips'] = (base_lpips * s_lpips) if use_lpips else 0.0


# Monitors a validation score and triggers early stopping when it stops improving.
#
# Parameters:
#   patience  – number of consecutive epochs without significant improvement
#               (greater than `min_delta`) before stopping.
#   min_delta – minimum required improvement in the monitored score to be considered progress.
#
# Behavior:
#   - Tracks the best (lowest) score seen so far.
#   - Resets the bad epoch counter when improvement is detected.
#   - Increments the bad epoch counter otherwise.
#   - Sets `should_stop=True` when the bad epoch counter reaches `patience`.

class EarlyStopping:
    def __init__(self, patience=15, min_delta=1e-4):
        self.patience=patience; self.min_delta=min_delta
        self.best=None; self.bad_epochs=0; self.should_stop=False
    def step(self, score):
        if self.best is None or score < self.best - self.min_delta:
            self.best=score; self.bad_epochs=0
        else:
            self.bad_epochs += 1
            if self.bad_epochs >= self.patience: self.should_stop=True

# Computes a scalar "validation objective" score to guide model selection,
# learning rate scheduling, and early stopping. Lower is better.
#
# The score is a weighted sum of penalties:
#   - 1.5 * LPIPS: strong penalty for poor perceptual similarity (higher LPIPS).
#   - 0.8 * max(0, 0.02 - SSIM): penalty if SSIM falls below 0.02 (no penalty otherwise).
#   - 0.2 * max(0, 20.0 - PSNR): penalty if PSNR is below 20 dB (no penalty otherwise).
#
# Inputs:
#   m – dictionary of validation metrics containing:
#       'ssim_overall', 'lpips_overall', 'psnr_overall'
#
# Output:
#   A single float representing the validation objective; smaller values indicate better quality.
def val_objective(m):
    ssim = float(m.get('ssim_overall', 0.0))
    lp   = float(m.get('lpips_overall', 1.0))
    psnr = float(m.get('psnr_overall', 0.0))
    return (1.5*lp) + (0.8*max(0.0, 0.02-ssim)) + (0.2*max(0.0, 20.0-psnr))

def to01(x): return (x.clamp(-1,1) + 1)/2

def save_checkpoint(trainer, history, epoch, best_score, path):
    torch.save({
        'epoch': epoch,
        'model': trainer.G.state_dict(),
        'opt': trainer.optG.state_dict(),
        'scaler': trainer.scaler.state_dict(),
        'sched': trainer.lr_sched.state_dict() if hasattr(trainer, 'lr_sched') else None,
        'early_best': best_score,
        'history': history,
    }, path)

# ===========================================================================


loss_sched = LossWeightScheduler(num_epochs)
early_stop = EarlyStopping(patience=15, min_delta=1e-4)
history    = {}
start_epoch = 1


if ckpt_path is not None:
  name = configs[0]['name']
  ckpt = torch.load(ckpt_path, map_location=device)
  trainers[name].G.load_state_dict(ckpt['model'])
  trainers[name].optG.load_state_dict(ckpt['opt'])
  trainers[name].scaler.load_state_dict(ckpt['scaler'])
  if ckpt.get('sched'):
      trainers[name].lr_sched.load_state_dict(ckpt['sched'])
  print("Loaded configuration from", ckpt_path)

  early_stop.best = ckpt.get('early_best', float('inf'))
  trainers[name]._best_score = early_stop.best
  start_epoch = ckpt.get('epoch', 0) + 1

  history = ckpt.get('history', {})

last_round_epoch = 1
# TRAINING LOOP
for epoch in range(start_epoch, num_epochs+1):
    print(f"\n=== Epoch {epoch}/{num_epochs} ===")

    # update loss weights
    for cfg in configs:
        loss_sched.apply(epoch, cfg)

    # ---- TRAIN ----
    epoch_metrics = {}
    for cfg in configs:
        name    = cfg['name']
        metrics = trainers[name].train_epoch(loader)
        epoch_metrics[name] = metrics

        # history (train)
        hist = history.setdefault(name, {})
        for k, v in metrics.items():
            hist.setdefault(k, []).append(v)

    # print train (now with MS-SSIM)
    headers = ["Model","pix","perc","attn","lpips_loss","comb","PSNR","SSIM","MS-SSIM","LPIPS"]
    table = []
    for cfg in configs:
        name = cfg['name']; m = epoch_metrics[name]
        table.append([
            name,
            f"{m['loss_pixel']:.4e}",
            f"{m['loss_perc']:.4e}"    if cfg.get('use_perc', False)  else "-",
            f"{m['loss_attn']:.4e}"    if cfg.get('use_attn', False)  else "-",
            f"{m['loss_lpips']:.4e}"   if cfg.get('use_lpips', False) else "-",
            f"{m['loss_combined']:.4e}",
            f"{m['psnr_overall']:.2f}",
            f"{m['ssim_overall']:.4f}",
            f"{m['msssim_overall']:.4f}",
            f"{m['lpips_overall']:.4f}",
        ])
    print(tabulate(table, headers=headers, tablefmt="github"))

    # ---- VALIDATION ----
    val_metrics = {}
    for cfg in configs:
        name = cfg['name']
        vm = trainers[name].evaluate(val_loader)
        val_metrics[name] = vm

        # history (val)
        hist = history[name]
        for k, v in vm.items():
            hist.setdefault('val_' + k, []).append(v)

    ref_name = configs[0]['name']
    ref_val  = val_metrics[ref_name]
    score    = val_objective(ref_val)

    # schedulers + early stop
    for cfg in configs:
        trainers[cfg['name']].lr_sched.step(val_objective(val_metrics[cfg['name']]))
    early_stop.step(score)

    # print val (now with MS-SSIM)
    headers = ["Model","pix","perc","attn","lpips_loss","comb","PSNR","SSIM","MS-SSIM","LPIPS","val_obj"]
    table = []
    for cfg in configs:
        name = cfg['name']; m = val_metrics[name]
        table.append([
            name,
            f"{m['loss_pixel']:.4e}",
            f"{m['loss_perc']:.4e}"    if cfg.get('use_perc', False)  else "-",
            f"{m['loss_attn']:.4e}"    if cfg.get('use_attn', False)  else "-",
            f"{m['loss_lpips']:.4e}"   if cfg.get('use_lpips', False) else "-",
            f"{m['loss_combined']:.4e}",
            f"{m['psnr_overall']:.2f}",
            f"{m['ssim_overall']:.4f}",
            f"{m['msssim_overall']:.4f}",
            f"{m['lpips_overall']:.4f}",
            f"{val_objective(m):.4f}",
        ])
    print("\n")
    print(tabulate(table, headers=headers, tablefmt="github"))


    # quick viz
    if epoch % 5 ==0:
        val_iter = iter(val_loader)
        _, _, hr_val_b, lr_val_b, _ = next(val_iter)
        lr_val = lr_val_b[:5].to(device)
        hr_val = hr_val_b[:5].cpu()

        lr_up_val = to01(F.interpolate(lr_val, size=(128,128), mode='bilinear', align_corners=False)).cpu()
        recon_val = {}
        with torch.no_grad():
            for cfg in configs:
                name = cfg['name']
                sr = trainers[name].G.eval()(lr_val)
                recon_val[name] = to01(sr).cpu()

        fig, axes = plt.subplots(5, 2 + len(configs), figsize=(4*(2+len(configs)), 20))
        for i in range(5):
            row = axes[i]
            row[0].imshow(lr_up_val[i].permute(1,2,0)); row[0].set_title("LR ↑"); row[0].axis('off')
            row[1].imshow(to01(hr_val[i]).permute(1,2,0)); row[1].set_title("HR GT"); row[1].axis('off')
            for j, cfg in enumerate(configs, start=2):
                nm = cfg['name']
                row[j].imshow(recon_val[nm][i].permute(1,2,0)); row[j].set_title(nm); row[j].axis('off')
        plt.suptitle(f"Epoch {epoch} — Validation Reconstructions", fontsize=16)
        plt.tight_layout(); plt.show()

    # save best + periodic
    best_dir = f"./{ref_name}"; os.makedirs(best_dir, exist_ok=True)
    best_path = os.path.join(best_dir, f"{ref_name}_best.pth")
    if score <= getattr(trainers[ref_name], "_best_score", float("inf")):
        torch.save(trainers[ref_name].G.state_dict(), best_path)
        trainers[ref_name]._best_score = score

    if epoch % 10 == 0:
        for cfg in configs:
            name   = cfg['name']
            last_round_epoch = epoch
            folder = f"./{name}"; os.makedirs(folder, exist_ok=True)
            path_to_save   =  f"./{name}/{name}_epoch{epoch:03d}.pt"
            save_checkpoint(trainers[name], history, epoch, early_stop.best,path_to_save)

    if early_stop.should_stop:
        print(f"\nEarly stopping at epoch {epoch} (best val objective: {early_stop.best:.4f}).")

        # --- load the best checkpoint ---
        best_ckpt = torch.load(best_path, map_location=device)
        trainers[ref_name].G.load_state_dict(best_ckpt)

        # --- recompute validation metrics for best model ---
        best_metrics = trainers[ref_name].evaluate(val_loader)

        print("\n=== Best Model Validation Metrics ===")
        print(tabulate([[
            ref_name,
            f"{best_metrics['loss_pixel']:.4e}",
            f"{best_metrics['loss_perc']:.4e}"    if configs[0].get('use_perc', False)  else "-",
            f"{best_metrics['loss_attn']:.4e}"    if configs[0].get('use_attn', False)  else "-",
            f"{best_metrics['loss_lpips']:.4e}"   if configs[0].get('use_lpips', False) else "-",
            f"{best_metrics['loss_combined']:.4e}",
            f"{best_metrics['psnr_overall']:.2f}",
            f"{best_metrics['ssim_overall']:.4f}",
            f"{best_metrics['msssim_overall']:.4f}",
            f"{best_metrics['lpips_overall']:.4f}",
            f"{val_objective(best_metrics):.4f}",
        ]], headers=["Model","pix","perc","attn","lpips_loss","comb","PSNR","SSIM","MS-SSIM","LPIPS","val_obj"], tablefmt="github"))

        # --- visualize best model ---
        val_iter = iter(val_loader)
        _, _, hr_val_b, lr_val_b, _ = next(val_iter)
        lr_val = lr_val_b[:5].to(device)
        hr_val = hr_val_b[:5].cpu()

        lr_up_val = to01(F.interpolate(lr_val, size=(128,128), mode='bilinear', align_corners=False)).cpu()
        recon_val = {}
        with torch.no_grad():
            for cfg in configs:
                nm = cfg['name']
                sr = trainers[nm].G.eval()(lr_val)
                recon_val[nm] = to01(sr).cpu()

        fig, axes = plt.subplots(5, 2 + len(configs), figsize=(4*(2+len(configs)), 20))
        for i in range(5):
            row = axes[i]
            row[0].imshow(lr_up_val[i].permute(1,2,0)); row[0].set_title("LR ↑"); row[0].axis('off')
            row[1].imshow(to01(hr_val[i]).permute(1,2,0)); row[1].set_title("HR GT"); row[1].axis('off')
            for j, cfg in enumerate(configs, start=2):
                nm = cfg['name']
                row[j].imshow(recon_val[nm][i].permute(1,2,0)); row[j].set_title(f"{nm} (best)"); row[j].axis('off')
        plt.suptitle(f"Best Model — Validation Reconstructions", fontsize=16)
        plt.tight_layout(); plt.show()

        # best metrics


        break


path_to_save   =  f"./{name}/{name}_epoch{last_round_epoch:03d}.pt"
files.download(path_to_save)

| Column          | Meaning                                                                                     | What It Tells You                                                                                                                                  |
| --------------- | ------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Model**       | Name of the model/config used in training.                                                  | Helps identify which configuration the metrics refer to.                                                                                           |
| **pix**         | Pixel-wise reconstruction loss (typically MSE) between SR output and ground truth HR image. | Lower is better; measures raw pixel accuracy but doesn’t account for perceptual quality.                                                           |
| **perc**        | Perceptual loss (e.g., VGG-based feature difference).                                       | Lower means the SR image is closer in high-level visual features to the HR image. Large values are normal due to the scale of feature differences. |
| **attn**        | Attention loss (if used).                                                                   | Quantifies how well the model attends to important regions (e.g., face parts). Not present (`-`) if attention is disabled.                         |
| **lpips\_loss** | Learned Perceptual Image Patch Similarity (LPIPS) metric.                                   | Lower is better; correlates well with human perception of image similarity.                                                                        |
| **comb**        | Weighted sum of all active loss components.                                                 | This is the actual loss value used for optimization.                                                                                               |
| **PSNR**        | Peak Signal-to-Noise Ratio (in dB).                                                         | Higher is better; measures reconstruction fidelity at the pixel level.                                                                             |
| **SSIM**        | Structural Similarity Index.                                                                | Higher is better (max = 1); measures structural similarity between SR and HR images.                                                               |
| **MS-SSIM**     | Multi-Scale SSIM.                                                                           | Higher is better; similar to SSIM but considers image structures at multiple scales.                                                               |
| **LPIPS**       | LPIPS metric computed as evaluation (same scale as `lpips_loss`).                           | Lower means better perceptual similarity on the validation set.                                                                                    |
| **val\_obj**    | Validation objective score (from `val_objective`).                                          | Lower is better; combines LPIPS, SSIM, and PSNR into a single scalar for model selection/early stopping.                                           |

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# pick the model you want to plot (you have only one in configs now)
name = list(history.keys())[0]
H = history[name]

def _ema(values, alpha=0.9):
    """Simple EMA over a 1D list/array. Returns a list of same length."""
    if values is None or len(values) == 0:
        return []
    out = []
    m = None
    for v in values:
        v = float(v)
        m = v if (m is None) else alpha * m + (1 - alpha) * v
        out.append(m)
    return out

def plot_metric(metric_key, title, ylabel="", invert=False, ema_alpha=0.9, show_raw=True):
    """
    Plot train/val curves with optional EMA smoothing.
    - ema_alpha: None to disable EMA; else e.g. 0.9
    - show_raw: whether to also draw raw curves (faint) alongside EMA
    """
    train = H.get(metric_key, [])
    val   = H.get(f"val_{metric_key}", [])

    plt.figure(figsize=(6,4))

    # Train
    if train:
        if ema_alpha is not None:
            tr_ema = _ema(train, alpha=ema_alpha)
            plt.plot(tr_ema, label=f"train (EMA {ema_alpha})")
            if show_raw:
                plt.plot(train, label="train (raw)", linestyle="--", alpha=0.35)
        else:
            plt.plot(train, label="train")

    # Val
    if val:
        if ema_alpha is not None:
            va_ema = _ema(val, alpha=ema_alpha)
            plt.plot(va_ema, label=f"val (EMA {ema_alpha})")
            if show_raw:
                plt.plot(val, label="val (raw)", linestyle="--", alpha=0.35)
        else:
            plt.plot(val, label="val")

    if invert:
        plt.gca().invert_yaxis()
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel if ylabel else metric_key)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

# === Losses ===
plot_metric("loss_combined", "Combined Loss", ylabel="Loss", ema_alpha=0.9)
plot_metric("loss_pixel",    "Pixel Loss (MSE)", ylabel="Loss", ema_alpha=0.9)
if "loss_perc" in H:  plot_metric("loss_perc",  "Perceptual Loss", ylabel="Loss", ema_alpha=0.9)
if "loss_attn" in H:  plot_metric("loss_attn",  "Attention Loss",  ylabel="Loss", ema_alpha=0.9)
if "loss_lpips" in H: plot_metric("loss_lpips", "LPIPS Loss",      ylabel="Loss", ema_alpha=0.9)

# === Image Quality Metrics ===
plot_metric("psnr_overall",   "PSNR",     ylabel="dB",     ema_alpha=0.9)
plot_metric("ssim_overall",   "SSIM",     ylabel="Score",  ema_alpha=0.9)
plot_metric("msssim_overall", "MS-SSIM",  ylabel="Score",  ema_alpha=0.9)
plot_metric("lpips_overall",  "LPIPS",    ylabel="Distance", invert=True, ema_alpha=0.9)

In [None]:
# This cleanup sequence is only used in case of runtime errors or memory issues.
# It manually deletes large objects, triggers Python’s garbage collector,
# clears PyTorch’s GPU memory cache, resets memory stats, and synchronizes the device
# to fully release unused GPU resources before retrying or exiting.


# 1) delete big objects
#del trainers, history

# 2) force Python GC
#import gc
#gc.collect()

# 3) release PyTorch cache
#import torch
#torch.cuda.empty_cache()
#torch.cuda.reset_peak_memory_stats()

# 4) (for good measure) synchronize
#torch.cuda.synchronize()
