In [82]:
!nvidia-smi
!pip -q install --upgrade pip
!pip -q install basicsr timm einops gdown imageio opencv-python albumentations matplotlib torchvision numpy==1.26.4 scipy==1.13.1

# Fresh clone
!rm -rf Restormer
!git clone https://github.com/swz30/Restormer.git
%cd Restormer

Fri Nov 28 10:55:55 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.172.08             Driver Version: 570.172.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off |   00000000:00:04.0 Off |                    0 |
| N/A   38C    P0             33W /  250W |    8337MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [83]:
import numpy as np
import scipy

print("NumPy version:", np.__version__)
print("SciPy version:", scipy.__version__)

NumPy version: 1.26.4
SciPy version: 1.13.1


In [84]:
# ---- TorchVision shim for legacy imports (functional_tensor) ----
import sys, types
import torch, torchvision
print("Torch:", torch.__version__, "| TorchVision:", torchvision.__version__)

from torchvision.transforms import functional as F
ft_mod = types.ModuleType("torchvision.transforms.functional_tensor")
for k, v in F.__dict__.items():
    setattr(ft_mod, k, v)
sys.modules["torchvision.transforms.functional_tensor"] = ft_mod
print("Shim installed: torchvision.transforms.functional_tensor ‚Üí .functional ‚úÖ")

Torch: 2.6.0+cu124 | TorchVision: 0.21.0+cu124
Shim installed: torchvision.transforms.functional_tensor ‚Üí .functional ‚úÖ


In [85]:
# ‚úÖ Run this as a shell cell (notice the ! at the start of each command)
CKPT_URL="https://github.com/swz30/Restormer/releases/download/v1.0/deraining.pth"

!mkdir -p ./checkpoints
!rm -f ./checkpoints/pretrained_task.pth
!curl -L "$CKPT_URL" -o ./checkpoints/pretrained_task.pth
!ls -lh ./checkpoints

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 99.8M  100 99.8M    0     0   113M      0 --:--:-- --:--:-- --:--:--  154M
total 100M
-rw-r--r-- 1 root root 100M Nov 28 10:56 pretrained_task.pth


In [86]:
import os, torch
CKPT_PATH = "./checkpoints/pretrained_task.pth"
print("File size (MB):", os.path.getsize(CKPT_PATH)/1e6)
# Should be roughly > 50 MB
assert os.path.getsize(CKPT_PATH) > 10_000_000, "Checkpoint looks incomplete!"

# Quick load test
try:
    _ = torch.load(CKPT_PATH, map_location="cpu")
    print("‚úÖ torch.load works fine ‚Äî checkpoint OK!")
except Exception as e:
    print("‚ùå Problem loading checkpoint:", e)

File size (MB): 104.700429
‚úÖ torch.load works fine ‚Äî checkpoint OK!


In [87]:
# ==== RECOVERY: make sure Restormer is present and importable ====
import os, sys, subprocess, shutil, glob, importlib.util
from pathlib import Path

REPO_URL  = "https://github.com/swz30/Restormer.git"
REPO_DIR  = Path("/kaggle/working/Restormer")   # fixed absolute path

# 1) Fresh clone if missing or empty
def is_dir_empty(p: Path):
    return (not p.exists()) or (p.exists() and len(list(p.rglob("*"))) == 0)

if is_dir_empty(REPO_DIR):
    if REPO_DIR.exists():
        shutil.rmtree(REPO_DIR, ignore_errors=True)
    print("[i] Cloning Restormer repo...")
    subprocess.check_call(["git", "clone", "--depth", "1", REPO_URL, str(REPO_DIR)])
else:
    print("[i] Restormer repo already present at", REPO_DIR)

# 2) Put repo on sys.path so 'basicsr' (inside repo) is importable
if str(REPO_DIR) not in sys.path:
    sys.path.insert(0, str(REPO_DIR))
if str(REPO_DIR.parent) not in sys.path:
    sys.path.insert(0, str(REPO_DIR.parent))

# 3) Try canonical import; if it fails, import by file path
Restormer = None
try:
    from basicsr.models.archs.restormer_arch import Restormer  # type: ignore
    print("[‚úì] Imported Restormer from basicsr.models.archs.restormer_arch")
except Exception as e:
    print("[!] Canonical import failed:", e)
    cand = list(REPO_DIR.rglob("restormer_arch.py"))
    assert cand, "restormer_arch.py not found under repo"
    mod_path = cand[0]
    spec = importlib.util.spec_from_file_location("restormer_local", str(mod_path))
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)  # type: ignore
    Restormer = getattr(mod, "Restormer")
    print(f"[‚úì] Imported Restormer from file: {mod_path}")

[i] Restormer repo already present at /kaggle/working/Restormer
[!] Canonical import failed: No module named 'basicsr.models.archs'
[‚úì] Imported Restormer from file: /kaggle/working/Restormer/Restormer/Restormer/Restormer/Restormer/Restormer/Restormer/Restormer/Restormer/Restormer/basicsr/models/archs/restormer_arch.py


In [88]:
# === ONE-CELL SAFE LOADER FOR RESTORMER + CHECKPOINT ===
import sys, types, importlib.util, re
from pathlib import Path
from collections import OrderedDict
import torch

# ---- Shim for legacy torchvision import used by some basicsr code ----
try:
    from torchvision.transforms import functional as F
    ft_mod = types.ModuleType("torchvision.transforms.functional_tensor")
    for k, v in F.__dict__.items():
        setattr(ft_mod, k, v)
    sys.modules["torchvision.transforms.functional_tensor"] = ft_mod
    print("Shim ok: torchvision.transforms.functional_tensor -> .functional")
except Exception as e:
    print("Shim skipped:", e)

# ---- Robust import of Restormer (canonical or file-path fallback) ----
repo_root = "/kaggle/working/Restormer"
if repo_root not in sys.path:
    sys.path.append(repo_root)

Restormer = None
try:
    from basicsr.models.archs.restormer_arch import Restormer  # type: ignore
    print("Imported Restormer from basicsr.models.archs.restormer_arch")
except Exception as e:
    print("Canonical import failed:", e)
    candidates = list(Path(repo_root).rglob("*restormer*.py"))
    assert candidates, "Restormer source not found under repo_root"
    p = str(candidates[0])
    spec = importlib.util.spec_from_file_location("restormer_local", p)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)  # type: ignore
    Restormer = getattr(mod, "Restormer")
    print("Imported Restormer from file:", p)

# ---- Load checkpoint and pick the correct inner state dict ----
ckpt_path = "./checkpoints/pretrained_task.pth"
state = torch.load(ckpt_path, map_location="cpu")

if isinstance(state, dict):
    if isinstance(state.get("params_ema"), dict):
        sd = state["params_ema"]; print("Using checkpoint['params_ema']")
    elif isinstance(state.get("params"), dict):
        sd = state["params"];     print("Using checkpoint['params']")
    elif isinstance(state.get("state_dict"), dict):
        sd = state["state_dict"]; print("Using checkpoint['state_dict']")
    elif isinstance(state.get("model"), dict):
        sd = state["model"];      print("Using checkpoint['model']")
    else:
        sd = state;               print("Using checkpoint as-is (flat dict)")
else:
    sd = state;                   print("Using checkpoint as-is (non-dict)")

# strip DDP prefix if present
if any(k.startswith("module.") for k in sd.keys()):
    sd = OrderedDict((re.sub(r"^module\.", "", k), v) for k, v in sd.items())
    print("Stripped 'module.' prefixes")

def build_model(layernorm_type: str):
    return Restormer(
        inp_channels=3, out_channels=3,
        dim=48,
        num_blocks=[4,6,6,8],
        num_refinement_blocks=4,
        heads=[1,2,4,8],
        ffn_expansion_factor=2.66,
        bias=False,
        LayerNorm_type=layernorm_type,  # 'WithBias' or 'BiasFree'
        dual_pixel_task=False
    )

def try_load(layernorm_type: str):
    m = build_model(layernorm_type)
    missing, unexpected = m.load_state_dict(sd, strict=False)
    print(f"[{layernorm_type}] missing: {len(missing)} | unexpected: {len(unexpected)}")
    if missing:   print("  sample missing:", missing[:5])
    if unexpected:print("  sample unexpected:", unexpected[:5])
    return m, missing, unexpected

# Try BiasFree first, then WithBias
model, missing, unexpected = try_load("BiasFree")
if missing or unexpected:
    print("Retrying WithBias ‚Ä¶")
    model, missing, unexpected = try_load("WithBias")

print("\nFinal -> Loaded with strict=False")
print("missing:", len(missing), "unexpected:", len(unexpected))

Shim ok: torchvision.transforms.functional_tensor -> .functional
Canonical import failed: No module named 'basicsr.models.archs'
Imported Restormer from file: /kaggle/working/Restormer/Restormer/Restormer/Restormer/Restormer/Restormer/Restormer/Restormer/Restormer/Restormer/basicsr/models/archs/restormer_arch.py
Using checkpoint['params']
[BiasFree] missing: 0 | unexpected: 88
  sample unexpected: ['encoder_level1.0.norm1.body.bias', 'encoder_level1.0.norm2.body.bias', 'encoder_level1.1.norm1.body.bias', 'encoder_level1.1.norm2.body.bias', 'encoder_level1.2.norm1.body.bias']
Retrying WithBias ‚Ä¶
[WithBias] missing: 0 | unexpected: 0

Final -> Loaded with strict=False
missing: 0 unexpected: 0


In [89]:
# freeze everything
for p in model.parameters(): p.requires_grad = False

# unfreeze late/refinement + output head (adjust names if needed)
to_unfreeze = []
for n, p in model.named_parameters():
    if any(k in n.lower() for k in ["refinement", "reconstruct", "reconstruction", "conv_out", "tail"]):
        p.requires_grad = True
        to_unfreeze.append(n)

# fallback: if nothing matched, unfreeze last ~10 params
if not to_unfreeze:
    for n,p in list(model.named_parameters())[-10:]:
        p.requires_grad = True
        to_unfreeze.append(n)

print("Unfreezing:", *to_unfreeze[:8], "...", sep="\n")

import torch
trainable = [p for p in model.parameters() if p.requires_grad]
opt_G = torch.optim.Adam(trainable, lr=1e-5, betas=(0.9,0.999))

Unfreezing:
refinement.0.norm1.body.weight
refinement.0.norm1.body.bias
refinement.0.attn.temperature
refinement.0.attn.qkv.weight
refinement.0.attn.qkv_dwconv.weight
refinement.0.attn.project_out.weight
refinement.0.norm2.body.weight
refinement.0.norm2.body.bias
...


In [90]:
import torch.nn as nn, torchvision

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device).train()

# perceptual VGG features
vgg = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.IMAGENET1K_V1).features[:16].eval().to(device)
for p in vgg.parameters(): p.requires_grad = False
l1, mse = nn.L1Loss(), nn.MSELoss()

# tiny PatchGAN
class PatchD(nn.Module):
    def __init__(self, in_ch=3, base=64):
        super().__init__()
        def blk(ic, oc, norm=True):
            m=[nn.Conv2d(ic, oc, 4, 2, 1)]
            if norm: m+=[nn.InstanceNorm2d(oc, affine=True)]
            m+=[nn.LeakyReLU(0.2, inplace=True)]
            return nn.Sequential(*m)
        self.net = nn.Sequential(
            blk(in_ch, base, norm=False),
            blk(base, base*2),
            blk(base*2, base*4),
            nn.Conv2d(base*4, 1, 3, 1, 1)
        )
    def forward(self, x): return self.net(x)

disc = PatchD().to(device).train()
opt_D = torch.optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5,0.999))

def adv_loss(pred, is_real): 
    tgt = torch.ones_like(pred) if is_real else torch.zeros_like(pred)
    return mse(pred, tgt)

L_ADV, L_PERC, L_ID = 1.0, 0.1, 0.1

In [91]:
# ==== CELL 10 (MODIFIED): Single-Domain Data Loaders ====
import os, glob, random, cv2, torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import albumentations as A

BASE_DIR = "/kaggle/input/claude-synthesis-new"

DATA_ROOTS = {
    "clean":     os.path.join(BASE_DIR, "clean"),
    "fog":       os.path.join(BASE_DIR, "fog"),
    "rain":      os.path.join(BASE_DIR, "rain"),
    "lowlight":  os.path.join(BASE_DIR, "lowlight"),
}

DEG_DOMAINS = [d for d in DATA_ROOTS.keys() if d != "clean"]
CLEAN_DIR   = DATA_ROOTS["clean"]

def list_images(path):
    exts = ("*.jpg","*.jpeg","*.png","*.bmp")
    files = []
    for e in exts:
        files += glob.glob(os.path.join(path, e))
        files += glob.glob(os.path.join(path, "**", e), recursive=True)
    if not files:
        raise FileNotFoundError(f"No images found in {path}")
    return sorted(files)

# MODIFIED: Simpler augmentations for EDAR training
train_tf = A.Compose([
    A.LongestMaxSize(max_size=384),
    A.PadIfNeeded(min_height=384, min_width=384, border_mode=cv2.BORDER_REFLECT_101),
    A.RandomCrop(256, 256),
    A.HorizontalFlip(p=0.5),
    # REMOVED: RandomBrightnessContrast (let EDAR learn from clean data)
])

val_tf = A.Compose([
    A.LongestMaxSize(max_size=512),
    A.PadIfNeeded(min_height=512, min_width=512, border_mode=cv2.BORDER_REFLECT_101),
])

def to_tensor(img):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.ascontiguousarray(img).astype(np.float32)/255.0
    return torch.from_numpy(np.transpose(img,(2,0,1)))

# MODIFIED: Paired dataset (degraded -> clean pairs, not random)
class PairedDataset(Dataset):
    """Returns paired (degraded, clean) samples"""
    def __init__(self, deg_files, cln_files, transform=None, paired=True):
        self.deg_files = deg_files
        self.cln_files = cln_files
        self.tf = transform
        self.paired = paired
        
        # For paired training, ensure same count
        if paired:
            min_len = min(len(deg_files), len(cln_files))
            self.deg_files = deg_files[:min_len]
            self.cln_files = cln_files[:min_len]
    
    def __len__(self): 
        return len(self.deg_files)
    
    def __getitem__(self, idx):
        dimg = cv2.imread(self.deg_files[idx])
        
        if self.paired:
            cimg = cv2.imread(self.cln_files[idx])  # Same index for paired
        else:
            cimg = cv2.imread(random.choice(self.cln_files))  # Random for unpaired
        
        if self.tf:
            # Apply same augmentation to both (paired consistency)
            seed = np.random.randint(0, 999999)
            random.seed(seed); np.random.seed(seed)
            dimg = self.tf(image=dimg)['image']
            random.seed(seed); np.random.seed(seed)
            cimg = self.tf(image=cimg)['image']
        
        return to_tensor(dimg), to_tensor(cimg), os.path.basename(self.deg_files[idx])

def build_loader_for(domain, batch_size=2, train=True, paired=False):
    """
    Build dataloader for single domain
    
    Args:
        domain: 'fog', 'rain', or 'lowlight'
        batch_size: batch size
        train: training or validation mode
        paired: if True, use paired data (same index); if False, random clean images
    """
    deg_dir = DATA_ROOTS[domain]
    deg_files = list_images(deg_dir)
    cln_files = list_images(CLEAN_DIR)
    
    tf = train_tf if train else val_tf
    ds = PairedDataset(deg_files, cln_files, transform=tf, paired=paired)
    
    return DataLoader(
        ds, 
        batch_size=batch_size, 
        shuffle=train, 
        num_workers=2, 
        pin_memory=True, 
        drop_last=train
    ), len(deg_files)

# Build loaders for all domains
train_loaders = {}
val_loaders   = {}
counts = {}

for d in DEG_DOMAINS:
    # TRAIN: Unpaired (random clean images) - better for unlabeled scenario
    tl, n = build_loader_for(d, batch_size=4, train=True, paired=False)
    
    # VAL: Use all samples, no pairing needed
    vl, _ = build_loader_for(d, batch_size=1, train=False, paired=False)
    
    train_loaders[d] = tl
    val_loaders[d]   = vl
    counts[d] = n

print("‚úÖ Data loaders created (EDAR optimized)")
print("\nDomain image counts:")
for d, n in counts.items():
    print(f"  {d:12s}: {n}")

print("\nTrain loader batches per domain:")
for d, dl in train_loaders.items():
    print(f"  {d:12s}: {len(dl)} batches")

print("\nVal samples per domain:")
for d, dl in val_loaders.items():
    print(f"  {d:12s}: {len(dl.dataset)} samples")

‚úÖ Data loaders created (EDAR optimized)

Domain image counts:
  fog         : 1000
  rain        : 1000
  lowlight    : 1000

Train loader batches per domain:
  fog         : 250 batches
  rain        : 250 batches
  lowlight    : 250 batches

Val samples per domain:
  fog         : 1000 samples
  rain        : 1000 samples
  lowlight    : 1000 samples


In [92]:
# ============================================================
# CELL 13: EDAR MODULES (Edge-Aware Domain-Adaptive Restormer)
# ============================================================

import torch
import torch.nn as nn
import torch.nn.functional as F

# -------- 1. EDGE EXTRACTOR --------
class FastEdgeExtractor(nn.Module):
    """Sobel edge detection - fast and differentiable"""
    def __init__(self):
        super().__init__()
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
        self.register_buffer('sobel_x', sobel_x.view(1, 1, 3, 3))
        self.register_buffer('sobel_y', sobel_y.view(1, 1, 3, 3))
        
    def forward(self, x):
        """x: [B, 3, H, W] -> edges: [B, 1, H, W]"""
        gray = 0.299*x[:,0:1] + 0.587*x[:,1:2] + 0.114*x[:,2:3]
        grad_x = F.conv2d(gray, self.sobel_x, padding=1)
        grad_y = F.conv2d(gray, self.sobel_y, padding=1)
        edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6)
        return edges


# -------- 2. DOMAIN ENCODER --------
class SimpleDomainEncoder(nn.Module):
    """Maps domain name to learnable embedding"""
    def __init__(self, n_domains=3, embed_dim=64):
        super().__init__()
        self.embedding = nn.Embedding(n_domains, embed_dim)
        self.domain_map = {'fog': 0, 'rain': 1, 'lowlight': 2}
        
    def forward(self, domain_name):
        """domain_name: str -> [1, embed_dim, 1, 1]"""
        idx = torch.tensor([self.domain_map[domain_name]], device=self.embedding.weight.device)
        emb = self.embedding(idx)
        return emb[:, :, None, None]


# -------- 3. EDGE-GUIDED REFINEMENT --------
class LightweightEdgeRefinement(nn.Module):
    """Injects edge info into features"""
    def __init__(self, in_channels=96):
        super().__init__()
        self.edge_process = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, 3, 1, 1, bias=False),
            nn.ReLU(inplace=True)
        )
        self.gate = nn.Sequential(
            nn.Conv2d(in_channels + 32, 32, 1, bias=False),
            nn.Sigmoid()
        )
        self.fuse = nn.Conv2d(in_channels + 32, in_channels, 1, bias=False)
        
    def forward(self, feat, edge_map):
        """feat: [B, C, H, W], edge_map: [B, 1, H, W]"""
        if edge_map.shape[2:] != feat.shape[2:]:
            edge_map = F.interpolate(edge_map, size=feat.shape[2:], mode='bilinear', align_corners=False)
        
        edge_feat = self.edge_process(edge_map)
        combined = torch.cat([feat, edge_feat], dim=1)
        attn = self.gate(combined)
        edge_weighted = edge_feat * attn
        fused = self.fuse(torch.cat([feat, edge_weighted], dim=1))
        return feat + fused


# -------- 4. DOMAIN-ADAPTIVE MODULATION --------
class DomainModulation(nn.Module):
    """FiLM-style domain conditioning"""
    def __init__(self, feat_channels=96, domain_dim=64):
        super().__init__()
        self.scale = nn.Linear(domain_dim, feat_channels)
        self.shift = nn.Linear(domain_dim, feat_channels)
        
    def forward(self, feat, domain_emb):
        """feat: [B, C, H, W], domain_emb: [B, D, 1, 1]"""
        domain_vec = domain_emb.squeeze(-1).squeeze(-1)
        gamma = self.scale(domain_vec)[:, :, None, None]
        beta = self.shift(domain_vec)[:, :, None, None]
        return feat * (1 + gamma) + beta


# -------- 5. EDAR WRAPPER --------
class EDARWrapper(nn.Module):
    """Wraps Restormer with EDAR components"""
    def __init__(self, base_restormer, n_domains=3):
        super().__init__()
        self.restormer = base_restormer
        
        # EDAR modules
        self.edge_extractor = FastEdgeExtractor()
        self.domain_encoder = SimpleDomainEncoder(n_domains, embed_dim=64)
        self.edge_refine = LightweightEdgeRefinement(in_channels=96)
        self.domain_mod = DomainModulation(feat_channels=96, domain_dim=64)
        
        self.use_edar = True
        
    def forward(self, x, domain=None):
        """
        x: [B, 3, H, W], domain: str or None
        Returns: (output, edges)
        """
        edges = self.edge_extractor(x)
        
        domain_emb = None
        if domain is not None and self.use_edar:
            domain_emb = self.domain_encoder(domain)
        
        # Restormer encoder
        inp_enc = self.restormer.patch_embed(x)
        enc1 = self.restormer.encoder_level1(inp_enc)
        enc2 = self.restormer.encoder_level2(self.restormer.down1_2(enc1))
        enc3 = self.restormer.encoder_level3(self.restormer.down2_3(enc2))
        
        # Bottleneck
        latent = self.restormer.latent(self.restormer.down3_4(enc3))
        
        # Decoder
        dec3 = self.restormer.decoder_level3(
            self.restormer.reduce_chan_level3(
                torch.cat([self.restormer.up4_3(latent), enc3], 1)
            )
        )
        dec2 = self.restormer.decoder_level2(
            self.restormer.reduce_chan_level2(
                torch.cat([self.restormer.up3_2(dec3), enc2], 1)
            )
        )
        dec1 = self.restormer.decoder_level1(
            torch.cat([self.restormer.up2_1(dec2), enc1], 1)
        )
        
        # EDAR injection
        if domain_emb is not None:
            dec1 = self.domain_mod(dec1, domain_emb)
        
        refined = self.restormer.refinement(dec1)
        if self.use_edar:
            refined = self.edge_refine(refined, edges)
        
        output = self.restormer.output(refined)
        return output, edges
    
    def baseline_mode(self):
        self.use_edar = False
        
    def edar_mode(self):
        self.use_edar = True

print("‚úÖ EDAR modules loaded successfully!")

‚úÖ EDAR modules loaded successfully!


In [93]:
# ============================================================
# CELL 14: TEST EDAR WRAPPER
# ============================================================

device = "cuda" if torch.cuda.is_available() else "cpu"

# Wrap your existing model
edar_model = EDARWrapper(model, n_domains=3).to(device)
print(f"‚úÖ EDAR model created on {device}")

# Test forward pass
test_input = torch.randn(1, 3, 256, 256).to(device)
with torch.no_grad():
    test_out, test_edges = edar_model(test_input, domain='fog')
    
print(f"‚úÖ Forward pass successful!")
print(f"   Input shape:  {tuple(test_input.shape)}")
print(f"   Output shape: {tuple(test_out.shape)}")
print(f"   Edges shape:  {tuple(test_edges.shape)}")

# Count trainable parameters
edar_params = sum(p.numel() for p in edar_model.parameters() if p.requires_grad)
print(f"‚úÖ Trainable parameters: {edar_params:,}")

‚úÖ EDAR model created on cuda
‚úÖ Forward pass successful!
   Input shape:  (1, 3, 256, 256)
   Output shape: (1, 3, 256, 256)
   Edges shape:  (1, 1, 256, 256)
‚úÖ Trainable parameters: 505,292


In [94]:
CHECKPOINT_PATH = "/kaggle/input/novelty-epoch-19/edar_multidomain_ep19.pth"

In [101]:
# ============================================================
# CELL: Set CHECKPOINT_PATH here, smart-load checkpoint once, then upload image & run inference
# Paste this AFTER your EDAR model class + `edar_model` are defined.
# ============================================================
import os, re, torch, numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# ------------- CONFIG: set your epoch-19 .pth path here -------------
# Example Kaggle dataset path:
# CHECKPOINT_PATH = "/kaggle/input/novelty-epoch-19/edar_multidomain_ep19.pth"
# Or local working path if you uploaded to notebook files:
# CHECKPOINT_PATH = "/kaggle/working/edar_multidomain_ep19.pth"
try:
    CHECKPOINT_PATH  # if already set earlier, keep it
except NameError:
    CHECKPOINT_PATH = "/kaggle/input/novelty-epoch-19/edar_multidomain_ep19.pth"  # <-- update this if different

# ------------- device fallback -------------
try:
    device
except NameError:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")
print(f"Checkpoint path (will try to load now): {CHECKPOINT_PATH}")

# --------- Lightweight image utilities (same as used earlier) ----------
import cv2
def load_image_local(path, max_size=1024):
    img = cv2.imread(path)
    if img is None:
        raise FileNotFoundError(f"Image not found or unreadable: {path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    if max(h, w) > max_size:
        scale = max_size / max(h, w)
        img = cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LANCZOS4)
    return img

def img_to_tensor_local(img):
    img = img.astype(np.float32) / 255.0
    t = torch.from_numpy(np.transpose(img, (2,0,1))).unsqueeze(0)
    return t

def tensor_to_img_local(t):
    img = t.squeeze(0).permute(1,2,0).cpu().numpy()
    img = np.clip(img * 255, 0, 255).astype(np.uint8)
    return img

# === Replace/patch restore_image_local with padding-handling version ===
import math, cv2, numpy as np, torch
from PIL import Image

def pad_to_multiple(img_np, multiple=16):
    h, w = img_np.shape[:2]
    tgt_h = int(math.ceil(h / multiple) * multiple)
    tgt_w = int(math.ceil(w / multiple) * multiple)
    pad_h = tgt_h - h
    pad_w = tgt_w - w
    pad_top = pad_h // 2
    pad_bottom = pad_h - pad_top
    pad_left = pad_w // 2
    pad_right = pad_w - pad_left
    padded = np.pad(
        img_np,
        ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)),
        mode='reflect'
    )
    return padded, (pad_top, pad_bottom, pad_left, pad_right)

def unpad_img(img_np, pads):
    pad_top, pad_bottom, pad_left, pad_right = pads
    h, w = img_np.shape[:2]
    return img_np[pad_top:h-pad_bottom if pad_bottom>0 else h,
                  pad_left:w-pad_right if pad_right>0 else w]

def img_to_tensor_with_padding(img_np, multiple=16, device='cuda'):
    orig_h, orig_w = img_np.shape[:2]
    padded_np, pads = pad_to_multiple(img_np, multiple=multiple)
    tensor = torch.from_numpy(
        np.transpose(padded_np.astype(np.float32)/255.0, (2,0,1))
    ).unsqueeze(0).to(device)
    return tensor, pads, (orig_h, orig_w)

def tensor_to_img_and_unpad(tensor, pads, orig_size):
    img = tensor.squeeze(0).permute(1,2,0).cpu().numpy()
    img = np.clip(img * 255, 0, 255).astype(np.uint8)
    img = unpad_img(img, pads)
    img = cv2.resize(img, (orig_size[1], orig_size[0]), interpolation=cv2.INTER_LINEAR)
    return img

@torch.no_grad()
def restore_image_local(model, image_path, domain='fog', device='cuda', pad_multiple=16):
    img = cv2.imread(image_path)
    if img is None:
        raise FileNotFoundError(f"Image not found or unreadable: {image_path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    inp_tensor, pads, orig_size = img_to_tensor_with_padding(img, multiple=pad_multiple, device=device)

    model.to(device).eval()
    out_t, edges_t = model(inp_tensor, domain=domain)

    restored_img = tensor_to_img_and_unpad(out_t, pads, orig_size)

    try:
        if edges_t.dim() == 4 and edges_t.size(1) == 1:
            edges_t3 = edges_t.repeat(1,3,1,1)
        else:
            edges_t3 = edges_t
        edges_img = tensor_to_img_and_unpad(edges_t3, pads, orig_size)
    except:
        e = edges_t.squeeze().cpu().numpy()
        if e.ndim == 2:
            e = np.stack([e]*3, axis=-1)
        e = (np.clip(e,0,1)*255).astype(np.uint8)
        edges_img = unpad_img(e, pads)

    return img, restored_img, edges_img


# ---------- checkpoint helpers (smart .body handling) ----------
def pick_state_dict(ckpt):
    for k in ("edar_model","state_dict","model"):
        if isinstance(ckpt, dict) and k in ckpt:
            return ckpt[k]
    return ckpt

def strip_module(sd):
    return { (k[7:] if k.startswith("module.") else k): v for k,v in sd.items() }

def remove_body_dot(sd):
    new = {}
    for k,v in sd.items():
        nk = k.replace(".body.", ".")
        if nk.endswith(".body"):
            nk = nk[:-5]
        new[nk] = v
    return new

def add_body_dot_for_norms(sd):
    # insert ".body." between ".normN" and ".weight/.bias" if missing
    new = {}
    for k,v in sd.items():
        nk = re.sub(r"(\.norm\d+)(\.(?:weight|bias|gamma|beta))$", r"\1.body\2", k)
        new[nk] = v
    return new

def try_best_load(model, ckpt_path, map_location=None):
    """
    Try both removal and insertion of '.body.' norm naming and pick best by minimal mismatch.
    Loads chosen version into `model` and returns (chosen_label, info).
    """
    map_location = map_location or device
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint path not found: {ckpt_path}")
    raw = torch.load(ckpt_path, map_location=map_location)
    sd_raw = pick_state_dict(raw)
    sd_strip = strip_module(sd_raw)

    # Save original model weights (so we can test both variants without corruption)
    orig_state = {k: v.clone().cpu() for k, v in model.state_dict().items()}

    # Variant A: remove .body.
    sd_a = remove_body_dot(sd_strip)
    info_a = model.load_state_dict(sd_a, strict=False)
    score_a = len(info_a.missing_keys) + len(info_a.unexpected_keys)

    # restore original
    model.load_state_dict(orig_state, strict=False)

    # Variant B: add .body. for norm entries
    sd_b = add_body_dot_for_norms(sd_strip)
    info_b = model.load_state_dict(sd_b, strict=False)
    score_b = len(info_b.missing_keys) + len(info_b.unexpected_keys)

    # choose best
    if score_a <= score_b:
        model.load_state_dict(sd_a, strict=False)
        chosen = "remove .body. (A)"
        info = info_a
    else:
        model.load_state_dict(sd_b, strict=False)
        chosen = "add .body. for norms (B)"
        info = info_b

    return chosen, info

# ----------------- Try to load checkpoint NOW (once) -----------------
loaded_info = None
loaded_strategy = None
try:
    chosen, info = try_best_load(edar_model, CHECKPOINT_PATH, map_location=device)
    loaded_strategy, loaded_info = chosen, info
    print("Checkpoint load strategy:", chosen)
    print(f"Missing keys: {len(info.missing_keys)} | Unexpected keys: {len(info.unexpected_keys)}")
    if info.missing_keys:
        print("  Example missing keys (up to 12):")
        for k in info.missing_keys[:12]:
            print("   ", k)
    if info.unexpected_keys:
        print("  Example unexpected keys (up to 12):")
        for k in info.unexpected_keys[:12]:
            print("   ", k)
except Exception as e:
    print("‚ùå Checkpoint load failed:", e)
    import traceback; traceback.print_exc()

# ----------------- UI: upload image + run inference -----------------
img_upload = widgets.FileUpload(accept='image/*', multiple=False, description='Upload Image')
domain_selector = widgets.Dropdown(options=['fog','rain','lowlight'], value='fog', description='Domain:')
run_btn = widgets.Button(description='Run EDAR', button_style='success')
out = widgets.Output(layout={'border':'1px solid black'})

def save_uploaded_image(upload_widget, target_prefix='uploaded_input'):
    val = upload_widget.value
    if not val:
        return None
    # dict-style
    if isinstance(val, dict):
        fname = list(val.keys())[0]
        entry = val[fname]
        content = entry.get('content') if isinstance(entry, dict) else entry
        ext = os.path.splitext(fname)[1] or '.png'
        out_path = f'./{target_prefix}{ext}'
        with open(out_path, 'wb') as f:
            f.write(content)
        return out_path
    # list/tuple style
    if isinstance(val, (list, tuple)):
        entry = val[0]
        if isinstance(entry, dict) and 'name' in entry and 'content' in entry:
            fname = entry['name']; content = entry['content']
        elif isinstance(entry, (list, tuple)) and len(entry) >= 2:
            fname = entry[0]; content = entry[1]
        else:
            fname = getattr(entry, 'name', f'{target_prefix}.png'); content = entry if isinstance(entry, (bytes, bytearray)) else None
        if content is None:
            raise ValueError("Couldn't extract uploaded image bytes.")
        ext = os.path.splitext(fname)[1] or '.png'
        out_path = f'./{target_prefix}{ext}'
        with open(out_path, 'wb') as f:
            f.write(content)
        return out_path
    # raw bytes
    if isinstance(val, (bytes, bytearray)):
        out_path = f'./{target_prefix}.png'
        with open(out_path, 'wb') as f:
            f.write(val)
        return out_path
    raise ValueError("Unrecognized upload.value format: " + str(type(val)))

def on_run(b):
    with out:
        clear_output()
        # ensure checkpoint was loaded
        if loaded_info is None:
            print("‚ùå Checkpoint was not successfully loaded earlier. Fix CHECKPOINT_PATH and re-run this cell.")
            return

        # save uploaded image
        if not img_upload.value:
            print("‚ùå Please upload an input image.")
            return
        try:
            img_path = save_uploaded_image(img_upload)
        except Exception as e:
            print("‚ùå Failed to save uploaded image:", e)
            import traceback; traceback.print_exc()
            return

        print("üì• Uploaded image saved to:", img_path)
        print("‚ñ∂ Running inference (domain =", domain_selector.value, ") ...")

        try:
            edar_model.to(device).eval()
            inp, restored, edges = restore_image_local(edar_model, img_path, domain=domain_selector.value, device=device)
        except Exception:
            import traceback; traceback.print_exc()
            print("‚ùå Inference failed.")
            return

        # show results
        try:
            fig, ax = plt.subplots(1,3, figsize=(18,6))
            ax[0].imshow(inp); ax[0].axis('off'); ax[0].set_title(f"Input ({domain_selector.value})")
            ax[1].imshow(restored); ax[1].axis('off'); ax[1].set_title("Restored")
            ax[2].imshow(edges); ax[2].axis('off'); ax[2].set_title("Edge Map")
            plt.tight_layout()
            plt.show()
            Image.fromarray(restored).save("restored_output.png")
            fig.savefig("comparison_output.png", dpi=150, bbox_inches='tight')
            print("‚úÖ Saved: restored_output.png, comparison_output.png")
        except Exception:
            import traceback; traceback.print_exc()
            print("‚ùå Failed to display/save outputs.")

run_btn.on_click(on_run)

display(widgets.VBox([
    widgets.HTML("<h3>EDAR Tester ‚Äî checkpoint auto-loaded; upload input image to run</h3>"),
    widgets.HBox([img_upload, domain_selector, run_btn]),
    out
]))


Device: cuda
Checkpoint path (will try to load now): /kaggle/input/novelty-epoch-19/edar_multidomain_ep19.pth
Checkpoint load strategy: add .body. for norms (B)
Missing keys: 0 | Unexpected keys: 0


VBox(children=(HTML(value='<h3>EDAR Tester ‚Äî checkpoint auto-loaded; upload input image to run</h3>'), HBox(ch‚Ä¶

In [110]:
# ------------------------
# Gradio UI for EDAR (no input shown in outputs)
# Paste this after EDAR class definitions and after edar_model is created.
# Edit CHECKPOINT_PATH below to point to your epoch-19 .pth file.
# ------------------------
import os, re, math, io, torch, numpy as np
from PIL import Image
import cv2
import gradio as gr

# ---------------- CONFIG: set your checkpoint path here ----------------
try:
    CHECKPOINT_PATH  # keep existing if already set
except NameError:
    CHECKPOINT_PATH = "/kaggle/input/novelty-epoch-19/edar_multidomain_ep19.pth"

# ---------------- device ----------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
print("Checkpoint:", CHECKPOINT_PATH)

# ----------------- sanity check for edar_model -----------------
try:
    edar_model
except NameError:
    raise RuntimeError("`edar_model` is not defined. Create EDARWrapper(base_restormer) and assign to edar_model before running this cell.")

# ----------------- Utilities: padding + tensor conversions -----------------
def pad_to_multiple(img_np, multiple=16):
    h, w = img_np.shape[:2]
    tgt_h = int(math.ceil(h / multiple) * multiple)
    tgt_w = int(math.ceil(w / multiple) * multiple)
    pad_h = tgt_h - h
    pad_w = tgt_w - w
    pad_top = pad_h // 2
    pad_bottom = pad_h - pad_top
    pad_left = pad_w // 2
    pad_right = pad_w - pad_left
    padded = np.pad(img_np, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='reflect')
    return padded, (pad_top, pad_bottom, pad_left, pad_right)

def unpad_img(img_np, pads):
    pad_top, pad_bottom, pad_left, pad_right = pads
    h, w = img_np.shape[:2]
    top = pad_top
    bottom = h - pad_bottom if pad_bottom>0 else h
    left = pad_left
    right = w - pad_right if pad_right>0 else w
    return img_np[top:bottom, left:right]

def pil_to_rgb_np(pil_img):
    arr = np.array(pil_img.convert("RGB"))
    return arr

def img_to_tensor_with_padding(img_np, multiple=16, device='cuda'):
    orig_h, orig_w = img_np.shape[:2]
    padded_np, pads = pad_to_multiple(img_np, multiple=multiple)
    tensor = torch.from_numpy(np.transpose(padded_np.astype(np.float32)/255.0, (2,0,1))).unsqueeze(0).to(device)
    return tensor, pads, (orig_h, orig_w)

def tensor_to_img_and_unpad(tensor, pads, orig_size):
    img = tensor.squeeze(0).permute(1,2,0).cpu().numpy()
    img = np.clip(img * 255, 0, 255).astype(np.uint8)
    img = unpad_img(img, pads)
    img = cv2.resize(img, (orig_size[1], orig_size[0]), interpolation=cv2.INTER_LINEAR)
    return img

# ----------------- smart checkpoint loader (try both variants) -----------------
def pick_state_dict(ckpt):
    for k in ("edar_model","state_dict","model"):
        if isinstance(ckpt, dict) and k in ckpt:
            return ckpt[k]
    return ckpt

def strip_module(sd):
    return { (k[7:] if k.startswith("module.") else k): v for k,v in sd.items() }

def remove_body_dot(sd):
    new = {}
    for k,v in sd.items():
        nk = k.replace(".body.", ".")
        if nk.endswith(".body"):
            nk = nk[:-5]
        new[nk] = v
    return new

def add_body_dot_for_norms(sd):
    new = {}
    for k,v in sd.items():
        nk = re.sub(r"(\.norm\d+)(\.(?:weight|bias|gamma|beta))$", r"\1.body\2", k)
        new[nk] = v
    return new

def try_best_load(model, ckpt_path, map_location=None):
    map_location = map_location or device
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    raw = torch.load(ckpt_path, map_location=map_location)
    sd_raw = pick_state_dict(raw)
    sd_strip = strip_module(sd_raw)

    # save original model state (cpu copy)
    orig_state = {k: v.clone().cpu() for k,v in model.state_dict().items()}

    # Variant A
    sd_a = remove_body_dot(sd_strip)
    info_a = model.load_state_dict(sd_a, strict=False)
    model.load_state_dict(orig_state, strict=False)

    # Variant B
    sd_b = add_body_dot_for_norms(sd_strip)
    info_b = model.load_state_dict(sd_b, strict=False)
    # choose best by fewer mismatches
    score_a = len(info_a.missing_keys) + len(info_a.unexpected_keys)
    score_b = len(info_b.missing_keys) + len(info_b.unexpected_keys)
    if score_a <= score_b:
        model.load_state_dict(sd_a, strict=False)
        chosen, info = "remove .body. (A)", info_a
    else:
        model.load_state_dict(sd_b, strict=False)
        chosen, info = "add .body. for norms (B)", info_b
    return chosen, info

# ----------------- Load checkpoint ONCE -----------------
print("Loading checkpoint (smart)...")
try:
    chosen, info = try_best_load(edar_model, CHECKPOINT_PATH, map_location=device)
    print("Loaded checkpoint with strategy:", chosen)
    print(f"Missing keys: {len(info.missing_keys)} | Unexpected keys: {len(info.unexpected_keys)}")
except Exception as e:
    raise RuntimeError(f"Failed to load checkpoint: {e}")

# ----------------- Inference wrapper used by Gradio (returns only restored + edges) -----------------
@torch.no_grad()
def gradio_infer(pil_img, domain="fog"):
    if pil_img is None:
        return None, None
    inp_np = pil_to_rgb_np(pil_img)
    inp_tensor, pads, orig_size = img_to_tensor_with_padding(inp_np, multiple=16, device=device)
    edar_model.to(device).eval()
    out_t, edges_t = edar_model(inp_tensor, domain=domain)
    restored_np = tensor_to_img_and_unpad(out_t, pads, orig_size)
    if edges_t.dim() == 4 and edges_t.size(1) == 1:
        edges_t3 = edges_t.repeat(1,3,1,1)
    else:
        edges_t3 = edges_t
    edges_np = tensor_to_img_and_unpad(edges_t3, pads, orig_size)
    restored_pil = Image.fromarray(restored_np)
    edges_pil = Image.fromarray(edges_np)
    return restored_pil, edges_pil

# ----------------- Build Gradio Interface (no input output) -----------------
demo = gr.Interface(
    fn=gradio_infer,
    inputs=[
        gr.Image(label="Upload Input Image", type="pil"),
        gr.Dropdown(choices=["fog","rain","lowlight"], value="fog", label="Domain")
    ],
    outputs=[
        gr.Image(type="pil", label="EDAR Restored"),
        gr.Image(type="pil", label="Edge Map")
    ],
    title="EDAR ‚Äî Image Restoration (multi-domain)",
    description="Upload an image, pick domain (fog/rain/lowlight). The model is loaded from CHECKPOINT_PATH set in the cell.",
    allow_flagging="never",
    examples=None
)

# ----------------- Launch Gradio -----------------
print("Launching Gradio app... (press stop in the notebook to kill)")
demo.launch(share=True, server_name="0.0.0.0", server_port=7868)


Device: cuda
Checkpoint: /kaggle/input/novelty-epoch-19/edar_multidomain_ep19.pth
Loading checkpoint (smart)...
Loaded checkpoint with strategy: add .body. for norms (B)
Missing keys: 0 | Unexpected keys: 0
Launching Gradio app... (press stop in the notebook to kill)




* Running on local URL:  http://0.0.0.0:7868
* Running on public URL: https://626d2a5197f8dfe158.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [79]:
# # ============================================================
# # CELL 15 (UPDATED): TRAINING FUNCTION WITH RESUME SUPPORT
# # ============================================================

# from tqdm import tqdm
# import os
# import torch.nn as nn
# import torchvision
# from itertools import cycle
# import matplotlib.pyplot as plt
# from IPython.display import display
# import numpy as np

# def train_edar_multidomain_with_gan(
#     edar_model,
#     discriminator,
#     vgg_model,
#     train_loaders,
#     val_loaders,
#     epochs=11,
#     device='cuda',
#     save_dir='./edar_multidomain_checkpoints',
#     start_epoch=9,  # NEW: Start from this epoch
#     resume_checkpoint="/kaggle/input/novelty-epoch-8/edar_multidomain_ep08.pth"  # NEW: Path to checkpoint to resume from
# ):
#     """
#     Train EDAR on all 3 domains with support for resuming
    
#     Args:
#         start_epoch: Epoch number to start from (default: 1)
#         resume_checkpoint: Path to checkpoint file to resume from (optional)
#     """
    
#     os.makedirs(save_dir, exist_ok=True)
#     os.makedirs('./edar_multidomain_samples', exist_ok=True)
#     for domain in ['fog', 'rain', 'lowlight']:
#         os.makedirs(f'./edar_multidomain_samples/{domain}', exist_ok=True)
    
#     # Setup
#     for param in edar_model.restormer.parameters():
#         param.requires_grad = False
#     for param in edar_model.restormer.refinement.parameters():
#         param.requires_grad = True
#     for param in edar_model.restormer.output.parameters():
#         param.requires_grad = True
    
#     gen_params = (
#         list(edar_model.edge_extractor.parameters()) +
#         list(edar_model.domain_encoder.parameters()) +
#         list(edar_model.edge_refine.parameters()) +
#         list(edar_model.domain_mod.parameters()) +
#         list(edar_model.restormer.refinement.parameters()) +
#         list(edar_model.restormer.output.parameters())
#     )
    
#     opt_G = torch.optim.Adam(gen_params, lr=1e-5, betas=(0.9, 0.999))
#     opt_D = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    
#     # ===== RESUME FROM CHECKPOINT =====
#     if resume_checkpoint is not None and os.path.exists(resume_checkpoint):
#         print(f"üìÇ Loading checkpoint: {resume_checkpoint}")
#         checkpoint = torch.load(resume_checkpoint, map_location=device)
        
#         edar_model.load_state_dict(checkpoint['edar_model'])
#         discriminator.load_state_dict(checkpoint['discriminator'])
        
#         # Load optimizer states if available
#         if 'opt_G' in checkpoint:
#             opt_G.load_state_dict(checkpoint['opt_G'])
#             print("   ‚úÖ Loaded Generator optimizer state")
#         if 'opt_D' in checkpoint:
#             opt_D.load_state_dict(checkpoint['opt_D'])
#             print("   ‚úÖ Loaded Discriminator optimizer state")
        
#         print(f"   ‚úÖ Resumed from epoch {checkpoint.get('epoch', start_epoch-1)}")
#         print()
    
#     l1_loss = nn.L1Loss()
#     mse_loss = nn.MSELoss()
    
#     def adv_loss(pred, is_real):
#         target = torch.ones_like(pred) if is_real else torch.zeros_like(pred)
#         return mse_loss(pred, target)
    
#     L_ADV, L_PERC, L_ID, L_EDGE = 1.0, 0.1, 0.1, 0.2
    
#     scaler_G = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
#     scaler_D = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
    
#     # Round-robin iterator
#     def round_robin_batches(loaders_dict):
#         iters = {k: cycle(v) for k, v in loaders_dict.items()}
#         domains = list(loaders_dict.keys())
#         while True:
#             for domain in domains:
#                 yield domain, next(iters[domain])
    
#     steps_per_epoch = min(len(dl) for dl in train_loaders.values()) * len(train_loaders)
#     rr_iter = round_robin_batches(train_loaders)
    
#     if start_epoch == 1:
#         print(f"Steps per epoch: {steps_per_epoch} (balanced across domains)")
#     else:
#         print(f"Resuming training: Epochs {start_epoch} ‚Üí {start_epoch + epochs - 1}")
#         print(f"Steps per epoch: {steps_per_epoch}")
#     print()
    
#     # Training loop
#     best_loss = float('inf')
    
#     # Calculate end epoch
#     end_epoch = start_epoch + epochs
    
#     for epoch in range(start_epoch, end_epoch):
#         edar_model.train()
#         discriminator.train()
        
#         epoch_losses = {
#             'G_total': [], 'G_adv': [], 'G_perc': [], 
#             'G_id': [], 'G_edge': [], 'D': [],
#             'fog': [], 'rain': [], 'lowlight': []
#         }
        
#         pbar = tqdm(range(steps_per_epoch), desc=f'Epoch {epoch}/{end_epoch-1}')
        
#         for step in pbar:
#             domain, (degraded, clean, _) = next(rr_iter)
#             degraded, clean = degraded.to(device), clean.to(device)
            
#             # ========== GENERATOR STEP ==========
#             opt_G.zero_grad(set_to_none=True)
            
#             with torch.cuda.amp.autocast(enabled=(device=="cuda")):
#                 fake, edges_pred = edar_model(degraded, domain=domain)
#                 edges_target = edar_model.edge_extractor(clean)
                
#                 loss_g_adv = adv_loss(discriminator(fake), True) * L_ADV
#                 loss_g_perc = l1_loss(vgg_model(fake), vgg_model(clean)) * L_PERC
                
#                 identity, _ = edar_model(clean, domain=domain)
#                 loss_g_id = l1_loss(identity, clean) * L_ID
                
#                 loss_g_edge = l1_loss(edges_pred, edges_target) * L_EDGE
                
#                 loss_G = loss_g_adv + loss_g_perc + loss_g_id + loss_g_edge
            
#             scaler_G.scale(loss_G).backward()
#             scaler_G.step(opt_G)
#             scaler_G.update()
            
#             # ========== DISCRIMINATOR STEP ==========
#             opt_D.zero_grad(set_to_none=True)
            
#             with torch.cuda.amp.autocast(enabled=(device=="cuda")):
#                 pred_real = discriminator(clean)
#                 pred_fake = discriminator(fake.detach())
#                 loss_D = 0.5 * (adv_loss(pred_real, True) + adv_loss(pred_fake, False))
            
#             scaler_D.scale(loss_D).backward()
#             scaler_D.step(opt_D)
#             scaler_D.update()
            
#             # Track losses
#             epoch_losses['G_total'].append(loss_G.item())
#             epoch_losses['G_adv'].append(loss_g_adv.item())
#             epoch_losses['G_perc'].append(loss_g_perc.item())
#             epoch_losses['G_id'].append(loss_g_id.item())
#             epoch_losses['G_edge'].append(loss_g_edge.item())
#             epoch_losses['D'].append(loss_D.item())
#             epoch_losses[domain].append(loss_G.item())
            
#             pbar.set_postfix({
#                 'G': f'{loss_G.item():.3f}',
#                 'D': f'{loss_D.item():.3f}',
#                 'Edge': f'{loss_g_edge.item():.3f}',
#                 'dom': domain
#             })
        
#         # ========== EPOCH SUMMARY ==========
#         print(f"\n{'='*60}")
#         print(f"Epoch {epoch}/{end_epoch-1} Summary:")
#         print(f"  Overall Generator Loss: {np.mean(epoch_losses['G_total']):.4f}")
#         print(f"    ‚îú‚îÄ Adversarial: {np.mean(epoch_losses['G_adv']):.4f}")
#         print(f"    ‚îú‚îÄ Perceptual:  {np.mean(epoch_losses['G_perc']):.4f}")
#         print(f"    ‚îú‚îÄ Identity:    {np.mean(epoch_losses['G_id']):.4f}")
#         print(f"    ‚îî‚îÄ Edge (NEW):  {np.mean(epoch_losses['G_edge']):.4f}")
#         print(f"  Discriminator Loss:     {np.mean(epoch_losses['D']):.4f}")
#         print(f"\n  Per-Domain Generator Losses:")
#         print(f"    üå´Ô∏è  Fog:      {np.mean(epoch_losses['fog']):.4f}")
#         print(f"    üåßÔ∏è  Rain:     {np.mean(epoch_losses['rain']):.4f}")
#         print(f"    üåô Lowlight: {np.mean(epoch_losses['lowlight']):.4f}")
#         print(f"{'='*60}\n")
        
#         # ========== SAVE & DISPLAY SAMPLES ==========
#         print("Generating preview samples...")
#         sample_images = save_and_collect_samples(
#             edar_model, val_loaders, epoch, device
#         )
        
#         # Display inline
#         display_samples_inline(sample_images, epoch)
        
#         # ========== CHECKPOINTING ==========
#         avg_G = np.mean(epoch_losses['G_total'])
#         ckpt_path = f'{save_dir}/edar_multidomain_ep{epoch:02d}.pth'
#         torch.save({
#             'epoch': epoch,
#             'edar_model': edar_model.state_dict(),
#             'discriminator': discriminator.state_dict(),
#             'opt_G': opt_G.state_dict(),
#             'opt_D': opt_D.state_dict(),
#         }, ckpt_path)
        
#         if avg_G < best_loss:
#             best_loss = avg_G
#             best_path = f'{save_dir}/edar_multidomain_best.pth'
#             torch.save(edar_model.state_dict(), best_path)
#             print(f"‚úÖ Best model saved: {best_path}\n")
    
#     print(f"\n{'='*60}")
#     print("‚úÖ TRAINING COMPLETE!")
#     print(f"   Trained epochs: {start_epoch} ‚Üí {end_epoch-1}")
#     print(f"{'='*60}")
#     return edar_model


# def save_and_collect_samples(model, val_loaders, epoch, device):
#     """Save samples for all domains and return them for display"""
#     import torchvision.utils as vutils
    
#     model.eval()
#     sample_images = {}
    
#     with torch.no_grad():
#         for domain in ['fog', 'rain', 'lowlight']:
#             domain_samples = []
            
#             for idx, (degraded, clean, name) in enumerate(val_loaders[domain]):
#                 if idx >= 2:
#                     break
                
#                 degraded = degraded.to(device)
#                 clean = clean.to(device)
                
#                 restored, edges = model(degraded, domain=domain)
#                 restored = torch.clamp(restored, 0, 1)
                
#                 grid = torch.cat([
#                     degraded[0:1],
#                     restored[0:1],
#                     clean[0:1],
#                     edges[0:1].repeat(1, 3, 1, 1)
#                 ], dim=3)
                
#                 save_path = f'./edar_multidomain_samples/{domain}/ep{epoch:02d}_sample{idx}.png'
#                 vutils.save_image(grid, save_path, normalize=False)
                
#                 grid_np = grid[0].permute(1, 2, 0).cpu().numpy()
#                 domain_samples.append(grid_np)
            
#             sample_images[domain] = domain_samples
    
#     model.train()
#     return sample_images


# def display_samples_inline(sample_images, epoch):
#     """Display samples inline in notebook"""
#     fig, axes = plt.subplots(3, 2, figsize=(20, 15))
    
#     domains = ['fog', 'rain', 'lowlight']
#     domain_icons = {'fog': 'üå´Ô∏è', 'rain': 'üåßÔ∏è', 'lowlight': 'üåô'}
    
#     for row_idx, domain in enumerate(domains):
#         samples = sample_images[domain]
        
#         for col_idx, sample in enumerate(samples):
#             ax = axes[row_idx, col_idx]
#             ax.imshow(sample)
#             ax.axis('off')
            
#             if col_idx == 0:
#                 ax.set_title(
#                     f"{domain_icons[domain]} {domain.upper()} - Sample {col_idx+1}\n"
#                     f"[Degraded | Restored | Clean | Edges]",
#                     fontsize=12, fontweight='bold', pad=10
#                 )
#             else:
#                 ax.set_title(
#                     f"Sample {col_idx+1}\n[Degraded | Restored | Clean | Edges]",
#                     fontsize=12, fontweight='bold', pad=10
#                 )
    
#     plt.suptitle(
#         f'EPOCH {epoch} - EDAR Multi-Domain Results',
#         fontsize=16, fontweight='bold', y=0.995
#     )
#     plt.tight_layout()
#     plt.show()
    
#     print(f"‚úÖ Samples saved to: ./edar_multidomain_samples/[domain]/ep{epoch:02d}_sampleX.png\n")


# print("‚úÖ Updated training function with resume support ready!")

In [80]:
# # ============================================================
# # CELL 16: SETUP DISCRIMINATOR + VGG (Your Working Setup)
# # ============================================================

# import torch.nn as nn
# import torchvision

# # ===== DISCRIMINATOR (Your PatchGAN) =====
# class PatchD(nn.Module):
#     def __init__(self, in_ch=3, base=64):
#         super().__init__()
#         def blk(ic, oc, norm=True):
#             m=[nn.Conv2d(ic, oc, 4, 2, 1)]
#             if norm: m+=[nn.InstanceNorm2d(oc, affine=True)]
#             m+=[nn.LeakyReLU(0.2, inplace=True)]
#             return nn.Sequential(*m)
#         self.net = nn.Sequential(
#             blk(in_ch, base, norm=False),
#             blk(base, base*2),
#             blk(base*2, base*4),
#             nn.Conv2d(base*4, 1, 3, 1, 1)
#         )
#     def forward(self, x): return self.net(x)

# disc = PatchD().to(device).train()
# print("‚úÖ Discriminator loaded")

# # ===== VGG FOR PERCEPTUAL LOSS =====
# vgg = torchvision.models.vgg19(
#     weights=torchvision.models.VGG19_Weights.IMAGENET1K_V1
# ).features[:16].eval().to(device)

# for p in vgg.parameters(): 
#     p.requires_grad = False

# print("‚úÖ VGG-19 loaded for perceptual loss")

In [81]:
# # ============================================================
# # CELL 17: TRAIN MULTI-DOMAIN EDAR
# # ============================================================

# print("="*60)
# print("üåê TRAINING EDAR ACROSS ALL DOMAINS (FOG + RAIN + LOWLIGHT)")
# print("="*60)

# # Wrap model
# if 'edar_model' not in locals():
#     edar_model = EDARWrapper(model, n_domains=3).to(device)

# # Train on ALL domains
# edar_multidomain = train_edar_multidomain_with_gan(
#     edar_model=edar_model,
#     discriminator=disc,
#     vgg_model=vgg,
#     train_loaders=train_loaders,  # Pass all 3 loaders!
#     val_loaders=val_loaders,
#     epochs=11,
#     device=device,
#     save_dir='./edar_multidomain_checkpoints'
# )

# print("\n‚úÖ MULTI-DOMAIN TRAINING COMPLETE!")
# print("Model can now handle fog, rain, AND lowlight with domain adaptation!")