
# Import

In [None]:
# -*- coding: utf-8 -*-
"""
Updated Training Script for Multi-Modal Satellite Image Fusion
Supports L30, S1, Planet inputs with Masked operations
"""
!pip install kornia
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/drive')

import os
import sys
import argparse
import time
from pathlib import Path
import matplotlib.pyplot as plt
import kornia
from kornia.morphology import erosion
import torch.nn.functional as F
import platform

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
from tqdm import tqdm
from scipy.ndimage import binary_erosion
from skimage.metrics import structural_similarity as ssim

import importlib
# GPU optimization
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    device_props = torch.cuda.get_device_properties(0)

# Replace with the actual path to the directory containing model2.py
model_dir_in_drive = '/content/drive/MyDrive/ColabModel/'
loss_dir_in_drive = '/content/drive/MyDrive/ColabLoss/'

# Add the directory to the Python path
if model_dir_in_drive not in sys.path:
    sys.path.append(model_dir_in_drive)
    print(f"Added {model_dir_in_drive} to sys.path")
else:
    print(f"{model_dir_in_drive} already in sys.path")
sys.path.append(loss_dir_in_drive)

# from MultimodalUnetnew import Unet
from MultimodalUnetnewadap import Unet

from mse_loss import MSELoss

torch.cuda.is_available()
DEVICE = torch.device("cuda:0")
print(f"Using device: {DEVICE}")

# Function

In [None]:
class OptimizedPatchDataset(Dataset):
    """Dataset class for loading .pt files with caching,
    automatically removing samples whose L30, S1, Planet mask all equal zero.
    """
    def __init__(self, data_dir, cache_size=50, map_location="cpu"):
        self.data_dir = Path(data_dir)
        all_files = sorted(self.data_dir.glob("sample_*.pt"))
        if not all_files:
            raise RuntimeError(f"[OptimizedPatchDataset] No sample_*.pt files found in {self.data_dir}")

        # Filter: only keep samples with at least one non-zero mask
        valid_files = []
        for p in all_files:
            sample = torch.load(p, map_location=map_location)
            m0 = sample.get('mask_l30')
            m1 = sample.get('mask_s1')
            m2 = sample.get('mask_planet')
            # Keep if any mask is not all zeros
            if not ((m0.sum()==0) and (m1.sum()==0) and (m2.sum()==0)):
                valid_files.append(p)
        if not valid_files:
            raise RuntimeError("[OptimizedPatchDataset] All samples have zero masks!")

        self.files = valid_files
        self.cache = {}
        self.cache_size = cache_size
        self.map_location = map_location

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if idx in self.cache:
            return self.cache[idx]
        sample = torch.load(self.files[idx], map_location=self.map_location)
        # Type adjustment & contiguous
        for key, value in sample.items():
            if isinstance(value, torch.Tensor):
                if value.dtype == torch.float64:
                    value = value.float()
                sample[key] = value.contiguous()
        # Cache
        if len(self.cache) < self.cache_size:
            self.cache[idx] = sample
        return sample

def psnr_masked(img_ref, img_test, valid_mask, data_range=1.0):
    """
    Calculate PSNR only on pixels where valid_mask is True.
    img_ref:    numpy array, shape = (C, H, W)
    img_test:   numpy array, shape = (C, H, W)
    valid_mask: numpy array, shape = (H, W), bool or {0,1}
    data_range: float, image maximum value range (set to 1.0 if normalized to [0,1])

    Returns: a scalar PSNR (dB). If valid_mask is all zeros, returns np.nan.
    """
    if img_ref.shape != img_test.shape:
        raise ValueError("[psnr_masked] Input shapes are inconsistent.")
    if valid_mask.shape != img_ref.shape[1:]:
        raise ValueError(f"[psnr_masked] mask shape {valid_mask.shape} does not match image shape {img_ref.shape}.")

    I = img_ref.astype(np.float64)
    K = img_test.astype(np.float64)

    # Count valid pixels
    num_valid_pixels = np.sum(valid_mask)
    if num_valid_pixels == 0:
        return np.nan

    # Calculate squared error
    squared_error = (I - K) ** 2  # shape = (C, H, W)

    # Expand mask to channel dimension for pixel-channel masking
    mask_chw = np.expand_dims(valid_mask, axis=0).repeat(I.shape[0], axis=0)
    masked_squared = squared_error[mask_chw > 0]

    # MSE
    mse_masked = np.sum(masked_squared) / (num_valid_pixels * I.shape[0])
    if mse_masked <= 0:
        return float('inf')
    psnr_val = 10.0 * np.log10((data_range ** 2) / mse_masked)
    return float(psnr_val)


def ssim_eroded_mask(img_ref, img_test, valid_mask, max_val=1.0, **ssim_kwargs):
    """
    Calculate SSIM with mask and erosion processing.
    img_ref:    numpy array, shape = (C, H, W)
    img_test:   numpy array, shape = (C, H, W)
    valid_mask: numpy array, shape = (H, W), bool or {0,1}
    max_val:    image maximum value range (set to 1.0 if normalized to [0,1])
    Returns: a scalar SSIM (averaged if multi-channel).
    """
    # Convert to H, W, C format for skimage.metrics.ssim
    if img_ref.shape != img_test.shape:
        raise ValueError("[ssim_eroded_mask] Input shapes are inconsistent.")

    C, H, W = img_ref.shape
    # Convert valid_mask to bool
    orig_mask = valid_mask > 0

    # Erosion: remove pixels near edges that are insufficient for window size
    win_size = ssim_kwargs.pop('win_size', min(3, H, W))
    if win_size % 2 == 0:
        win_size -= 1
    if win_size < 3:
        return np.nan
    struct_el = np.ones((win_size, win_size), dtype=bool)
    core_mask = binary_erosion(orig_mask, structure=struct_el, border_value=0)
    if np.count_nonzero(core_mask) == 0:
        # If no valid pixels after erosion, return nan
        return np.nan

    ssim_vals = []
    # Calculate SSIM map for each channel, then take mean of core_mask region
    for c in range(C):
        ref_chan = img_ref[c, :, :]
        test_chan = img_test[c, :, :]
        try:
            # full=True returns (score, ssim_map)
            _, ssim_map = ssim(
                ref_chan, test_chan,
                data_range=max_val,
                full=True,
                **{k: v for k, v in ssim_kwargs.items()}
            )
            # ssim_map shape = (H, W). Take average of core_mask region
            ssim_vals.append(np.mean(ssim_map[core_mask]))
        except Exception as e:
            print(f"[ssim_eroded_mask] Warning: Channel {c} SSIM calculation failed: {e}")
            ssim_vals.append(np.nan)

    return float(np.nanmean(ssim_vals))

def sam_masked(img_ref, img_test, valid_mask):
    """
    Calculate masked SAM (Spectral Angle Mapper).
    img_ref:    numpy array, shape = (C, H, W)
    img_test:   numpy array, shape = (C, H, W)
    valid_mask: numpy array, shape = (H, W), bool or {0,1}
    Returns: a scalar SAM (in degrees), returns np.nan if valid_mask is all zeros.
    """
    if img_ref.shape != img_test.shape:
        raise ValueError("[sam_masked] Input shapes are inconsistent.")
    C, H, W = img_ref.shape
    # Ensure float64 for better precision
    I = img_ref.astype(np.float64)
    K = img_test.astype(np.float64)

    # Find valid pixel indices
    mask_bool = valid_mask > 0
    if not np.any(mask_bool):
        return np.nan

    # Calculate spectral angle for each valid pixel
    # Shape transformation: flatten (C, H, W) to (C, N_valid)
    I_flat = I[:, mask_bool]  # shape = (C, N_valid)
    K_flat = K[:, mask_bool]  # shape = (C, N_valid)

    # dot product norm
    dot = np.sum(I_flat * K_flat, axis=0)  # (N_valid,)
    norm_I = np.linalg.norm(I_flat, axis=0)  # (N_valid,)
    norm_K = np.linalg.norm(K_flat, axis=0)  # (N_valid,)
    # Avoid division by zero
    denom = norm_I * norm_K
    # Some pixels may be all zeros, causing denom=0, exclude these
    valid_idx = denom > 0
    if not np.any(valid_idx):
        return np.nan

    cos_theta = dot[valid_idx] / denom[valid_idx]
    # Check numerical range to avoid arccos error
    cos_theta = np.clip(cos_theta, -1.0, 1.0)
    angles = np.arccos(cos_theta)  # in radians
    sam_deg = np.degrees(angles)   # convert to degrees
    return float(np.mean(sam_deg))


def psnr_masked_gpu_batch(
    img_ref: torch.Tensor,   # shape (B, C, H, W)
    img_test: torch.Tensor,  # shape (B, C, H, W)
    valid_mask: torch.Tensor,# shape (B, 1, H, W) or (B, H, W)
    data_range: float = 1.0
) -> torch.Tensor:
    """
    Batch Masked PSNR on GPU. Returns a tensor of shape (B,).
    """
    B = img_ref.shape[0]

    # Ensure mask shape is (B, 1, H, W)
    if valid_mask.dim() == 3:
        valid_mask = valid_mask.unsqueeze(1)

    # Compute squared error
    se = (img_ref - img_test).pow(2)  # (B, C, H, W)

    # Count valid pixels per sample
    num_pixels = valid_mask.sum(dim=(1, 2, 3))  # (B,)
    num_pixels = num_pixels * img_ref.shape[1]  # multiply by channels

    # Mask and sum squared errors
    se_masked = se * valid_mask  # (B, C, H, W)
    mse = se_masked.sum(dim=(1, 2, 3)) / num_pixels.clamp(min=1)  # (B,)

    # Compute PSNR
    psnr = 10.0 * torch.log10((data_range ** 2) / mse.clamp(min=1e-10))

    # Handle invalid samples (no valid pixels)
    psnr = torch.where(num_pixels > 0, psnr, torch.tensor(float('nan'), device=img_ref.device))

    return psnr  # (B,)

def ssim_eroded_mask_gpu_batch(
    img_ref: torch.Tensor,
    img_test: torch.Tensor,
    valid_mask: torch.Tensor,
    window_size: int = 3,
    data_range: float = 1.0
) -> torch.Tensor:
    """
    Corrected version: uses the same Gaussian window parameters as kornia
    """
    B, C, H, W = img_ref.shape

    if valid_mask.dim() == 3:
        valid_mask = valid_mask.unsqueeze(1)
    valid_mask = valid_mask.float()

    # SSIM constants
    C1 = (0.01 * data_range) ** 2
    C2 = (0.03 * data_range) ** 2

    # Create Gaussian window - use the same parameters as kornia
    sigma = 1.5  # kornia's default value
    coords = torch.arange(window_size, device=img_ref.device, dtype=img_ref.dtype)
    coords = coords - (window_size - 1) / 2
    g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
    g = g / g.sum()
    window = g.unsqueeze(0) * g.unsqueeze(1)
    window = window.unsqueeze(0).unsqueeze(0)

    # Normalize window
    window = window / window.sum()

    # Expand to all channels
    window = window.expand(C, 1, window_size, window_size).contiguous()

    # Calculate local statistics
    pad = window_size // 2

    mu1 = F.conv2d(img_ref, window, padding=pad, groups=C)
    mu2 = F.conv2d(img_test, window, padding=pad, groups=C)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img_ref.pow(2), window, padding=pad, groups=C) - mu1_sq
    sigma2_sq = F.conv2d(img_test.pow(2), window, padding=pad, groups=C) - mu2_sq
    sigma12 = F.conv2d(img_ref * img_test, window, padding=pad, groups=C) - mu1_mu2

    # Calculate SSIM
    numerator = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
    denominator = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    ssim_map = numerator / denominator

    # Average across channel dimension
    ssim_map = ssim_map.mean(dim=1, keepdim=True)

    # Erode mask
    kernel = torch.ones((window_size, window_size), device=valid_mask.device)
    mask_eroded = erosion(valid_mask, kernel, border_type='constant', border_value=0.0)
    mask_eroded = (mask_eroded > 0.5).float()

    # Calculate masked mean
    valid_pixels = mask_eroded.sum(dim=(1, 2, 3))
    ssim_masked = ssim_map * mask_eroded
    ssim_sum = ssim_masked.sum(dim=(1, 2, 3))

    ssim_mean = torch.where(
        valid_pixels > 0,
        ssim_sum / valid_pixels,
        torch.tensor(float('nan'), device=img_ref.device)
    )

    return ssim_mean

def sam_masked_gpu_batch(
    img_ref: torch.Tensor,    # (B, C, H, W)
    img_pred: torch.Tensor,   # (B, C, H, W)
    valid_mask: torch.Tensor, # (B, 1, H, W) or (B, H, W)
    eps: float = 1e-8
) -> torch.Tensor:
    """
    Batch Masked Spectral Angle Mapper on GPU. Returns a tensor of shape (B,) in degrees.
    """
    B, C, H, W = img_ref.shape

    # Ensure mask shape is (B, H, W)
    if valid_mask.dim() == 4:
        valid_mask = valid_mask.squeeze(1)
    elif valid_mask.dim() == 2:
        valid_mask = valid_mask.unsqueeze(0).expand(B, -1, -1)

    # Reshape for batch processing
    img_ref_flat = img_ref.view(B, C, -1)  # (B, C, H*W)
    img_pred_flat = img_pred.view(B, C, -1)
    mask_flat = valid_mask.view(B, -1) > 0  # (B, H*W)

    # Compute spectral angles for all pixels
    dot_product = (img_ref_flat * img_pred_flat).sum(dim=1)  # (B, H*W)
    norm_ref = img_ref_flat.norm(dim=1)  # (B, H*W)
    norm_pred = img_pred_flat.norm(dim=1)

    # Avoid division by zero
    denominator = (norm_ref * norm_pred).clamp(min=eps)
    cos_theta = (dot_product / denominator).clamp(-1.0, 1.0)

    # Convert to angles in degrees
    angles = torch.acos(cos_theta) * (180.0 / torch.pi)  # (B, H*W)

    # Apply mask and compute mean for each sample
    sam_mean = torch.zeros(B, device=img_ref.device)
    for b in range(B):
        valid_angles = angles[b][mask_flat[b]]
        if valid_angles.numel() > 0:
            sam_mean[b] = valid_angles.mean()
        else:
            sam_mean[b] = float('nan')

    return sam_mean  # (B,)


def init_weights(m):
    """Initialize model weights."""
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
# -------------------------------------------------------------------------
# 3. Parse command‐line arguments (modify this block)
# -------------------------------------------------------------------------
def parse_args(argv=None):
    parser = argparse.ArgumentParser(
        description="Unified training script: dynamically load model and loss classes, then train/val/test."
    )

    # Add a new --mode flag:
    parser.add_argument(
        "--mode",
        type=str,
        choices=["train", "test"],
        default="train",
        help="Choose 'train' to train from scratch (or resume), or 'test' to only run evaluation on a saved checkpoint."
    )

    # Model & loss
    parser.add_argument(
        "--model", type=str, required=True,
        help="Model name, matching a file in models/ that defines MODEL_CLASS."
    )
    parser.add_argument(
        "--loss", type=str, required=True,
        help="Loss name, matching a file in losses/ that defines LOSS_CLASS."
    )

    parser.add_argument(
        "--data-dir", type=str, default="data",
        help="Root data directory (must contain train/, val/, test/ subfolders)."
    )
    parser.add_argument(
        "--batch-size", type=int, default=32, help="Batch size for train/val/test."
    )
    parser.add_argument(
        "--epochs", type=int, default=15, help="Total number of training epochs."
    )
    parser.add_argument(
        "--lr", type=float, default=1e-4, help="Initial learning rate."
    )
    parser.add_argument(
        "--weight-decay", type=float, default=1e-5, help="Weight decay (L2) coefficient."
    )
    parser.add_argument(
        "--use-amp", action="store_true",
        help="If set, enable PyTorch AMP (automatic mixed precision)."
    )
    parser.add_argument(
        "--gpu", type=int, default=0,
        help="GPU id to use; set to -1 for CPU."
    )

    parser.add_argument(
        "--ckpt-dir", type=str, default="checkpoints",
        help="Directory for saving checkpoints."
    )
    parser.add_argument(
        "--log-dir", type=str, default="logs",
        help="Directory for saving logs and intermediate results."
    )
    parser.add_argument(
        "--plot-dir", type=str, default="plots",
        help="Directory for saving plots."
    )
    parser.add_argument(
        "--resume", type=str, default=None,
        help="If specified, load this checkpoint and resume training or do test (depending on --mode)."
    )

    return parser.parse_args(argv)

# Main

In [None]:
def main(args):
    # Device selection
    if args.gpu >= 0 and torch.cuda.is_available():
        device = torch.device(f"cuda:{args.gpu}")
    else:
        device = torch.device("cpu")
    print(f"[INFO] Using device: {device}")

    # Create checkpoint and log directories
    ckpt_dir = Path(args.ckpt_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    log_dir = Path(args.log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    plot_dir = Path(args.plot_dir)
    plot_dir.mkdir(parents=True, exist_ok=True)

    # Instantiate model, loss, optimizer, scheduler
    model = Unet(use_meta=False, use_selfattention=True,use_spatial_attention=True)#True,False
    model = model.to(device)

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"[INFO] Loaded model `{args.model}`, trainable params = {num_params:,}")
    criterion = MSELoss()

    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    use_amp = args.use_amp and (device.type.startswith("cuda"))
    # scaler = torch.cuda.amp.GradScaler() if use_amp else None
    scaler = torch.amp.GradScaler() if use_amp else None
    if use_amp:
        print("[INFO] Enabled AMP mixed‐precision.")

    # 5) Prepare datasets and dataloaders
    train_dir = Path(args.data_dir) / "train"
    val_dir   = Path(args.data_dir) / "val"
    test_dir  = Path(args.data_dir) / "test"
    for d in (train_dir, val_dir, test_dir):
        if not d.exists() or not d.is_dir():
            raise RuntimeError(f"[ERROR] Directory `{d}` does not exist or is not a folder.")
    print("[INFO] train dataset creating")
    train_dataset = OptimizedPatchDataset(train_dir, cache_size=500, map_location="cpu")
    print("[INFO] train dataset created")
    val_dataset   = OptimizedPatchDataset(val_dir,   cache_size=500, map_location="cpu")
    print("[INFO] val dataset created")
    test_dataset  = OptimizedPatchDataset(test_dir,  cache_size=500, map_location="cpu")
    print("[INFO] test dataset created")

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,  num_workers=0, pin_memory=True)
    print("[INFO] train dataset loaded")
    val_loader   = DataLoader(val_dataset,   batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)
    test_loader  = DataLoader(test_dataset,  batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)
    print(f"[INFO] #train = {len(train_dataset)}, #val = {len(val_dataset)}, #test = {len(test_dataset)}")

    # Learning rate schedulers
    total_iteration = (len(train_dataset) // args.batch_size)*args.epochs
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_iteration, eta_min=1e-6)

    # 6) If resume is specified, load checkpoint (for either training or testing)
    start_epoch   = 0
    best_val_loss = float("inf")
    if args.resume is not None:
        if os.path.isfile(args.resume):
            ckpt = torch.load(args.resume, map_location=device,)
            model.load_state_dict(ckpt["model_state_dict"])
            optimizer.load_state_dict(ckpt["optimizer_state_dict"])
            scheduler.load_state_dict(ckpt["scheduler_state_dict"])
            start_epoch   = ckpt.get("epoch", 0)
            best_val_loss = ckpt.get("best_val_loss", best_val_loss)
            print(f"[INFO] Loaded checkpoint `{args.resume}`, starting epoch = {start_epoch}, best_val_loss = {best_val_loss:.6f}")
        else:
            print(f"[WARN] Specified resume file `{args.resume}` not found, ignoring.")

    # 7) Define train_one_epoch, validate_one_epoch, test_and_evaluate (unchanged) …
    def train_one_epoch(epoch_idx):
        model.train()
        running_loss = 0.0
        num_batches = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch_idx+1}/{args.epochs} [Train]")
        for batch in pbar:
            # Move inputs to device
            l30_img     = batch["l30_img"].to(device, non_blocking=True)
            # mask_l30    = batch["mask_l30"].to(device, non_blocking=True)
            l30_meta    = batch["l30_meta"].to(device, non_blocking=True)

            s1_img      = batch["s1_img"].to(device, non_blocking=True)
            # mask_s1     = batch["mask_s1"].to(device, non_blocking=True)
            s1_meta     = batch["s1_meta"].to(device, non_blocking=True)

            planet_img  = batch["planet_img"].to(device, non_blocking=True)
            # mask_planet = batch["mask_planet"].to(device, non_blocking=True)
            planet_meta = batch["planet_meta"].to(device, non_blocking=True)

            s30_gt      = batch["s30_img_gt"].to(device, non_blocking=True)
            mask_s30    = batch["mask_s30"].to(device, non_blocking=True)

            optimizer.zero_grad()
            if use_amp:
                with torch.amp.autocast('cuda'):
                    # outputs = model(
                    #     l30_img, mask_l30, l30_meta,
                    #     s1_img, mask_s1, s1_meta,
                    #     planet_img, mask_planet, planet_meta
                    # )
                    outputs = model(
                        l30_img, l30_meta,
                        s1_img, s1_meta,
                        planet_img, planet_meta
                    )
                    loss = criterion(outputs, s30_gt, mask_s30)
                    ssims=ssim_eroded_mask_gpu_batch(s30_gt, outputs, mask_s30).mean()
                    sams=sam_masked_gpu_batch(s30_gt, outputs, mask_s30).mean()
                    psnrs=psnr_masked_gpu_batch(s30_gt, outputs, mask_s30).mean()

                    loss_ssim  = (1 - ssims)   # 结构损失
                    loss_sam   = sams/180         # 光谱角度损失
                    loss_psnr  = 1.0 - psnrs/50.0
                    loss = 1*loss + 0.5*loss_ssim + 0.5*loss_sam + 0.2 * loss_psnr

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                # outputs = model(
                #     l30_img, mask_l30, l30_meta,
                #     s1_img, mask_s1, s1_meta,
                #     planet_img, mask_planet, planet_meta
                # )
                outputs = model(
                        l30_img, l30_meta,
                        s1_img, s1_meta,
                        planet_img, planet_meta
                )
                loss = criterion(outputs, s30_gt, mask_s30)

                ssims=ssim_eroded_mask_gpu_batch(s30_gt, outputs, mask_s30).mean()
                sams=sam_masked_gpu_batch(s30_gt, outputs, mask_s30).mean()
                psnrs=psnr_masked_gpu_batch(s30_gt, outputs, mask_s30).mean()

                loss_ssim  = (1 - ssims)   # 结构损失
                loss_sam   = sams/180         # 光谱角度损失
                loss_psnr  = 1.0 - psnrs/50.0
                loss = 1*loss + 0.5*loss_ssim + 0.5*loss_sam + 0.2 * loss_psnr

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
            scheduler.step()
            running_loss += loss.item()
            num_batches += 1
            pbar.set_postfix({"loss": f"{loss.item():.4e}"})

        return running_loss / max(num_batches, 1)

    def validate_one_epoch():
        model.eval()
        val_loss_sum = 0.0
        num_batches = 0
        with torch.no_grad():
            pbar = tqdm(val_loader, desc="[Validate]")
            for batch in pbar:
                # Move to device ...
                l30_img     = batch["l30_img"].to(device, non_blocking=True)
                # mask_l30    = batch["mask_l30"].to(device, non_blocking=True)
                l30_meta    = batch["l30_meta"].to(device, non_blocking=True)

                s1_img      = batch["s1_img"].to(device, non_blocking=True)
                # mask_s1     = batch["mask_s1"].to(device, non_blocking=True)
                s1_meta     = batch["s1_meta"].to(device, non_blocking=True)

                planet_img  = batch["planet_img"].to(device, non_blocking=True)
                # mask_planet = batch["mask_planet"].to(device, non_blocking=True)
                planet_meta = batch["planet_meta"].to(device, non_blocking=True)

                s30_gt      = batch["s30_img_gt"].to(device, non_blocking=True)
                mask_s30    = batch["mask_s30"].to(device, non_blocking=True)

                # outputs = model(
                #     l30_img, mask_l30, l30_meta,
                #     s1_img, mask_s1, s1_meta,
                #     planet_img, mask_planet, planet_meta
                # )
                outputs = model(
                        l30_img, l30_meta,
                        s1_img, s1_meta,
                        planet_img, planet_meta
                )
                loss = criterion(outputs, s30_gt, mask_s30)

                ssims=ssim_eroded_mask_gpu_batch(s30_gt, outputs, mask_s30).mean()
                sams=sam_masked_gpu_batch(s30_gt, outputs, mask_s30).mean()
                psnrs=psnr_masked_gpu_batch(s30_gt, outputs, mask_s30).mean()

                loss_ssim  = (1 - ssims)   # 结构损失
                loss_sam   = sams/180         # 光谱角度损失
                loss_psnr  = 1.0 - psnrs/50.0
                loss = 1*loss + 0.5*loss_ssim + 0.5*loss_sam + 0.2 * loss_psnr

                val_loss_sum += loss.item()
                num_batches += 1
                pbar.set_postfix({"val_loss": f"{loss.item():.4e}"})

        return val_loss_sum / max(num_batches, 1)

    def test_and_evaluate():
        """
        Run full test‐set evaluation (masked MSE/RMSE/PSNR/SSIM/SAM), save results.
        """
        model.eval()
        mse_list, rmse_list, psnr_list, ssim_list, sam_list = [], [], [], [], []
        num_valid_samples = 0
        total_valid_pixels = 0

        save_sample_vis = True
        num_samples_to_save = 10
        saved_cnt = 0
        sample_vis_dir = log_dir / "sample_visuals"
        if save_sample_vis:
            sample_vis_dir.mkdir(parents=True, exist_ok=True)

        with torch.no_grad():
            pbar = tqdm(test_loader, desc="[Test]")
            for batch_idx, batch in enumerate(pbar):
                # Move to device ...
                l30_img     = batch["l30_img"].to(device, non_blocking=True)
                # mask_l30    = batch["mask_l30"].to(device, non_blocking=True)
                l30_meta    = batch["l30_meta"].to(device, non_blocking=True)

                s1_img      = batch["s1_img"].to(device, non_blocking=True)
                # mask_s1     = batch["mask_s1"].to(device, non_blocking=True)
                s1_meta     = batch["s1_meta"].to(device, non_blocking=True)

                planet_img  = batch["planet_img"].to(device, non_blocking=True)
                # mask_planet = batch["mask_planet"].to(device, non_blocking=True)
                planet_meta = batch["planet_meta"].to(device, non_blocking=True)

                s30_gt      = batch["s30_img_gt"].to(device, non_blocking=True)
                mask_s30    = batch["mask_s30"].to(device, non_blocking=True)

                # outputs = model(
                #     l30_img, mask_l30, l30_meta,
                #     s1_img, mask_s1, s1_meta,
                #     planet_img, mask_planet, planet_meta
                # )
                outputs = model(
                        l30_img, l30_meta,
                        s1_img, s1_meta,
                        planet_img, planet_meta
                )
                mask_ext = mask_s30.expand(-1, outputs.size(1), -1, -1)
                outputs = outputs * mask_ext

                pred_np = outputs.detach().cpu().numpy()    # (B, C, H, W)
                gt_np   = s30_gt.detach().cpu().numpy()     # (B, C, H, W)
                mask_np = mask_s30.detach().cpu().numpy().squeeze(1)  # (B, H, W)

                B = pred_np.shape[0]
                for i in range(B):
                    pred_i = pred_np[i]
                    gt_i   = gt_np[i]
                    mask_i = mask_np[i]

                    # Masked MSE / RMSE
                    mse_map = (pred_i - gt_i) ** 2  # (C, H, W)
                    num_valid_pixels = np.sum(mask_i) * pred_i.shape[0]
                    if num_valid_pixels > 0:
                        mse_val = np.sum(mse_map * mask_i[np.newaxis, ...]) / num_valid_pixels
                    else:
                        mse_val = np.nan
                    rmse_val = np.sqrt(mse_val) if not np.isnan(mse_val) else np.nan

                    # Masked PSNR
                    psnr_val = psnr_masked(gt_i, pred_i, mask_i, data_range=1.0)

                    # Masked & eroded SSIM
                    ssim_val = ssim_eroded_mask(gt_i, pred_i, mask_i, max_val=1.0)

                    # Masked SAM
                    sam_val = sam_masked(gt_i, pred_i, mask_i)

                    if not np.isnan(mse_val):
                        mse_list.append(mse_val)
                        rmse_list.append(rmse_val)
                        psnr_list.append(psnr_val)
                        ssim_list.append(ssim_val)
                        sam_list.append(sam_val)
                        num_valid_samples += 1
                        total_valid_pixels += int(np.sum(mask_i))

                    # Visualization: save first few samples as RGB comparison
                    if save_sample_vis and saved_cnt < num_samples_to_save:
                        import matplotlib.pyplot as plt

                        if pred_i.shape[0] >= 3:
                            pred_rgb = np.clip(pred_i[:3].transpose(1, 2, 0), 0, 1)
                            gt_rgb   = np.clip(gt_i[:3].transpose(1, 2, 0), 0, 1)
                        else:
                            pred_rgb = np.repeat(pred_i[0:1].transpose(1, 2, 0), 3, axis=2)
                            gt_rgb   = np.repeat(gt_i[0:1].transpose(1, 2, 0), 3, axis=2)

                        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
                        axes[0].imshow(gt_rgb)
                        axes[0].set_title("GT (RGB)")
                        axes[0].axis("off")
                        axes[1].imshow(pred_rgb)
                        axes[1].set_title("Pred (RGB)")
                        axes[1].axis("off")
                        plt.tight_layout()
                        vis_path = sample_vis_dir / f"sample_{batch_idx:04d}_{i:02d}.png"
                        plt.savefig(vis_path)
                        plt.close(fig)
                        saved_cnt += 1

                if mse_list:
                    pbar.set_postfix({
                        "MSE":  f"{np.nanmean(mse_list):.6f}",
                        "RMSE": f"{np.nanmean(rmse_list):.6f}",
                        "PSNR": f"{np.nanmean(psnr_list):.4f}"
                    })

        # Compute means and stds
        mse_arr  = np.array(mse_list)
        rmse_arr = np.array(rmse_list)
        psnr_arr = np.array(psnr_list)
        ssim_arr = np.array(ssim_list)
        sam_arr  = np.array(sam_list)

        results = {
            "mse_mean": float(np.nanmean(mse_arr)),
            "mse_std":  float(np.nanstd(mse_arr)),
            "rmse_mean": float(np.nanmean(rmse_arr)),
            "rmse_std":  float(np.nanstd(rmse_arr)),
            "psnr_mean": float(np.nanmean(psnr_arr)),
            "psnr_std":  float(np.nanstd(psnr_arr)),
            "ssim_mean": float(np.nanmean(ssim_arr)),
            "ssim_std":  float(np.nanstd(ssim_arr)),
            "sam_mean":  float(np.nanmean(sam_arr)),
            "sam_std":   float(np.nanstd(sam_arr)),
            "num_valid_samples": num_valid_samples,
            "total_valid_pixels": total_valid_pixels,
        }

        results_file = log_dir / "test_results.txt"
        with open(results_file, "w") as f:
            f.write("===== Test Set Evaluation =====\n")
            f.write(f"MSE (Masked)   : {results['mse_mean']:.6f} ± {results['mse_std']:.6f}\n")
            f.write(f"RMSE (Masked)  : {results['rmse_mean']:.6f} ± {results['rmse_std']:.6f}\n")
            f.write(f"PSNR (Masked)  : {results['psnr_mean']:.4f} ± {results['psnr_std']:.4f} dB\n")
            f.write(f"SSIM (Masked)  : {results['ssim_mean']:.4f} ± {results['ssim_std']:.4f}\n")
            f.write(f"SAM (Masked)   : {results['sam_mean']:.4f} ± {results['sam_std']:.4f} degrees\n")
            f.write(f"Valid Samples  : {results['num_valid_samples']}\n")
            f.write(f"Valid Pixels   : {results['total_valid_pixels']:,}\n")
        print(f"[INFO] Test results saved to {results_file}")
        print(results)
        return results

    # -------------------------------------------------------------------------
    # 8) If mode == "test", skip training loop and just run evaluation
    # -------------------------------------------------------------------------
    if args.mode == "test":
        print("[INFO] Mode = TEST → skipping training, running only test/evaluation.")
        # If resume was provided, the checkpoint was already loaded above.
        # If resume is None, we cannot test—raise an error.
        if args.resume is None:
            raise RuntimeError("[ERROR] --mode test requires --resume <checkpoint_path> to be specified.")
        # Run test_and_evaluate() and then exit
        test_results = test_and_evaluate()
        print("[INFO] Done with test/evaluation. Exiting.")
        return

    # -------------------------------------------------------------------------
    # 9) Otherwise (mode == "train"), run the full train/val/test loop
    # -------------------------------------------------------------------------
    train_losses=[]
    val_losses=[]
    for epoch in range(start_epoch, args.epochs):
        epoch_start_time = time.time()

        train_loss = train_one_epoch(epoch)
        train_losses.append(train_loss)
        val_loss   = validate_one_epoch()
        val_losses.append(val_loss)

        # scheduler.step(val_loss)

        elapsed = time.time() - epoch_start_time
        print(
            f"[Epoch {epoch+1}/{args.epochs}] "
            f"Train Loss: {train_loss:.6f}  Val Loss: {val_loss:.6f}  "
            f"Time: {elapsed:.1f}s"
        )

        # Save best checkpoint
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            ckpt_path = ckpt_dir / f"best_{args.model}_epoch{epoch+1}_valloss{val_loss:.4f}.pth"
            torch.save({
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "best_val_loss": best_val_loss
            }, ckpt_path)
            print(f"[INFO] Validation loss improved → saved checkpoint to {ckpt_path}")

    print(f"[INFO] Training finished, best val loss = {best_val_loss:.6f}")

    # Finally run full test/evaluation once more
    print("\n" + "="*60)
    print("[INFO] Running final test/evaluation …")
    print("="*60)
    _ = test_and_evaluate()
    print("[INFO] Test evaluation completed.\n")
    #%% --- Plot Training Curves ---
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', color='blue')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (MSE - Masked)')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', color='blue')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (Log Scale)')
    plt.title('Training and Validation Loss (Log Scale)')
    plt.yscale('log')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plot_path = plot_dir / "loss_curves.png"
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"Saved loss curves to {plot_path}")


# Copy

In [None]:
import shutil
import os

src = "/content/drive/MyDrive/ColabData/ReconstructionDataset_Final"
dst = "/content/data/ReconstructionDataset_Final"

if os.path.exists(dst):
    shutil.rmtree(dst)
shutil.copytree(src, dst)

print("Copied files:", os.listdir(dst))

# Train

In [None]:
class IDEArgs: pass
args = IDEArgs()
args.mode        = "train"
args.model       = "MultimodalUnetadap"
args.loss        = "mse_loss"
args.data_dir    = "/content/data/ReconstructionDataset_Final"
args.batch_size  = 64
args.epochs      = 300
args.lr          = 1e-4
args.weight_decay= 1e-5
args.use_amp     = True#False
args.gpu         = 0
args.ckpt_dir    = "/content/drive/MyDrive/Colabdata/checkpoints/"+args.model
args.log_dir     = "/content/drive/MyDrive/Colabdata/logs/"+args.model
args.plot_dir     = "/content/drive/MyDrive/Colabdata/plots/"+args.model
args.resume      = None
main(args)

#infer-import

In [None]:
!pip install rasterio
from google.colab import drive
drive.mount('/content/drive')

# Replace with the actual path to the directory containing model2.py
model_dir_in_drive = '/content/drive/MyDrive/ColabModel/'
loss_dir_in_drive = '/content/drive/MyDrive/ColabLoss/'

import sys
# Add the directory to the Python path
if model_dir_in_drive not in sys.path:
    sys.path.append(model_dir_in_drive)
    print(f"Added {model_dir_in_drive} to sys.path")
else:
    print(f"{model_dir_in_drive} already in sys.path")
sys.path.append(loss_dir_in_drive)

from MultimodalUnetnewadap import Unet
import os
import argparse
from pathlib import Path
import json
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from scipy.ndimage import binary_erosion
from skimage.metrics import structural_similarity as ssim
import importlib

import rasterio
from rasterio.transform import Affine
from rasterio.errors import NotGeoreferencedWarning
import warnings

warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)

# infer- config

In [None]:
# -------------------------------------------------------------
# 1. Define dataset class (consistent with OptimizedPatchDataset in training script)
# -------------------------------------------------------------
class OptimizedPatchDataset(Dataset):
    """Dataset class for loading .pt files with caching,
    automatically removing samples whose L30, S1, Planet mask all equal zero.
    """
    def __init__(self, data_dir, cache_size=50, map_location="cpu"):
        self.data_dir = Path(data_dir)
        all_files = sorted(self.data_dir.glob("sample_*.pt"))
        if not all_files:
            raise RuntimeError(f"[OptimizedPatchDataset] No sample_*.pt files found in {self.data_dir}")

        # Filter: only keep samples with at least one non-zero mask
        valid_files = []
        for p in all_files:
            sample = torch.load(p, map_location=map_location)
            m0 = sample.get('mask_l30')
            m1 = sample.get('mask_s1')
            m2 = sample.get('mask_planet')
            # Keep if any mask is not all zeros
            if not ((m0.sum()==0) and (m1.sum()==0) and (m2.sum()==0)):
                valid_files.append(p)
        if not valid_files:
            raise RuntimeError("[OptimizedPatchDataset] All samples have zero masks!")

        self.files = valid_files
        self.cache = {}
        self.cache_size = cache_size
        self.map_location = map_location

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if idx in self.cache:
            return self.cache[idx]
        sample = torch.load(self.files[idx], map_location=self.map_location)
        # Type adjustment & contiguous
        for key, value in sample.items():
            if isinstance(value, torch.Tensor):
                if value.dtype == torch.float64:
                    value = value.float()
                sample[key] = value.contiguous()
        # Cache
        sample["__path__"] = str(self.files[idx])
        if len(self.cache) < self.cache_size:
            self.cache[idx] = sample
        return sample

# -------------------------------------------------------------
# 2. Define metric functions (same version as used in training)
# -------------------------------------------------------------

def psnr_masked(img_ref, img_test, valid_mask, data_range=1.0):
    """Calculate masked PSNR"""
    if img_ref.shape != img_test.shape:
        raise ValueError("[psnr_masked] Shape mismatch")
    if valid_mask.shape != img_ref.shape[1:]:
        raise ValueError(f"[psnr_masked] Mask shape mismatch")
    I = img_ref.astype(np.float64)
    K = img_test.astype(np.float64)
    num_valid = np.sum(valid_mask)
    if num_valid == 0:
        return float("nan")
    sq_err = (I - K) ** 2
    mask_chw = np.expand_dims(valid_mask, axis=0).repeat(I.shape[0], axis=0)
    masked_sq = sq_err[mask_chw > 0]
    mse = np.sum(masked_sq) / (num_valid * I.shape[0])
    if mse <= 0:
        return float("inf")
    psnr_val = 10.0 * np.log10((data_range ** 2) / mse)
    return float(psnr_val)

def ssim_eroded_mask(img_ref, img_test, valid_mask, max_val=1.0, **ssim_kwargs):
    """Calculate SSIM with eroded mask"""
    if img_ref.shape != img_test.shape:
        raise ValueError("[ssim_eroded_mask] Shape mismatch")
    C, H, W = img_ref.shape
    orig = valid_mask > 0
    win_size = ssim_kwargs.pop('win_size', min(3, H, W))
    if win_size % 2 == 0:
        win_size -= 1
    if win_size < 3:
        return float("nan")
    struct_el = np.ones((win_size, win_size), dtype=bool)
    core = binary_erosion(orig, structure=struct_el, border_value=0)
    if np.count_nonzero(core) == 0:
        return float("nan")
    ssim_vals = []
    for c in range(C):
        ref_c = img_ref[c, :, :]
        test_c = img_test[c, :, :]
        try:
            _, ssim_map = ssim(
                ref_c, test_c,
                data_range=max_val,
                full=True,
                **ssim_kwargs
            )
            ssim_vals.append(np.mean(ssim_map[core]))
        except Exception:
            ssim_vals.append(float("nan"))
    return float(np.nanmean(ssim_vals))

def sam_masked(img_ref, img_test, valid_mask):
    """Calculate masked SAM (Spectral Angle Mapper)"""
    if img_ref.shape != img_test.shape:
        raise ValueError("[sam_masked] Shape mismatch")
    C, H, W = img_ref.shape
    I = img_ref.astype(np.float64)
    K = img_test.astype(np.float64)
    mask_bool = valid_mask > 0
    if not np.any(mask_bool):
        return float("nan")
    I_flat = I[:, mask_bool]
    K_flat = K[:, mask_bool]
    dot = np.sum(I_flat * K_flat, axis=0)
    norm_I = np.linalg.norm(I_flat, axis=0)
    norm_K = np.linalg.norm(K_flat, axis=0)
    denom = norm_I * norm_K
    valid_idx = denom > 0
    if not np.any(valid_idx):
        return float("nan")
    cos_theta = dot[valid_idx] / denom[valid_idx]
    cos_theta = np.clip(cos_theta, -1.0, 1.0)
    angles = np.arccos(cos_theta)
    sam_deg = np.degrees(angles)
    return float(np.mean(sam_deg))

def save_multiband_tif_rasterio(array_3d, output_path):
    """
    Use rasterio to save array_3d with shape (C, H, W) as a 12-band GeoTIFF (without geographic coordinate information)
    """
    C, H, W = array_3d.shape
    assert C <= 2**16, "Number of bands should not be too large!"
    # No geographic reference, Affine.identity is equivalent to pixel-coordinate identity matrix
    transform = Affine.identity()

    # Open a new write handle
    with rasterio.open(
        output_path,
        'w',
        driver='GTiff',
        height=H,
        width=W,
        count=C,            # Number of bands
        dtype=array_3d.dtype,
        crs=None,           # No coordinate system
        transform=transform
    ) as dst:
        # Write band by band
        for band_idx in range(C):
            dst.write(array_3d[band_idx], band_idx + 1)

# infer-main


In [None]:
# -------------------------------------------------------------
# 4. Inference main function
# -------------------------------------------------------------
def main(args):
    # Select device
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Using device: {device}")

    # Create output directories
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    vis_dir = output_dir / "visualizations"
    vis_dir.mkdir(parents=True, exist_ok=True)
    arrays_dir = output_dir / "arrays"
    arrays_dir.mkdir(parents=True, exist_ok=True)

    # ------------------------------------------------------------------
    # 4.1 Load model
    # ------------------------------------------------------------------
    model = Unet(use_meta=False, use_selfattention=True,use_spatial_attention=True).to(device)#True,False
    # Load checkpoint
    if not Path(args.ckpt_path).is_file():
        raise FileNotFoundError(f"[ERROR] Checkpoint not found: {args.ckpt_path}")
    ckpt = torch.load(args.ckpt_path, map_location=device,)
    if "model_state_dict" in ckpt:
        model.load_state_dict(ckpt["model_state_dict"])
    else:
        model.load_state_dict(ckpt)

    model.eval()
    print(f"[INFO] Loaded checkpoint from {args.ckpt_path}")

    # ------------------------------------------------------------------
    # 4.2 Prepare test dataset DataLoader
    # ------------------------------------------------------------------
    test_dir = Path(args.data_dir) / "test"
    if not test_dir.exists():
        raise RuntimeError(f"[ERROR] Test directory '{test_dir}' does not exist")
    test_dataset = OptimizedPatchDataset(test_dir, cache_size=100, map_location="cpu")
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )

    # ------------------------------------------------------------------
    # 4.3 Inference loop: run inference on all samples in test_loader and calculate metrics
    # ------------------------------------------------------------------
    results = []
    total_mse, total_rmse, total_psnr, total_ssim, total_sam = [], [], [], [], []
    metrics_by_combo = defaultdict(list)

    with torch.no_grad():
        pbar = tqdm(test_loader, desc="[Inference]")
        for batch_idx, batch in enumerate(pbar):
            # Move data to device
            l30_img = batch["l30_img"].to(device, non_blocking=True)
            mask_l30 = batch["mask_l30"].to(device, non_blocking=True)
            l30_meta = batch["l30_meta"].to(device, non_blocking=True)
            s1_img   = batch["s1_img"].to(device, non_blocking=True)
            mask_s1  = batch["mask_s1"].to(device, non_blocking=True)
            s1_meta  = batch["s1_meta"].to(device, non_blocking=True)
            planet_img = batch["planet_img"].to(device, non_blocking=True)
            mask_planet = batch["mask_planet"].to(device, non_blocking=True)
            planet_meta = batch["planet_meta"].to(device, non_blocking=True)
            s30_gt   = batch["s30_img_gt"].to(device, non_blocking=True)
            mask_s30 = batch["mask_s30"].to(device, non_blocking=True)

            # Forward inference
            fake_s30 = model(
                l30_img, l30_meta,
                s1_img, s1_meta,
                planet_img, planet_meta
            )

            mask_ext = mask_s30.expand(-1, fake_s30.size(1), -1, -1)
            fake_s30 = fake_s30 * mask_ext

            # Convert to numpy array for metric calculation and saving
            fake_np = fake_s30.detach().cpu().numpy()  # shape (B, C=12, H, W)
            gt_np   = s30_gt.detach().cpu().numpy()    # shape (B, C=12, H, W)
            mask_np = mask_s30.detach().cpu().numpy().squeeze(1)  # (B, H, W)
            mask_l30_np    = mask_l30.detach().cpu().numpy().squeeze(1)
            mask_s1_np     = mask_s1.detach().cpu().numpy().squeeze(1)
            mask_planet_np = mask_planet.detach().cpu().numpy().squeeze(1)

            B = fake_np.shape[0]
            for i in range(B):
                pred = fake_np[i]  # (12, H, W)
                gt   = gt_np[i]    # (12, H, W)
                mask = mask_np[i]  # (H, W)

                # Per-pixel MSE / RMSE
                sq_err = (pred - gt) ** 2
                num_valid = np.sum(mask) * pred.shape[0]
                if num_valid > 0:
                    mse_val = np.sum(sq_err * mask[np.newaxis, ...]) / num_valid
                else:
                    mse_val = float("nan")
                rmse_val = np.sqrt(mse_val) if not np.isnan(mse_val) else float("nan")

                # PSNR / SSIM / SAM
                psnr_val = psnr_masked(gt, pred, mask, data_range=1.0)
                ssim_val = ssim_eroded_mask(gt, pred, mask, max_val=1.0)
                sam_val  = sam_masked(gt, pred, mask)
                # Group metrics by input combination
                combo = []
                if mask_l30_np[i].any():    combo.append("L30")
                if mask_s1_np[i].any():     combo.append("S1")
                if mask_planet_np[i].any(): combo.append("Planet")
                # combo_key = "+".join(combo) if combo else "None"
                if combo:
                    combo_key = "+".join(combo)
                else:
                    combo_key = "None"
                metrics_by_combo[combo_key].append({"mse": mse_val, "rmse": rmse_val, "psnr": psnr_val, "ssim": ssim_val, "sam": sam_val})

                total_mse.append(mse_val)
                total_rmse.append(rmse_val)
                total_psnr.append(psnr_val)
                total_ssim.append(ssim_val)
                total_sam.append(sam_val)

                # Construct sample name: use original .pt filename (without extension)
                full_path = batch["__path__"][i]
                sample_name = Path(full_path).stem

                pred_save_path = arrays_dir / f"{sample_name}_pred.tif"
                gt_save_path   = arrays_dir / f"{sample_name}_gt.tif"
                save_multiband_tif_rasterio(pred, pred_save_path)
                save_multiband_tif_rasterio(gt,   gt_save_path)

                # Save metric information
                results.append({
                    "sample": sample_name,
                    "mse": float(mse_val),
                    "rmse": float(rmse_val),
                    "psnr": float(psnr_val),
                    "ssim": float(ssim_val),
                    "sam": float(sam_val),
                })

                # Visualize and save: RGB bands use 2,1,0
                if pred.shape[0] >= 3:
                    pred_rgb = np.clip(pred[[2,1,0]].transpose(1, 2, 0), 0, 1)
                    gt_rgb   = np.clip(gt[[2,1,0]].transpose(1, 2, 0), 0, 1)
                else:
                    pred_rgb = np.repeat(pred[0:1].transpose(1, 2, 0), 3, axis=2)
                    gt_rgb   = np.repeat(gt[0:1].transpose(1, 2, 0), 3, axis=2)
                diff = np.abs(pred_rgb - gt_rgb)

                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                axes[0].imshow(gt_rgb)
                axes[0].set_title("Ground Truth (bands 2,1,0)")
                axes[0].axis("off")
                axes[1].imshow(pred_rgb)
                axes[1].set_title(f"Prediction (PSNR: {psnr_val:.2f} dB)")
                axes[1].axis("off")
                axes[2].imshow(diff, cmap='hot')
                axes[2].set_title("Absolute Difference")
                axes[2].axis("off")
                plt.tight_layout()

                # Construct visualization save path
                vis_save_path = vis_dir / f"{sample_name}.png"
                plt.savefig(vis_save_path, dpi=200, bbox_inches='tight')
                plt.close(fig)

            # Update progress bar
            pbar.set_postfix({
                "MSE": f"{np.nanmean(total_mse):.6f}",
                "PSNR": f"{np.nanmean(total_psnr):.2f}",
                "SSIM": f"{np.nanmean(total_ssim):.4f}"
            })

    # ------------------------------------------------------------------
    # 4.4 Aggregate and save metrics
    # ------------------------------------------------------------------
    summary = {
        "num_samples": len(total_mse),
        "mse_mean": float(np.nanmean(total_mse)),
        "mse_std":  float(np.nanstd(total_mse)),
        "rmse_mean": float(np.nanmean(total_rmse)),
        "rmse_std":  float(np.nanstd(total_rmse)),
        "psnr_mean": float(np.nanmean(total_psnr)),
        "psnr_std":  float(np.nanstd(total_psnr)),
        "ssim_mean": float(np.nanmean(total_ssim)),
        "ssim_std":  float(np.nanstd(total_ssim)),
        "sam_mean":  float(np.nanmean(total_sam)),
        "sam_std":   float(np.nanstd(total_sam)),
    }

    # Save all sample metrics to JSON
    results_file = output_dir / "inference_results.json"
    with open(results_file, "w") as f:
        json.dump({
            "summary": summary,
            "details": results
        }, f, indent=2)
        # Performance summary by input combination
        combo_summaries = {}
        for combo_key, records in metrics_by_combo.items():
            arr = np.array([[r["psnr"], r["ssim"], r["sam"], r["mse"], r["rmse"]] for r in records])
            combo_summaries[combo_key] = {
                "count": len(records),
                "psnr_mean": float(np.nanmean(arr[:,0])), "psnr_std": float(np.nanstd(arr[:,0])),
                "ssim_mean": float(np.nanmean(arr[:,1])), "ssim_std": float(np.nanstd(arr[:,1])),
                "sam_mean": float(np.nanmean(arr[:,2])),   "sam_std": float(np.nanstd(arr[:,2])),
                "mse_mean": float(np.nanmean(arr[:,3])),   "mse_std": float(np.nanstd(arr[:,3])),
                "rmse_mean": float(np.nanmean(arr[:,4])),  "rmse_std": float(np.nanstd(arr[:,4])),
            }
        full_summary = {"summary": summary, "by_input_combo": combo_summaries}
        with open(output_dir/"inference_summary_by_combo.json", "w") as f2:
            json.dump(full_summary, f2, indent=2)
        print("\n=== Performance by Input Combination ===")
        for combo, stats in combo_summaries.items():
            print(f"{combo}: {stats['count']} samples — PSNR {stats['psnr_mean']:.2f}±{stats['psnr_std']:.2f}, SSIM {stats['ssim_mean']:.3f}±{stats['ssim_std']:.3f}, SAM {stats['sam_mean']:.2f}±{stats['sam_std']:.2f}")
    print(f"\n[INFO] Inference completed. Results saved to {results_file}")
    print(f"       Full 12-band TIFFs saved to {arrays_dir}")
    print(f"       Visualizations saved to {vis_dir}")

    # Print summary information
    print("\nInference Summary:")
    for k, v in summary.items():
        if isinstance(v, float):
            print(f"  {k}: {v:.6f}")
        else:
            print(f"  {k}: {v}")

# -------------------------------------------------------------
# 5. Command line entry & IDE debugging
# -------------------------------------------------------------
def parse_args():
    parser = argparse.ArgumentParser(
        description="Inference script for trained GAN generator (save 12-band TIFF)"
    )
    parser.add_argument("--model",      type=str, required=True,
                        help="Model name (MODEL_CLASS must be defined in Models/{model}.py)")
    parser.add_argument("--ckpt-path",  type=str, required=True,
                        help="Generator checkpoint path (can be entire model or generator_state_dict)")
    parser.add_argument("--data-dir",   type=str, required=True,
                        help="Data root directory, should contain 'test' subdirectory (consistent with training script)")
    parser.add_argument("--output-dir", type=str, default="inference_results",
                        help="Inference results output directory (contains TIFF and visualizations)")
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--device",     type=str, default="cuda:0",
                        help="Device to run inference (e.g., 'cuda:0' or 'cpu')")
    return parser.parse_args()

# infer-run

In [None]:
print("[INFO] Running in IDE debug mode with default settings")

class IDEArgs:
    model = "MultimodalUnetadap_all"
    ckpt_path = "/content/drive/MyDrive/Colabdata/checkpoints/MultimodalUnetadap_all/best_MultimodalUnetadap_epoch***.pth"
    data_dir = "/content/data/ReconstructionDataset_Final"
    output_dir = "/content/drive/MyDrive/Colabdata/inference_results/MultimodalUnetadap_all"
    batch_size = 64
    num_workers = 4
    device = "cuda:0"

args = IDEArgs()
main(args)