In [None]:
!pip install -qq ultralytics
!pip install -qq torchmetrics
!pip install -qq lpips

In [2]:
import kagglehub
import torch
path = kagglehub.dataset_download("jessicali9530/celeba-dataset")
print("Path to dataset files:", path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
#dowload heatmaps
#small 10k
#!curl -L -o heatmaps.h5 https://huggingface.co/datasets/RiccardoCarraro/heatmaps/resolve/main/heatmaps_10k.h5

#medium 30k
!curl -L -o heatmaps.h5 https://huggingface.co/datasets/RiccardoCarraro/heatmaps/resolve/main/heatmaps_30k.h5

#uncomment the following line to use the 50k version of the dataset
#!curl -L -o heatmaps.h5 https://huggingface.co/datasets/RiccardoCarraro/heatmaps/resolve/main/heatmaps.h5

Path to dataset files: /kaggle/input/celeba-dataset
Using device: cuda
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1116  100  1116    0     0   3009      0 --:--:-- --:--:-- --:--:--  3016
100 1862M  100 1862M    0     0  52.9M      0  0:00:35  0:00:35 --:--:-- 54.5M


In [3]:
import h5py
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
from os.path import join, splitext
from PIL import Image
import csv
import numpy as np

class CelebDataSet(Dataset):
    """
    CelebA dataset with optional landmark-heatmap loading from HDF5.

    Returns: (x2, x4, hr, lr, heatmap)
      - x2: 32×32 target tensor
      - x4: 64×64 target tensor
      - hr: 128×128 target tensor
      - lr: 16×16 input tensor
      - heatmap: 1×128×128 float tensor
    """
    def __init__(
        self,
        data_path: str = './dataset/',
        state: str = 'train',
        data_augmentation: bool = False,
        heatmap_h5: str = None,
    ):
        self.main_path = data_path
        self.state = state
        self.data_augmentation = data_augmentation
        self.img_path = join(self.main_path, 'img_align_celeba/img_align_celeba/')
        self.eval_partition_path = join(self.main_path, 'list_eval_partition.csv')

        # load train/val/test split
        train_list, val_list, test_list = [], [], []
        with open(self.eval_partition_path, 'r') as f:
            reader = csv.reader(f)
            for fname, split in reader:
                fname, split = fname.strip(), split.strip()
                if split == '0':
                    train_list.append(fname)
                elif split == '1':
                    val_list.append(fname)
                else:
                    test_list.append(fname)

        if state == 'train':
            self.image_list = sorted(train_list)
        elif state == 'val':
            self.image_list = sorted(val_list)
        else:
            self.image_list = sorted(test_list)

        # transforms
        if state=='train' and data_augmentation:
            self.pre_process = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.CenterCrop((178,178)),
                transforms.Resize((128,128)),
                transforms.RandomRotation(
                    20,
                    interpolation=InterpolationMode.BILINEAR
                ),
                transforms.ColorJitter(0.4,0.4,0.4,0.1)
            ])
        else:
            self.pre_process = transforms.Compose([
                transforms.CenterCrop((178,178)),
                transforms.Resize((128,128)),
            ])

        self.totensor = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
        ])
        self.down64 = transforms.Resize((64,64))
        self.down32 = transforms.Resize((32,32))
        self.down16 = transforms.Resize((16,16))

        # ACTUALLY load heatmaps into memory
        if heatmap_h5:
            with h5py.File(heatmap_h5, 'r') as h5_file:
                # Load the entire heatmap dataset into RAM
                self.heatmaps = np.array(h5_file['heatmaps'])  # Shape: (N, 128, 128)
        else:
            self.heatmaps = None

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        # load image
        fname = self.image_list[index]
        img = Image.open(join(self.img_path, fname)).convert('RGB')
        img = self.pre_process(img)

        # build multi-scale
        x4 = self.down64(img)    # 64x64
        x2 = self.down32(x4)     # 32x32
        lr = self.down16(x2)     # 16x16

        # to tensor
        hr_tensor = self.totensor(img)
        x4_tensor = self.totensor(x4)
        x2_tensor = self.totensor(x2)
        lr_tensor = self.totensor(lr)

        # load heatmap (already 128×128)
        if self.heatmaps is not None:
            hm = self.heatmaps[index]              # numpy array (128,128)
            heat = torch.from_numpy(hm.copy()).unsqueeze(0)  # (1,128,128)
            #print('heat')
        else:
            heat = torch.zeros(1,128,128)
            #print('noheat')

        return x2_tensor, x4_tensor, hr_tensor, lr_tensor, heat

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def conv3x3(in_ch, out_ch, stride=1):
    return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=True)

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv = nn.Sequential(
            conv3x3(in_ch, out_ch, stride=stride),
            nn.LeakyReLU(0.2, inplace=True),
            conv3x3(out_ch, out_ch),
            nn.LeakyReLU(0.2, inplace=True),
        )
    def forward(self, x): return self.conv(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv = DoubleConv(in_ch + skip_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class UpLearn(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(ch, ch, 3, padding=1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
        )
    def forward(self, x): return self.up(x)

class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(ch, ch, 3, padding=1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch, ch, 3, padding=1, bias=True),
        )
    def forward(self, x): return x + self.body(x)

class SuperResolutionUNet(nn.Module):
    def __init__(self, in_channels=3, base_filters=32, out_channels=3, refine_blocks=3, deep_supervision=True):
        super().__init__()
        self.deep_supervision = deep_supervision

        # Encoder (strided)
        self.enc1 = DoubleConv(in_channels,       base_filters,   stride=1)  # 16x
        self.enc2 = DoubleConv(base_filters,      base_filters*2, stride=2)  # 8x
        self.enc3 = DoubleConv(base_filters*2,    base_filters*4, stride=2)  # 4x
        self.enc4 = DoubleConv(base_filters*4,    base_filters*8, stride=2)  # 2x

        # Bottleneck
        self.bottleneck = DoubleConv(base_filters*8, base_filters*8)

        # Decoder
        self.up3 = UpBlock(base_filters*8, base_filters*4, base_filters*4)  # 4x
        self.up2 = UpBlock(base_filters*4, base_filters*2, base_filters*2)  # 8x
        self.up1 = UpBlock(base_filters*2, base_filters,   base_filters)    # 16x

        # Learned upsampling
        self.up_learn1 = UpLearn(base_filters)  # 32x
        self.up_learn2 = UpLearn(base_filters)  # 64x
        self.up_learn3 = UpLearn(base_filters)  # 128x

        # Refine head at 128x
        self.refine_in = nn.Conv2d(base_filters, base_filters, 1)
        self.refine = nn.Sequential(*[ResBlock(base_filters) for _ in range(refine_blocks)])
        self.final_conv = nn.Conv2d(base_filters, out_channels, 1)

        # --- Deep supervision heads ---
        if self.deep_supervision:
            self.out32 = nn.Conv2d(base_filters, out_channels, 1)  # Output for 32x32
            self.out64 = nn.Conv2d(base_filters, out_channels, 1)  # Output for 64x64

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)        # 16x
        e2 = self.enc2(e1)       # 8x
        e3 = self.enc3(e2)       # 4x
        e4 = self.enc4(e3)       # 2x

        b  = self.bottleneck(e4)

        # Decoder
        d3 = self.up3(b, e3)     # 4x
        d2 = self.up2(d3, e2)    # 8x
        d1 = self.up1(d2, e1)    # 16x

        u1 = self.up_learn1(d1)  # 32x
        u2 = self.up_learn2(u1)  # 64x
        u3 = self.up_learn3(u2)  # 128x

        r128 = self.refine(self.refine_in(u3))
        out128 = self.final_conv(r128)

        # residual skip
        up_input = F.interpolate(x, size=out128.shape[2:], mode='bilinear', align_corners=False)
        out128 = out128 + up_input

        if self.deep_supervision:
            out32 = self.out32(u1)
            out64 = self.out64(u2)
            return out32, out64, out128
        else:
            return out128


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16, VGG16_Weights
import lpips

# -------------------------
# Pixel loss
# -------------------------
pixel_crit = nn.MSELoss()

# -------------------------
# Perceptual (VGG) loss
# -------------------------
class VGGPerceptualLoss(nn.Module):
    """
    Expects inputs in [-1,1]. Internally maps to [0,1] and applies ImageNet mean/std.
    Runs VGG feature extraction in float32 (even under autocast) for stability.
    """
    def __init__(self):
        super().__init__()
        vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_FEATURES).features
        self.slice1 = nn.Sequential(*list(vgg[:4])).eval()   # conv1_2
        self.slice2 = nn.Sequential(*list(vgg[4:9])).eval()  # conv2_2
        self.slice3 = nn.Sequential(*list(vgg[9:16])).eval() # conv3_3
        for m in (self.slice1, self.slice2, self.slice3):
            for p in m.parameters():
                p.requires_grad = False

        # ImageNet norm buffers
        self.register_buffer('mean', torch.tensor([0.485,0.456,0.406]).view(1,3,1,1))
        self.register_buffer('std',  torch.tensor([0.229,0.224,0.225]).view(1,3,1,1))

    def _prep(self, x: torch.Tensor) -> torch.Tensor:
        # [-1,1] -> [0,1] -> ImageNet norm
        x01 = (x.clamp(-1,1) + 1) / 2
        return (x01 - self.mean) / self.std

    def forward(self, sr: torch.Tensor, hr: torch.Tensor) -> torch.Tensor:
        # Force fp32 for VGG path even if outer training is mixed precision
        sr32 = self._prep(sr).float()
        hr32 = self._prep(hr).float()

        f1_sr, f1_hr = self.slice1(sr32), self.slice1(hr32)
        f2_sr, f2_hr = self.slice2(f1_sr),  self.slice2(f1_hr)
        f3_sr, f3_hr = self.slice3(f2_sr),  self.slice3(f2_hr)

        # Sum of MSEs across a few layers
        return (F.mse_loss(f1_sr, f1_hr) +
                F.mse_loss(f2_sr, f2_hr) +
                F.mse_loss(f3_sr, f3_hr))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
perceptual_crit = VGGPerceptualLoss().to(device)

def attention_loss(
    sr, hr, heat,
    *, gamma: float = 1.3,   # >1 = more focus on hot zones
       floor: float = 0.10,  # weight outside the mask (0..1)
       eps: float = 1e-6
):
    """
    Per-sample normalized masked MAE:
      loss_i = sum(w_i * |sr - hr|) / sum(w_i), then averaged over the batch.
    - sr, hr: (B, C, H, W) in [-1, 1]
    - heat: (B, 1, H, W) or (B, H, W)
    - gamma: selectivity of the heatmap
    - floor: minimum gradient even outside the mask
    """
    if heat is None:
        return sr.new_tensor(0.0)

    if heat.dim() == 3:
        heat = heat.unsqueeze(1)  # (B,1,H,W)

    heat = heat.to(device=sr.device, dtype=sr.dtype)
    B = heat.size(0)

    # min-max per campione -> [0,1]
    flat = heat.reshape(B, -1)
    hmin = flat.min(dim=1, keepdim=True)[0].reshape(B,1,1,1)
    hmax = flat.max(dim=1, keepdim=True)[0].reshape(B,1,1,1)
    span = (hmax - hmin)

    hn = (heat - hmin) / span.clamp_min(eps)     # [0,1]
    if abs(gamma - 1.0) > 1e-6:
        hn = hn.clamp(0,1).pow(gamma)

    # w' = floor + (1-floor)*hn  in [floor,1]
    w = floor + (1.0 - floor) * hn

    # se mappa ~costante, usa pesi uniformi (tutti 1)
    uniform = (span <= eps)
    if uniform.any():
        w = torch.where(uniform, torch.ones_like(w), w)

    # niente grad attraverso i pesi
    w = w.expand_as(sr).detach()

    # riduzione per-sample (reshape evita problemi di contiguità)
    w_flat   = w.reshape(B, -1)
    mae_flat = (w * (sr - hr).abs()).reshape(B, -1)
    loss_per_sample = mae_flat.sum(dim=1) / w_flat.sum(dim=1).clamp_min(eps)
    return loss_per_sample.mean()


# -------------------------
# LPIPS loss
# -------------------------
# lpips expects inputs in [-1,1]; returns (B,1,1,1) or (B,)
lpips_crit = lpips.LPIPS(net='vgg').to(device)

def lpips_loss(sr, hr):
    return lpips_crit(sr, hr).mean()

In [6]:
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from torchmetrics.functional import structural_similarity_index_measure as ssim
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure as MSSSIMMetric

def _area_resize(x, size_hw):
    if x is None: return None
    if x.shape[-2:] == size_hw: return x
    return F.interpolate(x, size=size_hw, mode="area")

def _renorm_heat(h):
    if h is None: return None
    m = h.amax(dim=(-2, -1), keepdim=True).clamp_min(1e-6)
    return h / m

class Trainer:
    def __init__(self, cfg):
        self.cfg    = cfg
        self.G = SuperResolutionUNet(
            in_channels   = cfg.get('in_channels', 3),
            base_filters  = cfg.get('base_filters', 32),
            out_channels  = cfg.get('out_channels', 3),
            refine_blocks = cfg.get('refine_blocks', 3)
        ).to(device)
        self.optG   = torch.optim.Adam(self.G.parameters(), lr=cfg['lr_g'])
        self.scaler = GradScaler(enabled=torch.cuda.is_available())
        self.metric_stride = 5

        self.msssim_metric = MSSSIMMetric(
            data_range=2.0,
            kernel_size=(7, 7),
            betas=(0.0448, 0.2856, 0.3001, 0.2363),
            normalize="relu",
        ).to(device)

        self.ms_scales  = [1.0, 0.5, 0.25]
        self.ms_weights = [1.0, 0.2, 0.05]
        assert len(self.ms_scales) == len(self.ms_weights)

        self.perc_min_side  = int(cfg.get("ms_perc_min_side",  96))
        self.lpips_min_side = int(cfg.get("ms_lpips_min_side", 96))

    @staticmethod
    def psnr_from_mse(mse, max_val=2.0):
        return 10.0 * torch.log10((max_val * max_val) / (mse + 1e-8))

    def compute_overall_metrics(self, sr, hr):
        with torch.no_grad():
            mse_overall = ((sr - hr) ** 2).mean(dim=(1,2,3))
            psnr_overall   = self.psnr_from_mse(mse_overall).mean().item()
            ssim_overall   = float(ssim(sr, hr, data_range=2.0))
            msssim_overall = float(self.msssim_metric(sr, hr).cpu())
            self.msssim_metric.reset()
            lpips_overall  = lpips_loss(sr, hr)
            if torch.is_tensor(lpips_overall):
                lpips_overall = lpips_overall.mean().item()
            return {
                'psnr_overall': psnr_overall,
                'ssim_overall': ssim_overall,
                'msssim_overall': msssim_overall,
                'lpips_overall': lpips_overall
            }

    def _build_ms_pairs(self, out32, out64, out128, x2, x4, hr, heat):
        def flags(h, w):
            return (min(h, w) >= self.perc_min_side), (min(h, w) >= self.lpips_min_side)

        allow_perc_128, allow_lp_128 = flags(*hr.shape[-2:])
        allow_perc_64,  allow_lp_64  = flags(*x4.shape[-2:])
        allow_perc_32,  allow_lp_32  = flags(*x2.shape[-2:])

        heat128 = _renorm_heat(heat)
        heat64  = _renorm_heat(_area_resize(heat, (64, 64))) if heat is not None else None
        heat32  = _renorm_heat(_area_resize(heat, (32, 32))) if heat is not None else None

        sw = dict(zip(self.ms_scales, self.ms_weights))
        return [
            ('1x', out128, hr,   heat128, allow_perc_128, allow_lp_128, sw.get(1.0, 0.0)),
            ('x4', out64,  x4,   heat64,  allow_perc_64,  allow_lp_64,  sw.get(0.5, 0.0)),
            ('x2', out32,  x2,   heat32,  allow_perc_32,  allow_lp_32,  sw.get(0.25, 0.0)),
        ]

    def _multiscale_loss_from_pairs(self, ms_pairs, epoch=None, scheduler=None):
        use_perc  = self.cfg.get('use_perc',  False)
        use_attn  = self.cfg.get('use_attn',  False)
        use_lpips = self.cfg.get('use_lpips', False)

        if epoch is not None and scheduler is not None:
            scale_perc, scale_attn, scale_lp = scheduler.scales_at(epoch, self.cfg)
        else:
            scale_perc, scale_attn, scale_lp = 1.0, 1.0, 1.0

        Lpix = 0.0; Lperc = 0.0; Lattn = 0.0; Llp = 0.0
        for tag, pred, tgt_img, tgt_heat, allow_perc, allow_lp, w in ms_pairs:
            if w <= 0.0: continue
            Lpix += w * pixel_crit(pred, tgt_img)
            if use_attn and tgt_heat is not None:
                Lattn += w * attention_loss(pred, tgt_img, tgt_heat)
            if use_perc and allow_perc:
                Lperc += w * perceptual_crit(pred, tgt_img)
            if use_lpips and allow_lp:
                Llp += w * lpips_loss(pred, tgt_img)

        loss = Lpix \
            + self.cfg.get('w_perc', 0.0)  * (scale_perc * Lperc if torch.is_tensor(Lperc) else 0.0) \
            + self.cfg.get('w_attn', 0.0)  * (scale_attn * Lattn if torch.is_tensor(Lattn) else 0.0) \
            + self.cfg.get('w_lpips', 0.0) * (scale_lp   * Llp   if torch.is_tensor(Llp)   else 0.0)
        return Lpix, Lperc, Lattn, Llp, loss

    def train_epoch(self, loader, epoch=None, scheduler=None):
        agg = {k: 0.0 for k in [
            'loss_pixel','loss_perc','loss_attn','loss_lpips','loss_combined',
            'psnr_overall','ssim_overall','msssim_overall','lpips_overall'
        ]}
        self.G.train()
        step = metric_steps = 0
        use_cuda = torch.cuda.is_available()

        for x2, x4, hr, lr, heat in tqdm(loader, desc=f"Training {self.cfg['name']}"):
            lr, hr, x4, x2 = (x.to(device, non_blocking=True) for x in [lr, hr, x4, x2])
            heat = heat.to(device, non_blocking=True) if self.cfg.get('use_attn', False) else None

            self.optG.zero_grad(set_to_none=True)
            with autocast(device_type='cuda', enabled=use_cuda, dtype=torch.float16):
                out32, out64, out128 = self.G(lr)
                ms_pairs = self._build_ms_pairs(out32, out64, out128, x2, x4, hr, heat)
                Lpix, Lperc, Lattn, Llp, loss = self._multiscale_loss_from_pairs(ms_pairs, epoch, scheduler)

            self.scaler.scale(loss).backward()
            self.scaler.step(self.optG)
            self.scaler.update()

            agg['loss_pixel']    += float(Lpix.detach())
            agg['loss_perc']     += float(Lperc.detach()) if self.cfg.get('use_perc', False) else 0.0
            agg['loss_attn']     += float(Lattn.detach()) if self.cfg.get('use_attn', False) else 0.0
            agg['loss_lpips']    += float(Llp.detach())   if self.cfg.get('use_lpips', False) else 0.0
            agg['loss_combined'] += float(loss.detach())

            if step % self.metric_stride == 0:
                m = self.compute_overall_metrics(out128.detach(), hr)
                for k in ('psnr_overall','ssim_overall','msssim_overall','lpips_overall'):
                    agg[k] += m[k]
                metric_steps += 1
            step += 1

        return {k: (v / max(step,1) if "loss" in k else v / max(metric_steps,1)) for k, v in agg.items()}

    def evaluate(self, loader, num_samples=500, epoch=None, scheduler=None):
        self.G.eval()
        agg = {k: 0.0 for k in [
            'loss_pixel','loss_perc','loss_attn','loss_lpips','loss_combined',
            'psnr_overall','ssim_overall','msssim_overall','lpips_overall'
        ]}
        n = 0
        with torch.no_grad():
            for x2, x4, hr, lr, heat in tqdm(loader, desc=f"Evaluating {self.cfg['name']}"):
                if n >= num_samples: break
                lr, hr, x4, x2 = (x.to(device, non_blocking=True) for x in [lr, hr, x4, x2])
                heat = heat.to(device, non_blocking=True) if self.cfg.get('use_attn', False) else None

                out32, out64, out128 = self.G(lr)
                m = self.compute_overall_metrics(out128, hr)
                for k in ('psnr_overall','ssim_overall','msssim_overall','lpips_overall'):
                    agg[k] += m[k]

                ms_pairs = self._build_ms_pairs(out32, out64, out128, x2, x4, hr, heat)
                Lpix, Lperc, Lattn, Llp, Lcomb = self._multiscale_loss_from_pairs(ms_pairs, epoch, scheduler)

                agg['loss_pixel']    += float(Lpix)
                agg['loss_perc']     += float(Lperc)  if self.cfg.get('use_perc', False)  else 0.0
                agg['loss_attn']     += float(Lattn)  if self.cfg.get('use_attn', False)  else 0.0
                agg['loss_lpips']    += float(Llp)    if self.cfg.get('use_lpips', False) else 0.0
                agg['loss_combined'] += float(Lcomb)
                n += 1

        return {k: (v / max(n, 1)) for k, v in agg.items()}


In [None]:
# --- core imports
import os, h5py, numpy as np
import torch
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from tabulate import tabulate
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
import h5py
from google.colab import files


# --- device & seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pin = torch.cuda.is_available()

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# --- paths & basic hyperparams
data_path  = path            # <- set this to your CelebA root
heat_h5    = './heatmaps.h5' # <- your heatmaps file
batch_size = 64
num_epochs = 200
ckpt_path = None
# Uncomment this and the other part in the code if loading a pretrained version of the model
#ckpt_path = "./PATH_TO_CHECKPOINT.pt"

with h5py.File(heat_h5, 'r') as f:
    HM_COUNT = int(f['heatmaps'].shape[0])
print("Heatmaps available:", HM_COUNT)

TRAIN_LIMIT = HM_COUNT                        # train limited by heatmaps (check not needed but assure consistency and prevents errors)
VAL_LIMIT   = max(1, int(0.2 * TRAIN_LIMIT))  # 20% of train (if we have 10k training samples, we limit our validation to be 2000)

# --- build datasets
train_ds = CelebDataSet(data_path, 'train', heatmap_h5=heat_h5)
val_ds   = CelebDataSet(data_path, 'val',   heatmap_h5=heat_h5)

train_n = min(TRAIN_LIMIT, len(train_ds))
val_n   = min(VAL_LIMIT,   len(val_ds))

print(f"Using TRAIN_LIMIT={train_n}, VAL_LIMIT={val_n}")

train_subset = Subset(train_ds, range(train_n))
val_subset   = Subset(val_ds,   range(val_n))

# --- dataloaders
loader     = DataLoader(train_subset, batch_size=batch_size, shuffle=True,
                        num_workers=2, pin_memory=pin)
val_loader = DataLoader(val_subset,   batch_size=batch_size, shuffle=False,
                        num_workers=2, pin_memory=pin)

print("Train samples:", len(train_subset))
print("Val samples:  ", len(val_subset))

# --- one fixed batch for viz
# This batch will be used to visualize validation results during validation process in the training loop
data_iter = iter(loader)
_, _, hr_f, lr_f, heat_f = next(data_iter)
lr_vis, hr_vis, heat_vis = [t.to(device, non_blocking=True) for t in (lr_f[0:1], hr_f[0:1], heat_f[0:1])]

# --- single config (you can add more later)
configs = [
    # This commented is the model that worked better. Take it as reference for changes.
    dict(
        name='deep_supervision_30k',
        use_attn=True, use_perc=True, use_lpips=True,
        w_attn=2.0, w_perc=0.01, w_lpips=0.15,
        lr_g=1e-4,
        base_filters=48,
        refine_blocks=5
    ),
]

for cfg in configs:
    for k in ('w_perc','w_attn','w_lpips'):
        cfg.setdefault(k, 0.0)
        cfg.setdefault(f'_base_{k}', cfg[k])
    cfg.setdefault('use_perc', False)
    cfg.setdefault('use_attn', False)
    cfg.setdefault('use_lpips', False)
    cfg.setdefault('schedule', True)

# --- build trainers from configs
trainers = {}
for cfg in configs:
    t = Trainer(cfg)
    if 'metric_stride' in cfg:
        t.metric_stride = cfg['metric_stride']
    t.lr_sched = ReduceLROnPlateau(t.optG, mode='min', patience=6, factor=0.5, verbose=True)
    trainers[cfg['name']] = t

# ==== helpers ===============================================================
# Dynamically adjusts the weights of perceptual, attention, and LPIPS losses
# throughout training based on the current epoch fraction.
#
# The schedule is divided into three phases:
#   1. Warm-up (0–30% of total epochs): Gradually increases perceptual and attention
#      loss weights from their starting values; LPIPS loss remains off.
#   2. Mid-phase (30–85%): Slowly ramps perceptual and LPIPS weights to target values,
#      and increases attention weight further.
#   3. Late-phase (85–100%): Keeps perceptual weight fixed, slightly decreases attention
#      weight, and slightly boosts LPIPS weight to emphasize fine texture details.
#
# If a given loss type is disabled in `cfg`, its weight is kept at 0 for the whole schedule.
# 2) Scheduler che produce SCALE, poi moltiplica per i target base
class LossWeightScheduler:
    def __init__(self, num_epochs):
        self.E = num_epochs

    def scales_at(self, epoch, cfg):
        """Ritorna scale (0..1) per perc/attn/lpips, NON pesi assoluti."""
        t = epoch / self.E  # 0..1

        if t <= 0.30:
            u = (t / 0.30)
            s_perc  = 0.00 + 1.00 * u      # 0 -> 1.00 del target base
            s_attn  = 0.60 + 0.40 * u      # 0.60 -> 1.00 del target base
            s_lpips = 0.00                 # off
        elif t <= 0.85:
            u = (t - 0.30) / 0.55
            s_perc  = 1.00                 # già pieno target
            s_attn  = 1.00                 # già pieno target
            s_lpips = 0.00 + 1.00 * u      # 0 -> 1.00 del target base
        else:
            u = (t - 0.85) / 0.15
            s_perc  = 1.00
            s_attn  = 1.00 - 0.25 * u      # cala al 75% del target base
            s_lpips = 1.00                 # pieno target
        return s_perc, s_attn, s_lpips

    def apply(self, epoch, cfg):
        if not cfg.get('schedule', True):
            # non toccare nulla
            return
        use_perc  = cfg.get('use_perc',  False)
        use_attn  = cfg.get('use_attn',  False)
        use_lpips = cfg.get('use_lpips', False)

        s_perc, s_attn, s_lpips = self.scales_at(epoch, cfg)

        base_perc  = cfg.get('_base_w_perc',  0.0)
        base_attn  = cfg.get('_base_w_attn',  0.0)
        base_lpips = cfg.get('_base_w_lpips', 0.0)

        cfg['w_perc']  = (base_perc  * s_perc)  if use_perc  else 0.0
        cfg['w_attn']  = (base_attn  * s_attn)  if use_attn  else 0.0
        cfg['w_lpips'] = (base_lpips * s_lpips) if use_lpips else 0.0


# Monitors a validation score and triggers early stopping when it stops improving.
#
# Parameters:
#   patience  – number of consecutive epochs without significant improvement
#               (greater than `min_delta`) before stopping.
#   min_delta – minimum required improvement in the monitored score to be considered progress.
#
# Behavior:
#   - Tracks the best (lowest) score seen so far.
#   - Resets the bad epoch counter when improvement is detected.
#   - Increments the bad epoch counter otherwise.
#   - Sets `should_stop=True` when the bad epoch counter reaches `patience`.

class EarlyStopping:
    def __init__(self, patience=15, min_delta=1e-4):
        self.patience=patience; self.min_delta=min_delta
        self.best=None; self.bad_epochs=0; self.should_stop=False
    def step(self, score):
        if self.best is None or score < self.best - self.min_delta:
            self.best=score; self.bad_epochs=0
        else:
            self.bad_epochs += 1
            if self.bad_epochs >= self.patience: self.should_stop=True

# Computes a scalar "validation objective" score to guide model selection,
# learning rate scheduling, and early stopping. Lower is better.
#
# The score is a weighted sum of penalties:
#   - 1.5 * LPIPS: strong penalty for poor perceptual similarity (higher LPIPS).
#   - 0.8 * max(0, 0.02 - SSIM): penalty if SSIM falls below 0.02 (no penalty otherwise).
#   - 0.2 * max(0, 20.0 - PSNR): penalty if PSNR is below 20 dB (no penalty otherwise).
#
# Inputs:
#   m – dictionary of validation metrics containing:
#       'ssim_overall', 'lpips_overall', 'psnr_overall'
#
# Output:
#   A single float representing the validation objective; smaller values indicate better quality.
def val_objective(m):
    ssim = float(m.get('ssim_overall', 0.0))
    lp   = float(m.get('lpips_overall', 1.0))
    psnr = float(m.get('psnr_overall', 0.0))
    return (1.5*lp) + (0.8*max(0.0, 0.02-ssim)) + (0.2*max(0.0, 20.0-psnr))

def to01(x): return (x.clamp(-1,1) + 1)/2

def save_checkpoint(trainer, history, epoch, best_score, path):
    torch.save({
        'epoch': epoch,
        'model': trainer.G.state_dict(),
        'opt': trainer.optG.state_dict(),
        'scaler': trainer.scaler.state_dict(),
        'sched': trainer.lr_sched.state_dict() if hasattr(trainer, 'lr_sched') else None,
        'early_best': best_score,
        'history': history,
    }, path)

# ===========================================================================


loss_sched = LossWeightScheduler(num_epochs)
early_stop = EarlyStopping(patience=15, min_delta=1e-4)
history    = {}
start_epoch = 1


if ckpt_path is not None:
  name = configs[0]['name']
  ckpt = torch.load(ckpt_path, map_location=device)
  trainers[name].G.load_state_dict(ckpt['model'])
  trainers[name].optG.load_state_dict(ckpt['opt'])
  trainers[name].scaler.load_state_dict(ckpt['scaler'])
  if ckpt.get('sched'):
      trainers[name].lr_sched.load_state_dict(ckpt['sched'])
  print("Loaded configuration from", ckpt_path)

  early_stop.best = ckpt.get('early_best', float('inf'))
  trainers[name]._best_score = early_stop.best
  start_epoch = ckpt.get('epoch', 0) + 1

  history = ckpt.get('history', {})

last_round_epoch = 1
# TRAINING LOOP
for epoch in range(start_epoch, num_epochs+1):
    print(f"\n=== Epoch {epoch}/{num_epochs} ===")

    # update loss weights
    for cfg in configs:
        loss_sched.apply(epoch, cfg)

    # ---- TRAIN ----
    epoch_metrics = {}
    for cfg in configs:
        name    = cfg['name']
        metrics = trainers[name].train_epoch(loader)
        epoch_metrics[name] = metrics

        # history (train)
        hist = history.setdefault(name, {})
        for k, v in metrics.items():
            hist.setdefault(k, []).append(v)

    # print train (now with MS-SSIM)
    headers = ["Model","pix","perc","attn","lpips_loss","comb","PSNR","SSIM","MS-SSIM","LPIPS"]
    table = []
    for cfg in configs:
        name = cfg['name']; m = epoch_metrics[name]
        table.append([
            name,
            f"{m['loss_pixel']:.4e}",
            f"{m['loss_perc']:.4e}"    if cfg.get('use_perc', False)  else "-",
            f"{m['loss_attn']:.4e}"    if cfg.get('use_attn', False)  else "-",
            f"{m['loss_lpips']:.4e}"   if cfg.get('use_lpips', False) else "-",
            f"{m['loss_combined']:.4e}",
            f"{m['psnr_overall']:.2f}",
            f"{m['ssim_overall']:.4f}",
            f"{m['msssim_overall']:.4f}",
            f"{m['lpips_overall']:.4f}",
        ])
    print(tabulate(table, headers=headers, tablefmt="github"))

    # ---- VALIDATION ----
    val_metrics = {}
    for cfg in configs:
        name = cfg['name']
        vm = trainers[name].evaluate(val_loader)
        val_metrics[name] = vm

        # history (val)
        hist = history[name]
        for k, v in vm.items():
            hist.setdefault('val_' + k, []).append(v)

    ref_name = configs[0]['name']
    ref_val  = val_metrics[ref_name]
    score    = val_objective(ref_val)

    # schedulers + early stop
    for cfg in configs:
        trainers[cfg['name']].lr_sched.step(val_objective(val_metrics[cfg['name']]))
    early_stop.step(score)

    # print val (now with MS-SSIM)
    headers = ["Model","pix","perc","attn","lpips_loss","comb","PSNR","SSIM","MS-SSIM","LPIPS","val_obj"]
    table = []
    for cfg in configs:
        name = cfg['name']; m = val_metrics[name]
        table.append([
            name,
            f"{m['loss_pixel']:.4e}",
            f"{m['loss_perc']:.4e}"    if cfg.get('use_perc', False)  else "-",
            f"{m['loss_attn']:.4e}"    if cfg.get('use_attn', False)  else "-",
            f"{m['loss_lpips']:.4e}"   if cfg.get('use_lpips', False) else "-",
            f"{m['loss_combined']:.4e}",
            f"{m['psnr_overall']:.2f}",
            f"{m['ssim_overall']:.4f}",
            f"{m['msssim_overall']:.4f}",
            f"{m['lpips_overall']:.4f}",
            f"{val_objective(m):.4f}",
        ])
    print("\n")
    print(tabulate(table, headers=headers, tablefmt="github"))


    # quick viz
    if epoch % 10 ==0:
        val_iter = iter(val_loader)
        _, _, hr_val_b, lr_val_b, _ = next(val_iter)
        lr_val = lr_val_b[:5].to(device)
        hr_val = hr_val_b[:5].cpu()

        lr_up_val = to01(F.interpolate(lr_val, size=(128,128), mode='bilinear', align_corners=False)).cpu()
        recon_val = {}
        with torch.no_grad():
            for cfg in configs:
                nm = cfg['name']
                sr_out = trainers[nm].G.eval()(lr_val)
                if isinstance(sr_out, tuple):
                    _, _, sr = sr_out
                else:
                    sr = sr_out
                recon_val[nm] = to01(sr).cpu()

        fig, axes = plt.subplots(5, 2 + len(configs), figsize=(4*(2+len(configs)), 20))
        for i in range(5):
            row = axes[i]
            row[0].imshow(lr_up_val[i].permute(1,2,0)); row[0].set_title("LR ↑"); row[0].axis('off')
            row[1].imshow(to01(hr_val[i]).permute(1,2,0)); row[1].set_title("HR GT"); row[1].axis('off')
            for j, cfg in enumerate(configs, start=2):
                nm = cfg['name']
                row[j].imshow(recon_val[nm][i].permute(1,2,0)); row[j].set_title(nm); row[j].axis('off')
        plt.suptitle(f"Epoch {epoch} — Validation Reconstructions", fontsize=16)
        plt.tight_layout(); plt.show()

    # save best + periodic
    best_dir = f"./{ref_name}"; os.makedirs(best_dir, exist_ok=True)
    best_path = os.path.join(best_dir, f"{ref_name}_best.pth")
    if score <= getattr(trainers[ref_name], "_best_score", float("inf")):
        torch.save(trainers[ref_name].G.state_dict(), best_path)
        trainers[ref_name]._best_score = score

    if epoch % 10 == 0:
        for cfg in configs:
            name   = cfg['name']
            last_round_epoch = epoch
            folder = f"./{name}"; os.makedirs(folder, exist_ok=True)
            path_to_save   =  f"./{name}/{name}_epoch{epoch:03d}.pt"
            save_checkpoint(trainers[name], history, epoch, early_stop.best,path_to_save)

    if early_stop.should_stop:
        print(f"\nEarly stopping at epoch {epoch} (best val objective: {early_stop.best:.4f}).")

        # --- load the best checkpoint ---
        best_ckpt = torch.load(best_path, map_location=device)
        trainers[ref_name].G.load_state_dict(best_ckpt)

        # --- recompute validation metrics for best model ---
        best_metrics = trainers[ref_name].evaluate(val_loader)

        print("\n=== Best Model Validation Metrics ===")
        print(tabulate([[
            ref_name,
            f"{best_metrics['loss_pixel']:.4e}",
            f"{best_metrics['loss_perc']:.4e}"    if configs[0].get('use_perc', False)  else "-",
            f"{best_metrics['loss_attn']:.4e}"    if configs[0].get('use_attn', False)  else "-",
            f"{best_metrics['loss_lpips']:.4e}"   if configs[0].get('use_lpips', False) else "-",
            f"{best_metrics['loss_combined']:.4e}",
            f"{best_metrics['psnr_overall']:.2f}",
            f"{best_metrics['ssim_overall']:.4f}",
            f"{best_metrics['msssim_overall']:.4f}",
            f"{best_metrics['lpips_overall']:.4f}",
            f"{val_objective(best_metrics):.4f}",
        ]], headers=["Model","pix","perc","attn","lpips_loss","comb","PSNR","SSIM","MS-SSIM","LPIPS","val_obj"], tablefmt="github"))

        # --- visualize best model ---
        val_iter = iter(val_loader)
        _, _, hr_val_b, lr_val_b, _ = next(val_iter)
        lr_val = lr_val_b[:5].to(device)
        hr_val = hr_val_b[:5].cpu()

        lr_up_val = to01(F.interpolate(lr_val, size=(128,128), mode='bilinear', align_corners=False)).cpu()
        recon_val = {}
        with torch.no_grad():
            for cfg in configs:
                nm = cfg['name']
                sr_out = trainers[nm].G.eval()(lr_val)
                if isinstance(sr_out, tuple):
                    _, _, sr = sr_out
                else:
                    sr = sr_out
                recon_val[nm] = to01(sr).cpu()


        fig, axes = plt.subplots(5, 2 + len(configs), figsize=(4*(2+len(configs)), 20))
        for i in range(5):
            row = axes[i]
            row[0].imshow(lr_up_val[i].permute(1,2,0)); row[0].set_title("LR ↑"); row[0].axis('off')
            row[1].imshow(to01(hr_val[i]).permute(1,2,0)); row[1].set_title("HR GT"); row[1].axis('off')
            for j, cfg in enumerate(configs, start=2):
                nm = cfg['name']
                row[j].imshow(recon_val[nm][i].permute(1,2,0)); row[j].set_title(f"{nm} (best)"); row[j].axis('off')
        plt.suptitle(f"Best Model — Validation Reconstructions", fontsize=16)
        plt.tight_layout(); plt.show()

        # best metrics


        break


path_to_save   =  f"./{name}/{name}_epoch{last_round_epoch:03d}.pt"
files.download(path_to_save)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# pick the model you want to plot (you have only one in configs now)
name = list(history.keys())[0]
H = history[name]

def _ema(values, alpha=0.9):
    """Simple EMA over a 1D list/array. Returns a list of same length."""
    if values is None or len(values) == 0:
        return []
    out = []
    m = None
    for v in values:
        v = float(v)
        m = v if (m is None) else alpha * m + (1 - alpha) * v
        out.append(m)
    return out

def plot_metric(metric_key, title, ylabel="", invert=False, ema_alpha=0.9, show_raw=True):
    """
    Plot train/val curves with optional EMA smoothing.
    - ema_alpha: None to disable EMA; else e.g. 0.9
    - show_raw: whether to also draw raw curves (faint) alongside EMA
    """
    train = H.get(metric_key, [])
    val   = H.get(f"val_{metric_key}", [])

    plt.figure(figsize=(6,4))

    # Train
    if train:
        if ema_alpha is not None:
            tr_ema = _ema(train, alpha=ema_alpha)
            plt.plot(tr_ema, label=f"train (EMA {ema_alpha})")
            if show_raw:
                plt.plot(train, label="train (raw)", linestyle="--", alpha=0.35)
        else:
            plt.plot(train, label="train")

    # Val
    if val:
        if ema_alpha is not None:
            va_ema = _ema(val, alpha=ema_alpha)
            plt.plot(va_ema, label=f"val (EMA {ema_alpha})")
            if show_raw:
                plt.plot(val, label="val (raw)", linestyle="--", alpha=0.35)
        else:
            plt.plot(val, label="val")

    if invert:
        plt.gca().invert_yaxis()
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel if ylabel else metric_key)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

# === Losses ===
plot_metric("loss_combined", "Combined Loss", ylabel="Loss", ema_alpha=0.9)
plot_metric("loss_pixel",    "Pixel Loss (MSE)", ylabel="Loss", ema_alpha=0.9)
if "loss_perc" in H:  plot_metric("loss_perc",  "Perceptual Loss", ylabel="Loss", ema_alpha=0.9)
if "loss_attn" in H:  plot_metric("loss_attn",  "Attention Loss",  ylabel="Loss", ema_alpha=0.9)
if "loss_lpips" in H: plot_metric("loss_lpips", "LPIPS Loss",      ylabel="Loss", ema_alpha=0.9)

# === Image Quality Metrics ===
plot_metric("psnr_overall",   "PSNR",     ylabel="dB",     ema_alpha=0.9)
plot_metric("ssim_overall",   "SSIM",     ylabel="Score",  ema_alpha=0.9)
plot_metric("msssim_overall", "MS-SSIM",  ylabel="Score",  ema_alpha=0.9)
plot_metric("lpips_overall",  "LPIPS",    ylabel="Distance", invert=True, ema_alpha=0.9)