In [None]:
# %% SSL4EO + SwinIR-style SR with physics-based loss
import os, math, logging, tarfile, random
from glob import glob

import numpy as np
import rasterio
from rasterio.warp import Resampling
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm

# -----------------------
# Directories & constants
# -----------------------
SSL4EO_DIR   = "data_ssl4eo"
SCENES_ROOT  = os.path.join(SSL4EO_DIR, "scenes")   # recursive search for all_bands.tif
MODELS_DIR   = "models"

# Hooks for future pretraining (ECOSTRESS / HLS) – NOT used here, just kept open
ECOSTRESS_DIR    = "data_ecostress"
PROCESSED_DIR    = "data_processed"
RAW_DIR          = "data_raw"
ECO_BEST         = os.path.join(MODELS_DIR, "ecostress_pretrained_best.pth")
ECO_LAST         = os.path.join(MODELS_DIR, "ecostress_pretrained_last.pth")

os.makedirs(SSL4EO_DIR,  exist_ok=True)
os.makedirs(SCENES_ROOT, exist_ok=True)
os.makedirs(MODELS_DIR,  exist_ok=True)

# Training / patch config
UPSCALE        = 2         # 2× SR here; can change to 4 with HR/LR logic consistent
HR_PATCH       = 128       # HR patch size
LR_PATCH       = HR_PATCH // UPSCALE
BATCH_SIZE     = 4
NUM_EPOCHS     = 50
LEARNING_RATE  = 1e-4
PHYS_LAMBDA    = 0.1       # weight for physics-based LR consistency loss

# Patches per scene per epoch
PATCHES_PER_SCENE_TRAIN = 4
PATCHES_PER_SCENE_VAL   = 2
PATCHES_PER_SCENE_TEST  = 2

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s:%(name)s: %(message)s"
)
logger = logging.getLogger("infranova_ssl4eo_swinir")

# SSL4EO archive URL (same as before)
SSL4EO_URL = (
    "https://huggingface.co/datasets/torchgeo/ssl4eo_l_benchmark/resolve/main/"
    "ssl4eo_l_oli_tirs_toa_benchmark.tar.gz?download=true"
)

# We only use these bands from all_bands.tif
BAND_IDX = {
    "B2": 2,    # Blue
    "B3": 3,    # Green
    "B4": 4,    # Red
    "B10": 10,  # Thermal IR 1
    "B11": 11   # Thermal IR 2
}

# -----------------------
# Utils: normalization & metrics
# -----------------------
def norm_np(a: np.ndarray) -> np.ndarray:
    """
    Per-band min-max normalization to [0,1] with NaN/Inf protection.
    """
    a = np.array(a, dtype=np.float32)
    if np.isnan(a).any() or np.isinf(a).any():
        a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
    mn = float(np.nanmin(a))
    mx = float(np.nanmax(a))
    if mx - mn < 1e-6:
        return np.zeros_like(a, dtype=np.float32)
    return ((a - mn) / (mx - mn)).astype(np.float32)


def compute_metrics(pred: np.ndarray, target: np.ndarray):
    """
    PSNR / SSIM / RMSE on [0,1] normalized arrays.
    In SIH terms, RMSE can be interpreted as normalized Kelvin error if we map
    back to physical K range later.
    """
    pred   = np.nan_to_num(pred,   nan=0.0, posinf=1.0, neginf=0.0)
    target = np.nan_to_num(target, nan=0.0, posinf=1.0, neginf=0.0)
    mse = float(np.mean((pred - target) ** 2))
    if not np.isfinite(mse) or mse < 1e-12:
        psnr_val = 100.0
        rmse_val = 0.0
    else:
        psnr_val = 10 * math.log10(1.0 / mse)
        rmse_val = math.sqrt(mse)
    try:
        ssim_val = ssim(target, pred, data_range=1.0)
    except Exception:
        ssim_val = 0.0
    return psnr_val, ssim_val, rmse_val

# -----------------------
# SSL4EO discovery / download
# -----------------------
def discover_ssl4eo_scenes(root=SCENES_ROOT):
    """
    Recursively find all scenes that contain `all_bands.tif`.
    """
    pattern = os.path.join(root, "**", "all_bands.tif")
    scene_files = sorted(glob(pattern, recursive=True))
    logger.info(f"Discovered {len(scene_files)} SSL4EO scenes with all_bands.tif")
    return scene_files


def maybe_download_ssl4eo(archive_dir=SSL4EO_DIR, scenes_root=SCENES_ROOT):
    """
    If no scenes found, download the SSL4EO benchmark archive and extract.
    If already present, return quickly.
    """
    scenes = discover_ssl4eo_scenes(scenes_root)
    if len(scenes) > 0:
        return scenes

    logger.info("No SSL4EO scenes found. Attempting to download archive...")
    os.makedirs(archive_dir, exist_ok=True)
    archive_path = os.path.join(archive_dir, "ssl4eo_benchmark.tar.gz")

    if not os.path.exists(archive_path):
        import requests
        with requests.get(SSL4EO_URL, stream=True) as r:
            r.raise_for_status()
            total = int(r.headers.get("content-length", 0))
            logger.info(f"Downloading SSL4EO archive ({total/1e9:.2f} GB approx)...")
            with open(archive_path, "wb") as f, tqdm(
                total=total, unit="B", unit_scale=True, desc="ssl4eo_download"
            ) as pbar:
                for chunk in r.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
                        pbar.update(len(chunk))
    else:
        logger.info(f"Found existing archive: {archive_path}")

    logger.info(f"Extracting {archive_path} -> {scenes_root}")
    os.makedirs(scenes_root, exist_ok=True)
    with tarfile.open(archive_path, "r:gz") as tf:
        tf.extractall(path=scenes_root)
    logger.info("Extraction finished.")
    return discover_ssl4eo_scenes(scenes_root)

# -----------------------
# Dataset: SSL4EOPatchDataset
# -----------------------
class SSL4EOPatchDataset(Dataset):
    """
    For each scene:
      - Loads all_bands.tif
      - Uses B2/B3/B4 as HR optical guidance (3 channels)
      - Uses either B10 or B11 (chosen randomly) as HR thermal "truth"
      - Synthesizes LR thermal by downsampling HR thermal with factor UPSCALE
      - Returns aligned patches:
            lr_thermal (1, LR, LR),
            hr_rgb     (3, HR, HR),
            hr_thermal (1, HR, HR)
    Each scene contributes `patches_per_scene` random patches per epoch.
    """
    def __init__(self, scene_files, hr_patch=HR_PATCH, upscale=UPSCALE,
                 patches_per_scene=4, mode="train"):
        super().__init__()
        self.scene_files = list(scene_files)
        self.hr_patch = hr_patch
        self.lr_patch = hr_patch // upscale
        self.upscale = upscale
        self.patches_per_scene = patches_per_scene
        self.mode = mode

    def __len__(self):
        # Each scene contributes 'patches_per_scene' patches per epoch
        return len(self.scene_files) * self.patches_per_scene

    def _read_bands(self, scene_path):
        with rasterio.open(scene_path) as src:
            # We read only 5 bands: B2,B3,B4,B10,B11
            bands = src.read([
                BAND_IDX["B2"], BAND_IDX["B3"], BAND_IDX["B4"],
                BAND_IDX["B10"], BAND_IDX["B11"]
            ]).astype(np.float32)  # shape: (5, H, W)
        return bands

    def __getitem__(self, idx):
        # Map global index to scene index
        scene_idx = idx // self.patches_per_scene
        scene_path = self.scene_files[scene_idx]

        bands = self._read_bands(scene_path)    # (5, H, W)
        rgb = bands[0:3, :, :]                  # (3, H, W)  -> B2,B3,B4
        t10 = bands[3, :, :]                    # (H, W)
        t11 = bands[4, :, :]                    # (H, W)

        # Randomly choose B10 or B11 as HR thermal target
        if random.random() < 0.5:
            thermal_hr = t10
        else:
            thermal_hr = t11

        # Normalize optical and thermal
        rgb_n = np.stack([norm_np(rgb[c]) for c in range(3)], axis=0)  # (3,H,W)
        thr_n = norm_np(thermal_hr)                                    # (H,W)

        H, W = thr_n.shape

        # Pad small scenes if needed
        if H < self.hr_patch or W < self.hr_patch:
            pad_y = max(0, self.hr_patch - H)
            pad_x = max(0, self.hr_patch - W)
            thr_n = np.pad(thr_n, ((0, pad_y), (0, pad_x)), mode='reflect')
            rgb_n = np.pad(rgb_n, ((0, 0), (0, pad_y), (0, pad_x)), mode='reflect')
            H, W = thr_n.shape

        # Synthesize LR thermal by downsampling HR thermal
        H_lr, W_lr = H // self.upscale, W // self.upscale
        lr_full = F.interpolate(
            torch.from_numpy(thr_n).unsqueeze(0).unsqueeze(0).float(),  # (1,1,H,W)
            size=(H_lr, W_lr),
            mode="bilinear",
            align_corners=False
        ).squeeze().numpy()  # (H_lr, W_lr)

        # Random HR crop
        max_y = H - self.hr_patch
        max_x = W - self.hr_patch
        if max_y <= 0 or max_x <= 0:
            y = 0
            x = 0
        else:
            y = np.random.randint(0, max_y + 1)
            x = np.random.randint(0, max_x + 1)

        # HR crops
        hr_t_patch   = thr_n[y:y + self.hr_patch, x:x + self.hr_patch]          # (HR, HR)
        hr_rgb_patch = rgb_n[:, y:y + self.hr_patch, x:x + self.hr_patch]       # (3,HR,HR)

        # LR aligned crop
        ly, lx = y // self.upscale, x // self.upscale
        lr_t_patch = lr_full[ly:ly + self.lr_patch, lx:lx + self.lr_patch]      # (LR, LR)

        # Safety cropping to ensure exact sizes
        if hr_t_patch.shape != (self.hr_patch, self.hr_patch):
            hr_t_patch   = hr_t_patch[:self.hr_patch, :self.hr_patch]
            hr_rgb_patch = hr_rgb_patch[:, :self.hr_patch, :self.hr_patch]
        if lr_t_patch.shape != (self.lr_patch, self.lr_patch):
            lr_t_patch = lr_t_patch[:self.lr_patch, :self.lr_patch]

        # To tensors
        lr_t   = torch.from_numpy(lr_t_patch).unsqueeze(0).float()  # (1, LR, LR)
        hr_rgb = torch.from_numpy(hr_rgb_patch).float()             # (3, HR, HR)
        hr_t   = torch.from_numpy(hr_t_patch).unsqueeze(0).float()  # (1, HR, HR)

        return lr_t, hr_rgb, hr_t

# -----------------------
# SwinIR-style model (simplified)
# -----------------------
# This is a lightweight, SwinIR-inspired architecture:
#  - early fusion of upsampled LR thermal + HR RGB  -> 4 channels
#  - patch-embedding conv
#  - several "Swin-style" residual blocks using conv + windowed self-attention-ish
#  - final conv to 1-channel HR thermal

class PatchEmbed(nn.Module):
    def __init__(self, in_channels=4, embed_dim=96, patch_size=1):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, padding=0)
    def forward(self, x):
        return self.proj(x)


class SimpleWindowAttention(nn.Module):
    """
    Very simplified stand-in for Swin window attention:
    - Flattens spatial dims -> sequence
    - Applies MultiheadAttention globally (not true shifted-window Swin, but similar idea)
    """
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        """
        x: (B, C, H, W)
        Flatten -> (B, H*W, C), apply attention, reshape back.
        """
        B, C, H, W = x.shape
        x_perm = x.view(B, C, H * W).permute(0, 2, 1)  # (B, N, C)
        x_norm = self.norm(x_perm)
        out, _ = self.attn(x_norm, x_norm, x_norm)     # (B, N, C)
        out = out.permute(0, 2, 1).view(B, C, H, W)
        return out


class SwinIRBlock(nn.Module):
    """
    SwinIR-style residual block:
      - window/global attention
      - MLP (Conv-based)
      - residual connections
    """
    def __init__(self, dim, mlp_ratio=2.0, num_heads=4):
        super().__init__()
        self.attn = SimpleWindowAttention(dim, num_heads=num_heads)

        hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Conv2d(dim, hidden_dim, kernel_size=1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, dim, kernel_size=1),
        )

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        B, C, H, W = x.shape

        # Norm over channels for attention path
        x_perm = x.view(B, C, H * W).permute(0, 2, 1)  # (B,N,C)
        x_norm = self.norm1(x_perm).permute(0, 2, 1).view(B, C, H, W)

        # Attention residual
        attn_out = self.attn(x_norm)
        x = x + attn_out

        # Norm for MLP path
        x_perm2 = x.view(B, C, H * W).permute(0, 2, 1)
        x_norm2 = self.norm2(x_perm2).permute(0, 2, 1).view(B, C, H, W)

        # MLP residual
        mlp_out = self.mlp(x_norm2)
        x = x + mlp_out
        return x


class SwinIRFusionNet(nn.Module):
    """
    SSL4EO thermal SR with SwinIR-style backbone.
    Inputs:
      - xT_lr: (B,1,LR,LR)    low-res thermal
      - xO_hr: (B,3,HR,HR)    high-res optical (RGB)
    Steps:
      1. Upsample xT_lr to HR
      2. Concatenate with RGB -> (B,4,HR,HR)
      3. Patch embedding conv -> feature maps
      4. Several SwinIRBlocks
      5. Final conv -> 1-channel HR thermal
    """
    def __init__(self, embed_dim=96, num_blocks=6, upscale=UPSCALE):
        super().__init__()
        self.upscale = upscale

        # Early upsample for LR thermal to HR grid
        self.thermal_upsample = nn.Upsample(scale_factor=upscale, mode="bilinear", align_corners=False)

        # 4-channel fusion input (1 thermal + 3 RGB)
        self.patch_embed = PatchEmbed(in_channels=4, embed_dim=embed_dim, patch_size=1)

        # SwinIR-style blocks
        blocks = []
        for _ in range(num_blocks):
            blocks.append(SwinIRBlock(dim=embed_dim, mlp_ratio=2.0, num_heads=4))
        self.blocks = nn.Sequential(*blocks)

        # Simple refinement
        self.refine = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

        # Output head: 1-channel HR thermal
        self.conv_out = nn.Conv2d(embed_dim, 1, kernel_size=3, padding=1)

        # Init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, xT_lr, xO_hr):
        # xT_lr: (B,1,LR,LR)
        # xO_hr: (B,3,HR,HR)
        # 1) Upsample LR thermal to HR
        xT_up = self.thermal_upsample(xT_lr)  # (B,1,HR,HR)

        # 2) Concatenate channels
        x_in = torch.cat([xT_up, xO_hr], dim=1)  # (B,4,HR,HR)

        # 3) Patch embedding
        feat = self.patch_embed(x_in)           # (B,embed_dim,HR,HR)

        # 4) SwinIR blocks
        feat = self.blocks(feat)

        # 5) Refinement + output
        feat = self.refine(feat)
        out  = self.conv_out(feat)              # (B,1,HR,HR)
        return out

# -----------------------
# Training with physics-based loss + resume
# -----------------------
def train_ssl4eo_swinir(num_epochs=NUM_EPOCHS):
    # 1) Find / download scenes
    scenes = discover_ssl4eo_scenes()
    if len(scenes) == 0:
        scenes = maybe_download_ssl4eo()
    if len(scenes) == 0:
        logger.error("No SSL4EO scenes available; aborting.")
        return

    # 2) Train/val/test split
    random.seed(42)
    np.random.seed(42)
    scenes_shuffled = scenes.copy()
    random.shuffle(scenes_shuffled)

    n_total = len(scenes_shuffled)
    n_train = int(0.7 * n_total)
    n_val   = int(0.15 * n_total)
    n_test  = n_total - n_train - n_val

    train_scenes = scenes_shuffled[:n_train]
    val_scenes   = scenes_shuffled[n_train:n_train + n_val]
    test_scenes  = scenes_shuffled[n_train + n_val:]

    logger.info(
        "Starting SSL4EO SwinIR training with physics-aware loss.\n"
        f"Scene split -> Train: {len(train_scenes)}, Val: {len(val_scenes)}, Test: {len(test_scenes)}"
    )

    # 3) Datasets / loaders
    train_ds = SSL4EOPatchDataset(
        train_scenes, hr_patch=HR_PATCH, upscale=UPSCALE,
        patches_per_scene=PATCHES_PER_SCENE_TRAIN, mode="train"
    )
    val_ds = SSL4EOPatchDataset(
        val_scenes, hr_patch=HR_PATCH, upscale=UPSCALE,
        patches_per_scene=PATCHES_PER_SCENE_VAL, mode="val"
    )
    test_ds = SSL4EOPatchDataset(
        test_scenes, hr_patch=HR_PATCH, upscale=UPSCALE,
        patches_per_scene=PATCHES_PER_SCENE_TEST, mode="test"
    )

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=0, pin_memory=(DEVICE.type == "cuda")
    )
    val_loader = DataLoader(
        val_ds, batch_size=1, shuffle=False,
        num_workers=0, pin_memory=(DEVICE.type == "cuda")
    )
    test_loader = DataLoader(
        test_ds, batch_size=1, shuffle=False,
        num_workers=0, pin_memory=(DEVICE.type == "cuda")
    )

    # 4) Model, optimizer, loss
    model = SwinIRFusionNet(embed_dim=96, num_blocks=6, upscale=UPSCALE).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    mse_loss  = nn.MSELoss()

    BEST_PATH = os.path.join(MODELS_DIR, "ssl4eo_best_swinir.pth")
    LAST_PATH = os.path.join(MODELS_DIR, "ssl4eo_last_swinir.pth")

    best_val_psnr = -1e9
    start_epoch = 1

    # Resume training if LAST_PATH exists
    if os.path.exists(LAST_PATH):
        try:
            ckpt = torch.load(LAST_PATH, map_location=DEVICE)
            model.load_state_dict(ckpt["model_state"])
            optimizer.load_state_dict(ckpt["optimizer_state"])
            start_epoch = ckpt["epoch"] + 1
            best_val_psnr = ckpt.get("best_val_psnr", -1e9)
            logger.info(
                f"Resuming from checkpoint {LAST_PATH}: "
                f"start_epoch={start_epoch}, best_val_psnr={best_val_psnr:.3f}"
            )
        except Exception as e:
            logger.warning(f"Failed to load checkpoint {LAST_PATH}: {e}")
            start_epoch = 1
    else:
        logger.info("No previous SwinIR checkpoint found; starting at epoch 1.")

    logger.info(f"Training epochs: {start_epoch} -> {num_epochs}")

    # 5) Training loop
    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        running = 0.0
        it = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} (train)")
        for lr_t, hr_rgb, hr_t in pbar:
            lr_t   = lr_t.to(DEVICE)      # (B,1,LR,LR)
            hr_rgb = hr_rgb.to(DEVICE)    # (B,3,HR,HR)
            hr_t   = hr_t.to(DEVICE)      # (B,1,HR,HR)

            optimizer.zero_grad()
            pred_hr = model(lr_t, hr_rgb)  # (B,1,HR,HR)

            # Data fidelity loss at HR
            loss_fid = mse_loss(pred_hr, hr_t)

            # Physics-aware loss:
            # Downsample predicted HR to LR and match the input LR thermal
            pred_lr = F.interpolate(
                pred_hr, size=(lr_t.shape[2], lr_t.shape[3]),
                mode="area"  # average-style downsampling
            )
            loss_phys = mse_loss(pred_lr, lr_t)

            loss = loss_fid + PHYS_LAMBDA * loss_phys
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running += float(loss.item())
            it += 1
            pbar.set_postfix(loss=running / max(1, it))

        avg_train_loss = running / max(1, it)
        logger.info(f"Epoch {epoch} TRAIN loss={avg_train_loss:.6f}")

        # 6) Validation
        model.eval()
        ps_sum = ss_sum = rm_sum = 0.0
        cnt = 0
        with torch.no_grad():
            for lr_t, hr_rgb, hr_t in tqdm(val_loader, desc=f"Epoch {epoch} (val)"):
                lr_t   = lr_t.to(DEVICE)
                hr_rgb = hr_rgb.to(DEVICE)
                hr_t   = hr_t.to(DEVICE)

                out = model(lr_t, hr_rgb)
                pred = out.cpu().squeeze().numpy()
                tgt  = hr_t.cpu().squeeze().numpy()

                ps, ss, rm = compute_metrics(pred, tgt)
                ps_sum += ps
                ss_sum += ss
                rm_sum += rm
                cnt += 1

        if cnt > 0:
            avg_ps = ps_sum / cnt
            avg_ss = ss_sum / cnt
            avg_rm = rm_sum / cnt
            logger.info(
                f"Epoch {epoch} VAL (SwinIR) PSNR={avg_ps:.3f} dB, "
                f"SSIM={avg_ss:.4f}, RMSE={avg_rm:.6f}"
            )
            if avg_ps > best_val_psnr:
                best_val_psnr = avg_ps
                torch.save(
                    {
                        "model_state": model.state_dict(),
                        "epoch": epoch,
                        "best_val_psnr": best_val_psnr,
                    },
                    BEST_PATH,
                )
                logger.info(f"Saved BEST SwinIR model -> {BEST_PATH} (PSNR={avg_ps:.3f})")

        # Always save LAST
        torch.save(
            {
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "epoch": epoch,
                "best_val_psnr": best_val_psnr,
            },
            LAST_PATH,
        )
        logger.info(f"Saved LAST SwinIR model -> {LAST_PATH} (epoch={epoch})")

    # 7) Final TEST evaluation with best model
    if os.path.exists(BEST_PATH):
        ckpt = torch.load(BEST_PATH, map_location=DEVICE)
        model.load_state_dict(ckpt["model_state"])
        logger.info(f"Loaded BEST SwinIR model from {BEST_PATH} for TEST evaluation.")

    model.eval()
    ps_sum = ss_sum = rm_sum = 0.0
    cnt = 0
    with torch.no_grad():
        for lr_t, hr_rgb, hr_t in tqdm(test_loader, desc="TEST (SwinIR)"):
            lr_t   = lr_t.to(DEVICE)
            hr_rgb = hr_rgb.to(DEVICE)
            hr_t   = hr_t.to(DEVICE)

            out = model(lr_t, hr_rgb)
            pred = out.cpu().squeeze().numpy()
            tgt  = hr_t.cpu().squeeze().numpy()

            ps, ss, rm = compute_metrics(pred, tgt)
            ps_sum += ps
            ss_sum += ss
            rm_sum += rm
            cnt += 1

    if cnt > 0:
        avg_ps = ps_sum / cnt
        avg_ss = ss_sum / cnt
        avg_rm = rm_sum / cnt
        logger.info(
            f"TEST SUMMARY (SSL4EO + SwinIR-style, RGB-guided, physics-aware): "
            f"PSNR={avg_ps:.3f} dB, SSIM={avg_ss:.4f}, RMSE={avg_rm:.6f}"
        )

    return model

# %% Script entry
if __name__ == "__main__":
    train_ssl4eo_swinir()
