In [None]:
!nvidia-smi
!pip -q install --upgrade pip
!pip -q install basicsr timm einops gdown imageio opencv-python albumentations matplotlib torchvision numpy==1.26.4 scipy==1.13.1

# Fresh clone
!rm -rf Restormer
!git clone https://github.com/swz30/Restormer.git
%cd Restormer

In [None]:
import numpy as np
import scipy

print("NumPy version:", np.__version__)
print("SciPy version:", scipy.__version__)

In [None]:
# ---- TorchVision shim for legacy imports (functional_tensor) ----
import sys, types
import torch, torchvision
print("Torch:", torch.__version__, "| TorchVision:", torchvision.__version__)

from torchvision.transforms import functional as F
ft_mod = types.ModuleType("torchvision.transforms.functional_tensor")
for k, v in F.__dict__.items():
    setattr(ft_mod, k, v)
sys.modules["torchvision.transforms.functional_tensor"] = ft_mod
print("Shim installed: torchvision.transforms.functional_tensor → .functional ✅")

In [None]:
# ✅ Run this as a shell cell (notice the ! at the start of each command)
CKPT_URL="https://github.com/swz30/Restormer/releases/download/v1.0/deraining.pth"

!mkdir -p ./checkpoints
!rm -f ./checkpoints/pretrained_task.pth
!curl -L "$CKPT_URL" -o ./checkpoints/pretrained_task.pth
!ls -lh ./checkpoints

In [None]:
import os, torch
CKPT_PATH = "./checkpoints/pretrained_task.pth"
print("File size (MB):", os.path.getsize(CKPT_PATH)/1e6)
# Should be roughly > 50 MB
assert os.path.getsize(CKPT_PATH) > 10_000_000, "Checkpoint looks incomplete!"

# Quick load test
try:
    _ = torch.load(CKPT_PATH, map_location="cpu")
    print("✅ torch.load works fine — checkpoint OK!")
except Exception as e:
    print("❌ Problem loading checkpoint:", e)

In [None]:
# ==== RECOVERY: make sure Restormer is present and importable ====
import os, sys, subprocess, shutil, glob, importlib.util
from pathlib import Path

REPO_URL  = "https://github.com/swz30/Restormer.git"
REPO_DIR  = Path("/kaggle/working/Restormer")   # fixed absolute path

# 1) Fresh clone if missing or empty
def is_dir_empty(p: Path):
    return (not p.exists()) or (p.exists() and len(list(p.rglob("*"))) == 0)

if is_dir_empty(REPO_DIR):
    if REPO_DIR.exists():
        shutil.rmtree(REPO_DIR, ignore_errors=True)
    print("[i] Cloning Restormer repo...")
    subprocess.check_call(["git", "clone", "--depth", "1", REPO_URL, str(REPO_DIR)])
else:
    print("[i] Restormer repo already present at", REPO_DIR)

# 2) Put repo on sys.path so 'basicsr' (inside repo) is importable
if str(REPO_DIR) not in sys.path:
    sys.path.insert(0, str(REPO_DIR))
if str(REPO_DIR.parent) not in sys.path:
    sys.path.insert(0, str(REPO_DIR.parent))

# 3) Try canonical import; if it fails, import by file path
Restormer = None
try:
    from basicsr.models.archs.restormer_arch import Restormer  # type: ignore
    print("[✓] Imported Restormer from basicsr.models.archs.restormer_arch")
except Exception as e:
    print("[!] Canonical import failed:", e)
    cand = list(REPO_DIR.rglob("restormer_arch.py"))
    assert cand, "restormer_arch.py not found under repo"
    mod_path = cand[0]
    spec = importlib.util.spec_from_file_location("restormer_local", str(mod_path))
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)  # type: ignore
    Restormer = getattr(mod, "Restormer")
    print(f"[✓] Imported Restormer from file: {mod_path}")

In [None]:
# === ONE-CELL SAFE LOADER FOR RESTORMER + CHECKPOINT ===
import sys, types, importlib.util, re
from pathlib import Path
from collections import OrderedDict
import torch

# ---- Shim for legacy torchvision import used by some basicsr code ----
try:
    from torchvision.transforms import functional as F
    ft_mod = types.ModuleType("torchvision.transforms.functional_tensor")
    for k, v in F.__dict__.items():
        setattr(ft_mod, k, v)
    sys.modules["torchvision.transforms.functional_tensor"] = ft_mod
    print("Shim ok: torchvision.transforms.functional_tensor -> .functional")
except Exception as e:
    print("Shim skipped:", e)

# ---- Robust import of Restormer (canonical or file-path fallback) ----
repo_root = "/kaggle/working/Restormer"
if repo_root not in sys.path:
    sys.path.append(repo_root)

Restormer = None
try:
    from basicsr.models.archs.restormer_arch import Restormer  # type: ignore
    print("Imported Restormer from basicsr.models.archs.restormer_arch")
except Exception as e:
    print("Canonical import failed:", e)
    candidates = list(Path(repo_root).rglob("*restormer*.py"))
    assert candidates, "Restormer source not found under repo_root"
    p = str(candidates[0])
    spec = importlib.util.spec_from_file_location("restormer_local", p)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)  # type: ignore
    Restormer = getattr(mod, "Restormer")
    print("Imported Restormer from file:", p)

# ---- Load checkpoint and pick the correct inner state dict ----
ckpt_path = "./checkpoints/pretrained_task.pth"
state = torch.load(ckpt_path, map_location="cpu")

if isinstance(state, dict):
    if isinstance(state.get("params_ema"), dict):
        sd = state["params_ema"]; print("Using checkpoint['params_ema']")
    elif isinstance(state.get("params"), dict):
        sd = state["params"];     print("Using checkpoint['params']")
    elif isinstance(state.get("state_dict"), dict):
        sd = state["state_dict"]; print("Using checkpoint['state_dict']")
    elif isinstance(state.get("model"), dict):
        sd = state["model"];      print("Using checkpoint['model']")
    else:
        sd = state;               print("Using checkpoint as-is (flat dict)")
else:
    sd = state;                   print("Using checkpoint as-is (non-dict)")

# strip DDP prefix if present
if any(k.startswith("module.") for k in sd.keys()):
    sd = OrderedDict((re.sub(r"^module\.", "", k), v) for k, v in sd.items())
    print("Stripped 'module.' prefixes")

def build_model(layernorm_type: str):
    return Restormer(
        inp_channels=3, out_channels=3,
        dim=48,
        num_blocks=[4,6,6,8],
        num_refinement_blocks=4,
        heads=[1,2,4,8],
        ffn_expansion_factor=2.66,
        bias=False,
        LayerNorm_type=layernorm_type,  # 'WithBias' or 'BiasFree'
        dual_pixel_task=False
    )

def try_load(layernorm_type: str):
    m = build_model(layernorm_type)
    missing, unexpected = m.load_state_dict(sd, strict=False)
    print(f"[{layernorm_type}] missing: {len(missing)} | unexpected: {len(unexpected)}")
    if missing:   print("  sample missing:", missing[:5])
    if unexpected:print("  sample unexpected:", unexpected[:5])
    return m, missing, unexpected

# Try BiasFree first, then WithBias
model, missing, unexpected = try_load("BiasFree")
if missing or unexpected:
    print("Retrying WithBias …")
    model, missing, unexpected = try_load("WithBias")

print("\nFinal -> Loaded with strict=False")
print("missing:", len(missing), "unexpected:", len(unexpected))

In [None]:
# freeze everything
for p in model.parameters(): p.requires_grad = False

# unfreeze late/refinement + output head (adjust names if needed)
to_unfreeze = []
for n, p in model.named_parameters():
    if any(k in n.lower() for k in ["refinement", "reconstruct", "reconstruction", "conv_out", "tail"]):
        p.requires_grad = True
        to_unfreeze.append(n)

# fallback: if nothing matched, unfreeze last ~10 params
if not to_unfreeze:
    for n,p in list(model.named_parameters())[-10:]:
        p.requires_grad = True
        to_unfreeze.append(n)

print("Unfreezing:", *to_unfreeze[:8], "...", sep="\n")

import torch
trainable = [p for p in model.parameters() if p.requires_grad]
opt_G = torch.optim.Adam(trainable, lr=1e-5, betas=(0.9,0.999))

In [None]:
import torch.nn as nn, torchvision

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device).train()

# perceptual VGG features
vgg = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.IMAGENET1K_V1).features[:16].eval().to(device)
for p in vgg.parameters(): p.requires_grad = False
l1, mse = nn.L1Loss(), nn.MSELoss()

# tiny PatchGAN
class PatchD(nn.Module):
    def __init__(self, in_ch=3, base=64):
        super().__init__()
        def blk(ic, oc, norm=True):
            m=[nn.Conv2d(ic, oc, 4, 2, 1)]
            if norm: m+=[nn.InstanceNorm2d(oc, affine=True)]
            m+=[nn.LeakyReLU(0.2, inplace=True)]
            return nn.Sequential(*m)
        self.net = nn.Sequential(
            blk(in_ch, base, norm=False),
            blk(base, base*2),
            blk(base*2, base*4),
            nn.Conv2d(base*4, 1, 3, 1, 1)
        )
    def forward(self, x): return self.net(x)

disc = PatchD().to(device).train()
opt_D = torch.optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5,0.999))

def adv_loss(pred, is_real): 
    tgt = torch.ones_like(pred) if is_real else torch.zeros_like(pred)
    return mse(pred, tgt)

L_ADV, L_PERC, L_ID = 1.0, 0.1, 0.1

In [None]:
# ==== Multi-degradation dataset & loaders (REPLACES old Cell 8 + 10) ====
import os, glob, random, cv2, torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import albumentations as A

# 1) Point this to the parent folder that contains: clean, fog, rain, lowlight, fog_rain, fog_lowlight, rain_lowlight, extreme
BASE_DIR = "/kaggle/input/claude-synthesis"   # <-- CHANGE THIS to your dataset root

DATA_ROOTS = {
    "clean":        os.path.join(BASE_DIR, "clean"),
    "fog":          os.path.join(BASE_DIR, "fog"),
    "rain":         os.path.join(BASE_DIR, "rain"),
    "lowlight":     os.path.join(BASE_DIR, "lowlight"),
    # "fog_rain":     os.path.join(BASE_DIR, "fog_rain"),
    # "fog_lowlight": os.path.join(BASE_DIR, "fog_lowlight"),
    # "rain_lowlight":os.path.join(BASE_DIR, "rain_lowlight"),
    # "extreme":      os.path.join(BASE_DIR, "extreme"),
}

DEG_DOMAINS = [d for d in DATA_ROOTS.keys() if d != "clean"]
CLEAN_DIR   = DATA_ROOTS["clean"]

def list_images(path):
    exts = ("*.jpg","*.jpeg","*.png","*.bmp")
    files = []
    for e in exts:
        files += glob.glob(os.path.join(path, e))
        files += glob.glob(os.path.join(path, "**", e), recursive=True)
    if not files:
        raise FileNotFoundError(f"No images found in {path}")
    return sorted(files)

# 2) Augmentations: keep geometry; gentle photometric jitter only
train_tf = A.Compose([
    A.LongestMaxSize(max_size=384),
    A.PadIfNeeded(min_height=384, min_width=384, border_mode=cv2.BORDER_REFLECT_101),
    A.RandomCrop(256, 256),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(0.08, 0.08, p=0.3),  # mild; safe across domains
])
val_tf = A.Compose([
    A.LongestMaxSize(max_size=512),
    A.PadIfNeeded(min_height=512, min_width=512, border_mode=cv2.BORDER_REFLECT_101),
])

def to_tensor(img):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.ascontiguousarray(img).astype(np.float32)/255.0
    return torch.from_numpy(np.transpose(img,(2,0,1)))

class UnpairedDataset(Dataset):
    """Returns (degraded_tensor, random_clean_tensor, basename)"""
    def __init__(self, deg_files, cln_files, transform=None):
        self.deg_files, self.cln_files, self.tf = deg_files, cln_files, transform
    def __len__(self): return len(self.deg_files)
    def __getitem__(self, idx):
        dimg = cv2.imread(self.deg_files[idx])
        cimg = cv2.imread(random.choice(self.cln_files))
        if self.tf:
            seed = np.random.randint(0, 999999)
            random.seed(seed); np.random.seed(seed)
            dimg = self.tf(image=dimg)['image']
            random.seed(seed); np.random.seed(seed)
            cimg = self.tf(image=cimg)['image']
        return to_tensor(dimg), to_tensor(cimg), os.path.basename(self.deg_files[idx])

def build_loader_for(domain, batch_size=2, train=True):
    deg_dir = DATA_ROOTS[domain]
    deg_files = list_images(deg_dir)
    cln_files = list_images(CLEAN_DIR)
    tf = train_tf if train else val_tf
    ds = UnpairedDataset(deg_files, cln_files, transform=tf)
    return DataLoader(ds, batch_size=batch_size, shuffle=train, num_workers=2, pin_memory=True, drop_last=train), len(deg_files)

# 3) One train/val loader per domain
train_loaders = {}
val_loaders   = {}
counts = {}
for d in DEG_DOMAINS:
    tl, n = build_loader_for(d, batch_size=2, train=True)
    vl, _ = build_loader_for(d, batch_size=1, train=False)
    train_loaders[d] = tl
    val_loaders[d]   = vl
    counts[d] = n

print("Domain image counts:")
for d, n in counts.items():
    print(f"  {d:12s}: {n}")

print("\nTrain loader batches per domain:")
for d, dl in train_loaders.items():
    print(f"  {d:12s}: {len(dl)}")

print("\nVal samples per domain:")
for d, dl in val_loaders.items():
    print(f"  {d:12s}: {len(dl.dataset)}")

In [None]:
# ==== Quick per-domain forward check ====
model.eval()
with torch.no_grad():
    for dom, vloader in val_loaders.items():
        try:
            d, c, name = next(iter(vloader))
        except StopIteration:
            print(f"[{dom}] empty val loader, skipping")
            continue
        d = d.to(device)
        o = model(d)
        print(f"[{dom}] Shapes: {tuple(d.shape)} -> {tuple(o.shape)}  | sample: {name[0]}")
model.train()

In [None]:
# ==== Stage A: Train one common model across all degradations (REPLACES old Cell 11) ====
import os, cv2, numpy as np
from itertools import cycle
from tqdm import tqdm

import torch

# ---- NEW: metrics (PSNR/SSIM) for saved previews ----
METRICS_CSV = "/kaggle/working/sample_metrics.csv"
def _init_metrics_csv(path=METRICS_CSV):
    if not os.path.exists(path):
        with open(path, "w") as f:
            f.write("epoch,domain,name,psnr,ssim\n")

# skimage helpers (kept tiny; guarded for API differences)
from skimage.metrics import peak_signal_noise_ratio as _psnr
from skimage.metrics import structural_similarity as _ssim
def _ssim_safe(ref_rgb_uint8, img_rgb_uint8):
    try:
        # new API
        return _ssim(ref_rgb_uint8, img_rgb_uint8, data_range=255, channel_axis=2)
    except TypeError:
        # old API
        return _ssim(ref_rgb_uint8, img_rgb_uint8, data_range=255, multichannel=True)

# Where previews and checkpoints will be saved
SAMPLES_ROOT = "/kaggle/working/samples_mixed"
CKPT_ROOT    = "/kaggle/working/checkpoints_mixed"   # <-- NEW
os.makedirs(SAMPLES_ROOT, exist_ok=True)
os.makedirs(CKPT_ROOT,    exist_ok=True)

# round-robin iterator that yields (domain, batch)
def rr_batches(loaders_dict):
    iters = {k: cycle(v) for k, v in loaders_dict.items()}
    while True:
        for k in loaders_dict.keys():
            yield k, next(iters[k])

# balanced steps: every epoch sees the same number of batches from each domain
steps_per_epoch = min(len(dl) for dl in train_loaders.values()) * len(train_loaders)
print("steps_per_epoch =", steps_per_epoch)

scalerG = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
scalerD = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

def clamp01(x): return torch.clamp(x, 0, 1)

@torch.no_grad()
def save_previews_by_domain(epoch):
    """
    Saves 3 preview triptychs per domain (degraded|output|clean)
    and RETURNS a list of metric rows (epoch, domain, name, psnr, ssim)
    computed on the exact samples displayed.
    """
    model.eval()
    rows = []
    for dom, vloader in val_loaders.items():
        outdir = os.path.join(SAMPLES_ROOT, dom)
        os.makedirs(outdir, exist_ok=True)
        cnt = 0
        for d, c, name in vloader:
            d, c = d.to(device), c.to(device)
            o = clamp01(model(d))

            # build triptych preview (unchanged)
            grid = torch.cat([d[:1], o[:1], c[:1]], dim=3)  # [B,3,H,3W]
            arr = (grid[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
            cv2.imwrite(
                os.path.join(outdir, f"ep{epoch:03d}_{dom}_{name[0]}"),
                cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
            )

            # ---- NEW: compute PSNR/SSIM for the shown sample (output vs clean) ----
            # use the same first item we visualized
            out_img = (o[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)   # RGB uint8
            gt_img  = (c[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)   # RGB uint8
            psnr = float(_psnr(gt_img, out_img, data_range=255))
            ssim = float(_ssim_safe(gt_img, out_img))
            rows.append((epoch, dom, name[0], psnr, ssim))

            cnt += 1
            if cnt >= 3:  # a few previews per domain
                break
    model.train()
    return rows

# Init the tiny CSV once
_init_metrics_csv(METRICS_CSV)

EPOCHS = 10  # 5–7 is a good range for base robustness
rr = rr_batches(train_loaders)

for ep in range(1, EPOCHS+1):
    pbar = tqdm(range(steps_per_epoch), desc=f"Epoch {ep}/{EPOCHS}")
    model.train(); disc.train()
    for _ in pbar:
        domain, batch = next(rr)
        degraded, clean, _ = batch
        degraded, clean = degraded.to(device, non_blocking=True), clean.to(device, non_blocking=True)

        # --- G step ---
        opt_G.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device=="cuda")):
            fake = model(degraded)
            loss_g_adv  = adv_loss(disc(fake), True) * L_ADV
            loss_g_perc = l1(vgg(fake), vgg(clean)) * L_PERC
            loss_g_id   = l1(model(clean), clean) * L_ID
            loss_g = loss_g_adv + loss_g_perc + loss_g_id
        scalerG.scale(loss_g).backward()
        scalerG.step(opt_G); scalerG.update()

        # --- D step ---
        opt_D.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device=="cuda")):
            r = disc(clean); f = disc(fake.detach())
            loss_d = 0.5*(adv_loss(r, True) + adv_loss(f, False))
        scalerD.scale(loss_d).backward()
        scalerD.step(opt_D); scalerD.update()

        pbar.set_postfix(G=float(loss_g.item()), D=float(loss_d.item()), dom=domain)

    # ---- NEW: save epoch-specific checkpoint ----
    ep_ckpt = os.path.join(CKPT_ROOT, f"restormer_mixed_ep{ep:03d}.pth")
    torch.save(model.state_dict(), ep_ckpt)

    # (optional: keep your original single-name checkpoint too)
    torch.save(model.state_dict(), "./restormer_mixed_base.pth")

    # ---- previews + sample metrics (PSNR/SSIM for those exact images) ----
    sample_rows = save_previews_by_domain(ep)
    if sample_rows:
        with open(METRICS_CSV, "a") as f:
            for row in sample_rows:
                f.write(",".join(map(str, row)) + "\n")

print("✅ Stage A done.")
print(f"Per-epoch checkpoints: {CKPT_ROOT}/restormer_mixed_epXXX.pth")
print(f"Also wrote: ./restormer_mixed_base.pth (last epoch)")
print(f"Previews:    {SAMPLES_ROOT}/<domain>/")
print(f"Sample metrics CSV: {METRICS_CSV} (columns: epoch,domain,name,psnr,ssim)")

In [None]:
# # ==== Stage A: Train one common model across all degradations (metrics + per-epoch ckpts) ====
# import os, cv2, time, json, numpy as np
# from itertools import cycle
# from collections import Counter
# from contextlib import suppress
# from tqdm import tqdm

# import torch
# import torch.nn.functional as F

# # ---------- Paths ----------
# SAMPLES_ROOT = "/kaggle/working/samples_mixed"
# CKPT_ROOT    = "/kaggle/working/checkpoints_mixed"
# LOG_ROOT     = "/kaggle/working/logs_mixed"
# os.makedirs(SAMPLES_ROOT, exist_ok=True)
# os.makedirs(CKPT_ROOT,    exist_ok=True)
# os.makedirs(LOG_ROOT,     exist_ok=True)

# METRICS_CSV   = os.path.join(LOG_ROOT, "metrics_summary.csv")
# TRAIN_CSV     = os.path.join(LOG_ROOT, "train_trace_epoch.csv")        # A1
# DOMCOUNT_CSV  = os.path.join(LOG_ROOT, "domain_steps.csv")             # A2
# BEST_CSV      = os.path.join(LOG_ROOT, "best_epochs_per_domain.csv")   # A3

# # ---------- Metrics helpers ----------
# _have_brisque = False
# _have_niqe = False
# try:
#     from skimage.metrics import peak_signal_noise_ratio as _psnr
#     from skimage.metrics import structural_similarity as _ssim
#     try:
#         from skimage.metrics import niqe as _niqe
#         _have_niqe = True
#     except Exception:
#         _have_niqe = False
#     try:
#         from skimage.metrics import brisque as _brisque
#         _have_brisque = True
#     except Exception:
#         _have_brisque = False
# except Exception as e:
#     raise RuntimeError(
#         "scikit-image not available; run `pip install -q scikit-image` in a Kaggle cell."
#     ) from e

# def _ssim_safe(img, ref):
#     """Handle skimage API differences (channel_axis vs multichannel)."""
#     try:
#         return _ssim(ref, img, data_range=255, channel_axis=2)
#     except TypeError:
#         return _ssim(ref, img, data_range=255, multichannel=True)

# def clamp01(x): 
#     return torch.clamp(x, 0, 1)

# def rr_batches(loaders_dict):
#     iters = {k: cycle(v) for k, v in loaders_dict.items()}
#     while True:
#         for k in loaders_dict.keys():
#             yield k, next(iters[k])

# def init_metrics_csv(path):
#     if not os.path.exists(path):
#         cols = ["epoch","domain","psnr","ssim"]
#         if _have_niqe: cols.append("niqe")
#         if _have_brisque: cols.append("brisque")
#         cols += ["mean_G_loss","mean_D_loss","mean_adv","mean_perc","mean_id"]
#         with open(path, "w", encoding="utf-8") as f:
#             f.write(",".join(cols) + "\n")

# def init_train_csv(path):  # A1 header
#     if not os.path.exists(path):
#         with open(path, "w", encoding="utf-8") as f:
#             f.write("epoch,steps,batch_size,imgs_seen,sec_per_epoch,imgs_per_sec,lr_G,lr_D,mean_G,mean_D,mean_adv,mean_perc,mean_id\n")

# def init_domcount_csv(path):  # A2 header
#     if not os.path.exists(path):
#         with open(path, "w", encoding="utf-8") as f:
#             f.write("epoch,domain,steps\n")

# def init_best_csv(path):  # A3 header
#     if not os.path.exists(path):
#         with open(path, "w", encoding="utf-8") as f:
#             f.write("domain,best_psnr_epoch,best_psnr,best_ssim_epoch,best_ssim,ckpt_at_best_psnr,ckpt_at_best_ssim\n")

# def write_metrics_row(path, epoch, domain, metrics, loss_means):
#     cols = [
#         str(epoch),
#         domain,
#         f"{metrics.get('psnr', np.nan)}",
#         f"{metrics.get('ssim', np.nan)}",
#     ]
#     if _have_niqe:    cols.append(f"{metrics.get('niqe', np.nan)}")
#     if _have_brisque: cols.append(f"{metrics.get('brisque', np.nan)}")
#     cols += [
#         f"{loss_means['G']}",
#         f"{loss_means['D']}",
#         f"{loss_means['adv']}",
#         f"{loss_means['perc']}",
#         f"{loss_means['id']}",
#     ]
#     with open(path, "a", encoding="utf-8") as f:
#         f.write(",".join(cols) + "\n")

# def compute_metrics_batch(fake_t, clean_t):
#     """
#     fake_t, clean_t: tensors in [0,1], shape [B,3,H,W]
#     Returns dict of mean metrics over the batch.
#     """
#     fake = (fake_t.detach().clamp(0,1).cpu().permute(0,2,3,1).numpy() * 255).astype(np.uint8)
#     gt   = (clean_t.detach().clamp(0,1).cpu().permute(0,2,3,1).numpy() * 255).astype(np.uint8)

#     psnr_vals, ssim_vals, niqe_vals, brisque_vals = [], [], [], []
#     for f, r in zip(fake, gt):
#         try: psnr_vals.append(_psnr(r, f, data_range=255))
#         except: pass
#         try: ssim_vals.append(_ssim_safe(f, r))
#         except: pass
#         if _have_niqe:
#             try:
#                 g = cv2.cvtColor(f, cv2.COLOR_RGB2GRAY).astype(np.float64) / 255.0
#                 niqe_vals.append(_niqe(g))
#             except: pass
#         if _have_brisque:
#             try:
#                 g = cv2.cvtColor(f, cv2.COLOR_RGB2GRAY).astype(np.float64) / 255.0
#                 brisque_vals.append(_brisque(g))
#             except: pass

#     out = {
#         "psnr": float(np.mean(psnr_vals)) if psnr_vals else np.nan,
#         "ssim": float(np.mean(ssim_vals)) if ssim_vals else np.nan,
#     }
#     if _have_niqe:    out["niqe"] = float(np.mean(niqe_vals)) if niqe_vals else np.nan
#     if _have_brisque: out["brisque"] = float(np.mean(brisque_vals)) if brisque_vals else np.nan
#     return out

# @torch.no_grad()
# def save_previews_by_domain(epoch):
#     model.eval()
#     for dom, vloader in val_loaders.items():
#         outdir = os.path.join(SAMPLES_ROOT, dom)
#         os.makedirs(outdir, exist_ok=True)
#         cnt = 0
#         for d, c, name in vloader:
#             d, c = d.to(device), c.to(device)
#             o = clamp01(model(d))
#             grid = torch.cat([d[:1], o[:1], c[:1]], dim=3)  # [B,3,H,3W]
#             arr = (grid[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
#             cv2.imwrite(
#                 os.path.join(outdir, f"ep{epoch:03d}_{dom}_{name[0]}.jpg"),
#                 cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
#             )
#             cnt += 1
#             if cnt >= 3:  # a few previews per domain
#                 break
#     model.train()

# @torch.no_grad()
# def evaluate_domains(epoch):
#     """
#     Evaluate model on validation sets per domain.
#     Returns { domain: {psnr, ssim, (niqe), (brisque)} } averaged over the loader.
#     """
#     model.eval()
#     results = {}
#     for dom, vloader in val_loaders.items():
#         vals = {"psnr": [], "ssim": []}
#         if _have_niqe: vals["niqe"] = []
#         if _have_brisque: vals["brisque"] = []

#         for d, c, _ in vloader:
#             d, c = d.to(device), c.to(device)
#             o = clamp01(model(d))
#             m = compute_metrics_batch(o, c)
#             for k, v in m.items():
#                 vals[k].append(v)

#         results[dom] = {k: (float(np.nanmean(v)) if len(v) else np.nan) for k, v in vals.items()}
#     model.train()
#     return results

# # ---------- Train ----------
# steps_per_epoch = min(len(dl) for dl in train_loaders.values()) * len(train_loaders)
# print("steps_per_epoch =", steps_per_epoch)

# scalerG = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
# scalerD = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

# # init logs
# init_metrics_csv(METRICS_CSV)
# init_train_csv(TRAIN_CSV)       # A1
# init_domcount_csv(DOMCOUNT_CSV) # A2
# init_best_csv(BEST_CSV)         # A3

# # in-memory best tracker (A3)
# _best = {}  # domain -> {psnr, psnr_epoch, ssim, ssim_epoch, ckpt_psnr, ckpt_ssim}

# EPOCHS = 25  # keep as you set
# rr = rr_batches(train_loaders)

# for ep in range(1, EPOCHS+1):
#     pbar = tqdm(range(steps_per_epoch), desc=f"Epoch {ep}/{EPOCHS}")
#     model.train(); disc.train()

#     # A1/A2 bookkeeping
#     epoch_start = time.perf_counter()
#     dom_counter = Counter()
#     batch_size_seen = None

#     # running sums for loss means
#     sum_G = sum_D = 0.0
#     sum_adv = sum_perc = sum_id = 0.0
#     n_steps = 0

#     for _ in pbar:
#         domain, batch = next(rr)
#         degraded, clean, _ = batch
#         degraded = degraded.to(device, non_blocking=True)
#         clean    = clean.to(device, non_blocking=True)

#         # record domain step (A2) + batch size (A1)
#         dom_counter[domain] += 1
#         if batch_size_seen is None:
#             batch_size_seen = degraded.size(0)

#         # --- G step ---
#         opt_G.zero_grad(set_to_none=True)
#         with torch.cuda.amp.autocast(enabled=(device=="cuda")):
#             fake = model(degraded)
#             loss_g_adv  = adv_loss(disc(fake), True) * L_ADV
#             loss_g_perc = l1(vgg(fake), vgg(clean)) * L_PERC
#             loss_g_id   = l1(model(clean), clean) * L_ID
#             loss_g = loss_g_adv + loss_g_perc + loss_g_id
#         scalerG.scale(loss_g).backward()
#         scalerG.step(opt_G); scalerG.update()

#         # --- D step ---
#         opt_D.zero_grad(set_to_none=True)
#         with torch.cuda.amp.autocast(enabled=(device=="cuda")):
#             r = disc(clean); f = disc(fake.detach())
#             loss_d = 0.5*(adv_loss(r, True) + adv_loss(f, False))
#         scalerD.scale(loss_d).backward()
#         scalerD.step(opt_D); scalerD.update()

#         # running logs
#         sum_G   += float(loss_g.item())
#         sum_D   += float(loss_d.item())
#         sum_adv += float(loss_g_adv.item())
#         sum_perc+= float(loss_g_perc.item())
#         sum_id  += float(loss_g_id.item())
#         n_steps += 1

#         pbar.set_postfix(G=float(loss_g.item()), D=float(loss_d.item()), dom=domain)

#     # mean losses for the epoch
#     denom = max(1, n_steps)
#     loss_means = {
#         "G":   round(sum_G/denom, 6),
#         "D":   round(sum_D/denom, 6),
#         "adv": round(sum_adv/denom, 6),
#         "perc":round(sum_perc/denom, 6),
#         "id":  round(sum_id/denom, 6),
#     }

#     # ---- save checkpoint (separate per epoch) ----
#     ckpt_path = os.path.join(CKPT_ROOT, f"restormer_mixed_ep{ep:03d}.pth")
#     torch.save(model.state_dict(), ckpt_path)
#     # also keep a latest pointer
#     torch.save(model.state_dict(), os.path.join(CKPT_ROOT, "restormer_mixed_latest.pth"))

#     # ---- previews + metrics per domain ----
#     save_previews_by_domain(ep)
#     dom_metrics = evaluate_domains(ep)
#     for dom, metrics in dom_metrics.items():
#         write_metrics_row(METRICS_CSV, ep, dom, metrics, loss_means)

#     # ---- A1: epoch timing, throughput, learning rates ----
#     epoch_time = time.perf_counter() - epoch_start
#     imgs_seen  = int(n_steps * (batch_size_seen if batch_size_seen is not None else 0))
#     with suppress(Exception):
#         lr_G = opt_G.param_groups[0]['lr']
#     with suppress(Exception):
#         lr_D = opt_D.param_groups[0]['lr']
#     lr_G = lr_G if 'lr_G' in locals() else float('nan')
#     lr_D = lr_D if 'lr_D' in locals() else float('nan')
#     with open(TRAIN_CSV, "a", encoding="utf-8") as f:
#         f.write(",".join(map(str, [
#             ep, n_steps, (batch_size_seen or 0), imgs_seen,
#             round(epoch_time, 4),
#             round(imgs_seen/max(1e-8, epoch_time), 3),
#             lr_G, lr_D,
#             loss_means["G"], loss_means["D"], loss_means["adv"], loss_means["perc"], loss_means["id"]
#         ])) + "\n")

#     # ---- A2: write domain step counts ----
#     with open(DOMCOUNT_CSV, "a", encoding="utf-8") as f:
#         for dname, steps in sorted(dom_counter.items()):
#             f.write(f"{ep},{dname},{steps}\n")

#     # ---- A3: update best epochs per domain (by PSNR & SSIM) ----
#     for dom, m in dom_metrics.items():
#         cur = _best.get(dom, {"psnr": -1, "ssim": -1})
#         # PSNR
#         if m.get("psnr", -1) > cur.get("psnr", -1):
#             cur["psnr"] = m["psnr"]
#             cur["psnr_epoch"] = ep
#             cur["ckpt_psnr"] = ckpt_path
#         # SSIM
#         if m.get("ssim", -1) > cur.get("ssim", -1):
#             cur["ssim"] = m["ssim"]
#             cur["ssim_epoch"] = ep
#             cur["ckpt_ssim"] = ckpt_path
#         _best[dom] = cur

#     # flush best table after each epoch
#     with open(BEST_CSV, "w", encoding="utf-8") as f:
#         f.write("domain,best_psnr_epoch,best_psnr,best_ssim_epoch,best_ssim,ckpt_at_best_psnr,ckpt_at_best_ssim\n")
#         for dom, cur in sorted(_best.items()):
#             f.write(",".join(map(str, [
#                 dom,
#                 cur.get("psnr_epoch",""),
#                 cur.get("psnr",""),
#                 cur.get("ssim_epoch",""),
#                 cur.get("ssim",""),
#                 cur.get("ckpt_psnr",""),
#                 cur.get("ckpt_ssim",""),
#             ])) + "\n")

# print("✅ Stage A done.")
# print(f"Checkpoints: {CKPT_ROOT}/restormer_mixed_epXXX.pth (+ restormer_mixed_latest.pth)")
# print(f"Previews:    {SAMPLES_ROOT}/<domain>/")
# print(f"Metrics CSV: {METRICS_CSV}")
# print(f"A1: Train trace: {TRAIN_CSV}")
# print(f"A2: Domain steps: {DOMCOUNT_CSV}")
# print(f"A3: Best epochs:  {BEST_CSV}")


In [None]:
# # ==== Stage A: Train one common model across all degradations (metrics + per-epoch ckpts) ====
# import os, cv2, time, json, numpy as np
# from itertools import cycle
# from tqdm import tqdm

# import torch
# import torch.nn.functional as F

# # ---------- Paths ----------
# SAMPLES_ROOT = "/kaggle/working/samples_mixed"
# CKPT_ROOT    = "/kaggle/working/checkpoints_mixed"
# LOG_ROOT     = "/kaggle/working/logs_mixed"
# os.makedirs(SAMPLES_ROOT, exist_ok=True)
# os.makedirs(CKPT_ROOT,    exist_ok=True)
# os.makedirs(LOG_ROOT,     exist_ok=True)

# METRICS_CSV = os.path.join(LOG_ROOT, "metrics_summary.csv")

# # ---------- Metrics helpers ----------
# _have_brisque = False
# _have_niqe = False
# try:
#     from skimage.metrics import peak_signal_noise_ratio as _psnr
#     from skimage.metrics import structural_similarity as _ssim
#     try:
#         from skimage.metrics import niqe as _niqe
#         _have_niqe = True
#     except Exception:
#         _have_niqe = False
#     try:
#         from skimage.metrics import brisque as _brisque
#         _have_brisque = True
#     except Exception:
#         _have_brisque = False
# except Exception as e:
#     raise RuntimeError(
#         "scikit-image not available; run `pip install -q scikit-image` in a Kaggle cell."
#     ) from e

# def _ssim_safe(img, ref):
#     """Handle skimage API differences (channel_axis vs multichannel)."""
#     try:
#         return _ssim(ref, img, data_range=255, channel_axis=2)
#     except TypeError:
#         return _ssim(ref, img, data_range=255, multichannel=True)

# def clamp01(x): 
#     return torch.clamp(x, 0, 1)

# def rr_batches(loaders_dict):
#     iters = {k: cycle(v) for k, v in loaders_dict.items()}
#     while True:
#         for k in loaders_dict.keys():
#             yield k, next(iters[k])

# def init_metrics_csv(path):
#     if not os.path.exists(path):
#         cols = ["epoch","domain","psnr","ssim"]
#         if _have_niqe: cols.append("niqe")
#         if _have_brisque: cols.append("brisque")
#         cols += ["mean_G_loss","mean_D_loss","mean_adv","mean_perc","mean_id"]
#         with open(path, "w", encoding="utf-8") as f:
#             f.write(",".join(cols) + "\n")

# def write_metrics_row(path, epoch, domain, metrics, loss_means):
#     cols = [
#         str(epoch),
#         domain,
#         f"{metrics.get('psnr', np.nan)}",
#         f"{metrics.get('ssim', np.nan)}",
#     ]
#     if _have_niqe:    cols.append(f"{metrics.get('niqe', np.nan)}")
#     if _have_brisque: cols.append(f"{metrics.get('brisque', np.nan)}")
#     cols += [
#         f"{loss_means['G']}",
#         f"{loss_means['D']}",
#         f"{loss_means['adv']}",
#         f"{loss_means['perc']}",
#         f"{loss_means['id']}",
#     ]
#     with open(path, "a", encoding="utf-8") as f:
#         f.write(",".join(cols) + "\n")

# def compute_metrics_batch(fake_t, clean_t):
#     """
#     fake_t, clean_t: tensors in [0,1], shape [B,3,H,W]
#     Returns dict of mean metrics over the batch.
#     """
#     fake = (fake_t.detach().clamp(0,1).cpu().permute(0,2,3,1).numpy() * 255).astype(np.uint8)
#     gt   = (clean_t.detach().clamp(0,1).cpu().permute(0,2,3,1).numpy() * 255).astype(np.uint8)

#     psnr_vals, ssim_vals, niqe_vals, brisque_vals = [], [], [], []
#     for f, r in zip(fake, gt):
#         try: psnr_vals.append(_psnr(r, f, data_range=255))
#         except: pass
#         try: ssim_vals.append(_ssim_safe(f, r))
#         except: pass
#         if _have_niqe:
#             try:
#                 g = cv2.cvtColor(f, cv2.COLOR_RGB2GRAY).astype(np.float64) / 255.0
#                 niqe_vals.append(_niqe(g))
#             except: pass
#         if _have_brisque:
#             try:
#                 g = cv2.cvtColor(f, cv2.COLOR_RGB2GRAY).astype(np.float64) / 255.0
#                 brisque_vals.append(_brisque(g))
#             except: pass

#     out = {
#         "psnr": float(np.mean(psnr_vals)) if psnr_vals else np.nan,
#         "ssim": float(np.mean(ssim_vals)) if ssim_vals else np.nan,
#     }
#     if _have_niqe:    out["niqe"] = float(np.mean(niqe_vals)) if niqe_vals else np.nan
#     if _have_brisque: out["brisque"] = float(np.mean(brisque_vals)) if brisque_vals else np.nan
#     return out

# @torch.no_grad()
# def save_previews_by_domain(epoch):
#     model.eval()
#     for dom, vloader in val_loaders.items():
#         outdir = os.path.join(SAMPLES_ROOT, dom)
#         os.makedirs(outdir, exist_ok=True)
#         cnt = 0
#         for d, c, name in vloader:
#             d, c = d.to(device), c.to(device)
#             o = clamp01(model(d))
#             grid = torch.cat([d[:1], o[:1], c[:1]], dim=3)  # [B,3,H,3W]
#             arr = (grid[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
#             cv2.imwrite(
#                 os.path.join(outdir, f"ep{epoch:03d}_{dom}_{name[0]}.jpg"),
#                 cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
#             )
#             cnt += 1
#             if cnt >= 3:  # a few previews per domain
#                 break
#     model.train()

# @torch.no_grad()
# def evaluate_domains(epoch):
#     """
#     Evaluate model on validation sets per domain.
#     Returns { domain: {psnr, ssim, (niqe), (brisque)} } averaged over the loader.
#     """
#     model.eval()
#     results = {}
#     for dom, vloader in val_loaders.items():
#         vals = {"psnr": [], "ssim": []}
#         if _have_niqe: vals["niqe"] = []
#         if _have_brisque: vals["brisque"] = []

#         for d, c, _ in vloader:
#             d, c = d.to(device), c.to(device)
#             o = clamp01(model(d))
#             m = compute_metrics_batch(o, c)
#             for k, v in m.items():
#                 vals[k].append(v)

#         results[dom] = {k: (float(np.nanmean(v)) if len(v) else np.nan) for k, v in vals.items()}
#     model.train()
#     return results

# # ---------- Train ----------
# steps_per_epoch = min(len(dl) for dl in train_loaders.values()) * len(train_loaders)
# print("steps_per_epoch =", steps_per_epoch)

# scalerG = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
# scalerD = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

# init_metrics_csv(METRICS_CSV)

# EPOCHS = 25  # keep as you set
# rr = rr_batches(train_loaders)

# for ep in range(1, EPOCHS+1):
#     pbar = tqdm(range(steps_per_epoch), desc=f"Epoch {ep}/{EPOCHS}")
#     model.train(); disc.train()

#     # running sums for loss means
#     sum_G = sum_D = 0.0
#     sum_adv = sum_perc = sum_id = 0.0
#     n_steps = 0

#     for _ in pbar:
#         domain, batch = next(rr)
#         degraded, clean, _ = batch
#         degraded = degraded.to(device, non_blocking=True)
#         clean    = clean.to(device, non_blocking=True)

#         # --- G step ---
#         opt_G.zero_grad(set_to_none=True)
#         with torch.cuda.amp.autocast(enabled=(device=="cuda")):
#             fake = model(degraded)
#             loss_g_adv  = adv_loss(disc(fake), True) * L_ADV
#             loss_g_perc = l1(vgg(fake), vgg(clean)) * L_PERC
#             loss_g_id   = l1(model(clean), clean) * L_ID
#             loss_g = loss_g_adv + loss_g_perc + loss_g_id
#         scalerG.scale(loss_g).backward()
#         scalerG.step(opt_G); scalerG.update()

#         # --- D step ---
#         opt_D.zero_grad(set_to_none=True)
#         with torch.cuda.amp.autocast(enabled=(device=="cuda")):
#             r = disc(clean); f = disc(fake.detach())
#             loss_d = 0.5*(adv_loss(r, True) + adv_loss(f, False))
#         scalerD.scale(loss_d).backward()
#         scalerD.step(opt_D); scalerD.update()

#         # running logs
#         sum_G   += float(loss_g.item())
#         sum_D   += float(loss_d.item())
#         sum_adv += float(loss_g_adv.item())
#         sum_perc+= float(loss_g_perc.item())
#         sum_id  += float(loss_g_id.item())
#         n_steps += 1

#         pbar.set_postfix(G=float(loss_g.item()), D=float(loss_d.item()), dom=domain)

#     # mean losses for the epoch
#     denom = max(1, n_steps)
#     loss_means = {
#         "G":   round(sum_G/denom, 6),
#         "D":   round(sum_D/denom, 6),
#         "adv": round(sum_adv/denom, 6),
#         "perc":round(sum_perc/denom, 6),
#         "id":  round(sum_id/denom, 6),
#     }

#     # ---- save checkpoint (separate per epoch) ----
#     ckpt_path = os.path.join(CKPT_ROOT, f"restormer_mixed_ep{ep:03d}.pth")
#     torch.save(model.state_dict(), ckpt_path)
#     # also keep a latest pointer
#     torch.save(model.state_dict(), os.path.join(CKPT_ROOT, "restormer_mixed_latest.pth"))

#     # ---- previews + metrics per domain ----
#     save_previews_by_domain(ep)
#     dom_metrics = evaluate_domains(ep)
#     for dom, metrics in dom_metrics.items():
#         write_metrics_row(METRICS_CSV, ep, dom, metrics, loss_means)

# print("✅ Stage A done.")
# print(f"Checkpoints: {CKPT_ROOT}/restormer_mixed_epXXX.pth (+ restormer_mixed_latest.pth)")
# print(f"Previews:    {SAMPLES_ROOT}/<domain>/")
# print(f"Metrics CSV: {METRICS_CSV}")


In [None]:
# ==== Zip all generated outputs ====
import os, shutil
from datetime import datetime

OUTPUT_ROOT = "/kaggle/working"
ZIP_NAME = f"restormer_outputs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
ZIP_PATH = os.path.join(OUTPUT_ROOT, ZIP_NAME)

# Include these paths if they exist
INCLUDE = [
    "samples_mixed",            # previews
    "checkpoints_mixed",        # per-epoch .pth
    "sample_metrics.csv",       # PSNR/SSIM for preview samples
    "restormer_mixed_base.pth", # last-epoch convenience checkpoint
]

paths = [os.path.join(OUTPUT_ROOT, p) for p in INCLUDE if os.path.exists(os.path.join(OUTPUT_ROOT, p))]

if not paths:
    print("⚠️ No generated files found to zip.")
else:
    print("Adding to archive:")
    for p in paths:
        print("  -", p)
    # Easiest: zip the entire /kaggle/working (small project) so relative paths stay intact
    shutil.make_archive(ZIP_PATH[:-4], 'zip', OUTPUT_ROOT)
    print(f"\n✅ Zipped to: {ZIP_PATH}")

In [None]:
# ==== Graphs from sample_metrics.csv ====
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

METRICS_CSV = "/kaggle/working/sample_metrics.csv"
assert os.path.exists(METRICS_CSV), "sample_metrics.csv not found. Run training cell first."

df = pd.read_csv(METRICS_CSV)  # columns: epoch,domain,name,psnr,ssim
df["epoch"] = df["epoch"].astype(int)

# Aggregate: mean per (epoch, domain)
agg = df.groupby(["epoch","domain"], as_index=False).agg(
    psnr_mean=("psnr","mean"),
    psnr_std =("psnr","std"),
    ssim_mean=("ssim","mean"),
    ssim_std =("ssim","std"),
    n=("name","count")
)

# 1) PSNR vs Epoch — per domain
for dom in sorted(agg["domain"].unique()):
    sub = agg[agg["domain"] == dom]
    plt.figure()
    plt.plot(sub["epoch"], sub["psnr_mean"], marker="o")
    plt.title(f"PSNR vs Epoch — {dom}")
    plt.xlabel("Epoch"); plt.ylabel("PSNR (dB)")
    plt.grid(True); plt.show()

# 2) SSIM vs Epoch — per domain
for dom in sorted(agg["domain"].unique()):
    sub = agg[agg["domain"] == dom]
    plt.figure()
    plt.plot(sub["epoch"], sub["ssim_mean"], marker="o")
    plt.title(f"SSIM vs Epoch — {dom}")
    plt.xlabel("Epoch"); plt.ylabel("SSIM")
    plt.grid(True); plt.show()

# 3) Last-epoch bar charts (mean ± std) per domain
last_ep = agg["epoch"].max()
last = agg[agg["epoch"] == last_ep].copy().sort_values("domain")

plt.figure()
plt.bar(last["domain"], last["psnr_mean"], yerr=last["psnr_std"], capsize=3)
plt.title(f"PSNR by Domain — Last Epoch {last_ep}")
plt.xlabel("Domain"); plt.ylabel("PSNR (dB)")
plt.xticks(rotation=20); plt.grid(True, axis="y"); plt.show()

plt.figure()
plt.bar(last["domain"], last["ssim_mean"], yerr=last["ssim_std"], capsize=3)
plt.title(f"SSIM by Domain — Last Epoch {last_ep}")
plt.xlabel("Domain"); plt.ylabel("SSIM")
plt.xticks(rotation=20); plt.grid(True, axis="y"); plt.show()

In [None]:
# ==== Present results: tables and quick summaries ====
import os
import pandas as pd
import numpy as np

METRICS_CSV = "/kaggle/working/sample_metrics.csv"
assert os.path.exists(METRICS_CSV), "sample_metrics.csv not found."

df = pd.read_csv(METRICS_CSV)  # epoch,domain,name,psnr,ssim
df["epoch"] = df["epoch"].astype(int)

# Mean/std per (epoch, domain)
agg = df.groupby(["epoch","domain"], as_index=False).agg(
    psnr_mean=("psnr","mean"),
    psnr_std =("psnr","std"),
    ssim_mean=("ssim","mean"),
    ssim_std =("ssim","std"),
    n=("name","count")
).sort_values(["domain","epoch"])

# 1) Best epoch per domain (by PSNR and by SSIM)
idx_psnr = agg.groupby("domain")["psnr_mean"].idxmax()
idx_ssim = agg.groupby("domain")["ssim_mean"].idxmax()
best_psnr = agg.loc[idx_psnr].reset_index(drop=True)
best_ssim = agg.loc[idx_ssim].reset_index(drop=True)

print("=== Best Epoch per Domain (by PSNR) ===")
display(best_psnr[["domain","epoch","psnr_mean","psnr_std","n"]])

print("=== Best Epoch per Domain (by SSIM) ===")
display(best_ssim[["domain","epoch","ssim_mean","ssim_std","n"]])

# 2) Last epoch summary (mean ± std) per domain
last_ep = agg["epoch"].max()
last = agg[agg["epoch"] == last_ep].copy().sort_values("domain")

# readable columns for a report
last["PSNR (mean±std)"] = last["psnr_mean"].round(3).astype(str) + " ± " + last["psnr_std"].fillna(0).round(3).astype(str)
last["SSIM (mean±std)"] = last["ssim_mean"].round(4).astype(str) + " ± " + last["ssim_std"].fillna(0).round(4).astype(str)

print(f"=== Last Epoch Summary (Epoch {last_ep}) ===")
display(last[["domain","PSNR (mean±std)","SSIM (mean±std)","n"]])

# 3) Save tidy CSVs for report attachments
OUT_DIR = "/kaggle/working"
agg.to_csv(os.path.join(OUT_DIR, "sample_metrics_aggregated.csv"), index=False)
best_psnr.to_csv(os.path.join(OUT_DIR, "best_epochs_by_psnr.csv"), index=False)
best_ssim.to_csv(os.path.join(OUT_DIR, "best_epochs_by_ssim.csv"), index=False)

print("\nSaved:")
print(" - /kaggle/working/sample_metrics_aggregated.csv")
print(" - /kaggle/working/best_epochs_by_psnr.csv")
print(" - /kaggle/working/best_epochs_by_ssim.csv")