# Import

In [None]:
!pip install kornia
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/gdrive')

import os
import sys
import argparse
import time
from pathlib import Path
import matplotlib.pyplot as plt
import torch.nn.functional as F
import json
import kornia
from kornia.morphology import erosion

# Replace with the actual path to the directory containing model2.py
model_dir_in_drive = '/content/gdrive/MyDrive/ColabModel/'
# Add the directory to the Python path
if model_dir_in_drive not in sys.path:
    sys.path.append(model_dir_in_drive)
    print(f"Added {model_dir_in_drive} to sys.path")
else:
    print(f"{model_dir_in_drive} already in sys.path")

from MultimodalUnetGAN import UnetGAN


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
from tqdm import tqdm
from scipy.ndimage import binary_erosion
from skimage.metrics import structural_similarity as ssim

import importlib

# GPU optimization
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True


# Config

In [None]:
# -*- coding: utf-8 -*-

class MSELoss(nn.Module):
    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = alpha
        self.mse   = nn.MSELoss(reduction="none")

    def forward(self, outputs, targets, mask_s30):
        """
        outputs: Tensor[B, C, H, W]
        targets: Tensor[B, C, H, W]
        mask_s30:Tensor[B, 1, H, W], mask=1 indicates valid pixels
        """
        # Calculate per-pixel MSE
        loss_map = self.mse(outputs, targets)  # [B, C, H, W]
        # Expand mask to channel dimension
        mask_expand = mask_s30.expand_as(loss_map)  # [B, C, H, W]
        valid_pixels = torch.sum(mask_expand)  # Number of valid elements
        if valid_pixels > 0:
            main_loss = torch.sum(loss_map * mask_expand) / valid_pixels
        else:
            main_loss = torch.tensor(0.0, device=outputs.device)
        return self.alpha * main_loss

# ======= Modification: add warmup / transition parameters in parse_args =======
def parse_args(argv=None):
    """Argument parser for GAN training"""
    parser = argparse.ArgumentParser(
        description="GAN training script for multimodal image reconstruction with Warm-up/Transition strategy"
    )

    # Basic arguments
    parser.add_argument("--mode", type=str, choices=["train", "test"], default="train")
    parser.add_argument("--model", type=str, required=True, help="Model name (must have generator and discriminator)")
    parser.add_argument("--loss", type=str, required=True, help="Loss name for reconstruction")
    parser.add_argument("--data_dir", dest="data_dir", type=str, default="data", help="Data directory")
    parser.add_argument("--batch_size", dest="batch_size", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--ckpt_dir", dest="ckpt_dir", type=str, default="checkpoints")
    parser.add_argument("--log_dir", dest="log_dir", type=str, default="logs")
    parser.add_argument("--plot_dir", dest="plot_dir", type=str, default="plots")
    parser.add_argument("--resume", type=str, default=None)

    # GAN-specific arguments
    parser.add_argument("--gan_mode", dest="gan_mode", type=str, choices=["vanilla", "wgan-gp"], default="wgan-gp")
    parser.add_argument("--d_steps", dest="d_steps", type=int, default=5, help="Discriminator steps per generator step")
    parser.add_argument("--g_steps", dest="g_steps",type=int, default=1,help="Generator steps per discriminator step")
    parser.add_argument("--gp_weight", dest="gp_weight", type=float, default=10.0, help="Gradient penalty weight")
    parser.add_argument("--adv_weight", dest="adv_weight", type=float, default=0.01, help="Final adversarial loss weight for generator")
    parser.add_argument("--lr_g", dest="lr_g", type=float, default=1e-4, help="Generator learning rate")
    parser.add_argument("--lr_d", dest="lr_d", type=float, default=1e-4, help="Discriminator learning rate")
    parser.add_argument("--weight_decay", dest="weight_decay", type=float, default=1e-5)
    parser.add_argument("--use_amp", dest="use_amp", action="store_true", help="Use automatic mixed precision")

    # ======= Add warm-up and transition hyperparameters =======
    parser.add_argument("--warmup_epochs", dest="warmup_epochs", type=int, default=5,
                        help="Number of initial epochs using only reconstruction loss (adv_weight=0, don't update D)")
    parser.add_argument("--transition_epochs", dest="transition_epochs", type=int, default=10,
                        help="Number of epochs to linearly increase adv_weight from 0→adv_weight")

    return parser.parse_args(argv)

class OptimizedPatchDataset(Dataset):
    """Dataset class for loading .pt files with caching,
    automatically removing samples whose L30, S1, Planet mask all equal zero.
    """
    def __init__(self, data_dir, cache_size=50, map_location="cpu"):
        self.data_dir = Path(data_dir)
        all_files = sorted(self.data_dir.glob("sample_*.pt"))
        if not all_files:
            raise RuntimeError(f"[OptimizedPatchDataset] No sample_*.pt files found in {self.data_dir}")

        # Filter: only keep samples with at least one non-zero mask
        valid_files = []
        for p in all_files:
            sample = torch.load(p, map_location=map_location)
            m0 = sample.get('mask_l30')
            m1 = sample.get('mask_s1')
            m2 = sample.get('mask_planet')
            # Keep if any mask is not all zeros
            if not ((m0.sum()==0) and (m1.sum()==0) and (m2.sum()==0)):
                valid_files.append(p)
        if not valid_files:
            raise RuntimeError("[OptimizedPatchDataset] All samples have zero masks!")

        self.files = valid_files
        self.cache = {}
        self.cache_size = cache_size
        self.map_location = map_location

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if idx in self.cache:
            return self.cache[idx]
        sample = torch.load(self.files[idx], map_location=self.map_location)
        # Type adjustment & contiguous
        for key, value in sample.items():
            if isinstance(value, torch.Tensor):
                if value.dtype == torch.float64:
                    value = value.float()
                sample[key] = value.contiguous()
        # Cache
        if len(self.cache) < self.cache_size:
            self.cache[idx] = sample
        return sample


# Metric functions (remain unchanged)
def psnr_masked(img_ref, img_test, valid_mask, data_range=1.0):
    """Calculate PSNR on masked regions"""
    if img_ref.shape != img_test.shape:
        raise ValueError("[psnr_masked] Shape mismatch")
    if valid_mask.shape != img_ref.shape[1:]:
        raise ValueError(f"[psnr_masked] Mask shape mismatch")

    I = img_ref.astype(np.float64)
    K = img_test.astype(np.float64)
    num_valid_pixels = np.sum(valid_mask)
    if num_valid_pixels == 0:
        return np.nan

    squared_error = (I - K) ** 2
    mask_chw = np.expand_dims(valid_mask, axis=0).repeat(I.shape[0], axis=0)
    masked_squared = squared_error[mask_chw > 0]
    mse_masked = np.sum(masked_squared) / (num_valid_pixels * I.shape[0])
    if mse_masked <= 0:
        return float('inf')
    psnr_val = 10.0 * np.log10((data_range ** 2) / mse_masked)
    return float(psnr_val)

def ssim_eroded_mask(img_ref, img_test, valid_mask, max_val=1.0, **ssim_kwargs):
    """Calculate SSIM with eroded mask"""
    if img_ref.shape != img_test.shape:
        raise ValueError("[ssim_eroded_mask] Shape mismatch")

    C, H, W = img_ref.shape
    orig_mask = valid_mask > 0
    win_size = ssim_kwargs.pop('win_size', min(3, H, W))
    if win_size % 2 == 0:
        win_size -= 1
    if win_size < 3:
        return np.nan

    struct_el = np.ones((win_size, win_size), dtype=bool)
    core_mask = binary_erosion(orig_mask, structure=struct_el, border_value=0)
    if np.count_nonzero(core_mask) == 0:
        return np.nan

    ssim_vals = []
    for c in range(C):
        ref_chan = img_ref[c, :, :]
        test_chan = img_test[c, :, :]
        try:
            _, ssim_map = ssim(ref_chan, test_chan, data_range=max_val, full=True, **ssim_kwargs)
            ssim_vals.append(np.mean(ssim_map[core_mask]))
        except Exception as e:
            print(f"[ssim_eroded_mask] Warning: Channel {c} SSIM failed: {e}")
            ssim_vals.append(np.nan)

    return float(np.nanmean(ssim_vals))

def sam_masked(img_ref, img_test, valid_mask):

    """Calculate Spectral Angle Mapper on masked regions"""
    if img_ref.shape != img_test.shape:
        raise ValueError("[sam_masked] Shape mismatch")

    C, H, W = img_ref.shape
    I = img_ref.astype(np.float64)
    K = img_test.astype(np.float64)

    mask_bool = valid_mask > 0
    if not np.any(mask_bool):
        return np.nan

    I_flat = I[:, mask_bool]
    K_flat = K[:, mask_bool]

    dot = np.sum(I_flat * K_flat, axis=0)
    norm_I = np.linalg.norm(I_flat, axis=0)
    norm_K = np.linalg.norm(K_flat, axis=0)
    denom = norm_I * norm_K
    valid_idx = denom > 0
    if not np.any(valid_idx):
        return np.nan

    cos_theta = dot[valid_idx] / denom[valid_idx]
    cos_theta = np.clip(cos_theta, -1.0, 1.0)
    angles = np.arccos(cos_theta)
    sam_deg = np.degrees(angles)
    return float(np.mean(sam_deg))

def psnr_masked_gpu_batch(
    img_ref: torch.Tensor,   # shape (B, C, H, W)
    img_test: torch.Tensor,  # shape (B, C, H, W)
    valid_mask: torch.Tensor,# shape (B, 1, H, W) or (B, H, W)
    data_range: float = 1.0
) -> torch.Tensor:
    """
    Batch Masked PSNR on GPU. Returns a tensor of shape (B,).
    """
    B = img_ref.shape[0]

    # Ensure mask shape is (B, 1, H, W)
    if valid_mask.dim() == 3:
        valid_mask = valid_mask.unsqueeze(1)

    # Compute squared error
    se = (img_ref - img_test).pow(2)  # (B, C, H, W)

    # Count valid pixels per sample
    num_pixels = valid_mask.sum(dim=(1, 2, 3))  # (B,)
    num_pixels = num_pixels * img_ref.shape[1]  # multiply by channels

    # Mask and sum squared errors
    se_masked = se * valid_mask  # (B, C, H, W)
    mse = se_masked.sum(dim=(1, 2, 3)) / num_pixels.clamp(min=1)  # (B,)

    # Compute PSNR
    psnr = 10.0 * torch.log10((data_range ** 2) / mse.clamp(min=1e-10))

    # Handle invalid samples (no valid pixels)
    psnr = torch.where(num_pixels > 0, psnr, torch.tensor(float('nan'), device=img_ref.device))

    return psnr  # (B,)

def ssim_eroded_mask_gpu_batch(
    img_ref: torch.Tensor,
    img_test: torch.Tensor,
    valid_mask: torch.Tensor,
    window_size: int = 3,
    data_range: float = 1.0
) -> torch.Tensor:
    """
    Corrected version: uses the same Gaussian window parameters as kornia
    """
    B, C, H, W = img_ref.shape

    if valid_mask.dim() == 3:
        valid_mask = valid_mask.unsqueeze(1)
    valid_mask = valid_mask.float()

    # SSIM constants
    C1 = (0.01 * data_range) ** 2
    C2 = (0.03 * data_range) ** 2

    # Create Gaussian window - use the same parameters as kornia
    sigma = 1.5  # kornia's default value
    coords = torch.arange(window_size, device=img_ref.device, dtype=img_ref.dtype)
    coords = coords - (window_size - 1) / 2
    g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
    g = g / g.sum()
    window = g.unsqueeze(0) * g.unsqueeze(1)
    window = window.unsqueeze(0).unsqueeze(0)

    # Normalize window
    window = window / window.sum()

    # Expand to all channels
    window = window.expand(C, 1, window_size, window_size).contiguous()

    # Calculate local statistics
    pad = window_size // 2

    mu1 = F.conv2d(img_ref, window, padding=pad, groups=C)
    mu2 = F.conv2d(img_test, window, padding=pad, groups=C)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img_ref.pow(2), window, padding=pad, groups=C) - mu1_sq
    sigma2_sq = F.conv2d(img_test.pow(2), window, padding=pad, groups=C) - mu2_sq
    sigma12 = F.conv2d(img_ref * img_test, window, padding=pad, groups=C) - mu1_mu2

    # Calculate SSIM
    numerator = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
    denominator = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    ssim_map = numerator / denominator

    # Average across channel dimension
    ssim_map = ssim_map.mean(dim=1, keepdim=True)

    # Erode mask
    kernel = torch.ones((window_size, window_size), device=valid_mask.device)
    mask_eroded = erosion(valid_mask, kernel, border_type='constant', border_value=0.0)
    mask_eroded = (mask_eroded > 0.5).float()

    # Calculate masked mean
    valid_pixels = mask_eroded.sum(dim=(1, 2, 3))
    ssim_masked = ssim_map * mask_eroded
    ssim_sum = ssim_masked.sum(dim=(1, 2, 3))

    ssim_mean = torch.where(
        valid_pixels > 0,
        ssim_sum / valid_pixels,
        torch.tensor(float('nan'), device=img_ref.device)
    )

    return ssim_mean

def sam_masked_gpu_batch(
    img_ref: torch.Tensor,    # (B, C, H, W)
    img_pred: torch.Tensor,   # (B, C, H, W)
    valid_mask: torch.Tensor, # (B, 1, H, W) or (B, H, W)
    eps: float = 1e-8
) -> torch.Tensor:
    """
    Batch Masked Spectral Angle Mapper on GPU. Returns a tensor of shape (B,) in degrees.
    """
    B, C, H, W = img_ref.shape

    # Ensure mask shape is (B, H, W)
    if valid_mask.dim() == 4:
        valid_mask = valid_mask.squeeze(1)
    elif valid_mask.dim() == 2:
        valid_mask = valid_mask.unsqueeze(0).expand(B, -1, -1)

    # Reshape for batch processing
    img_ref_flat = img_ref.view(B, C, -1)  # (B, C, H*W)
    img_pred_flat = img_pred.view(B, C, -1)
    mask_flat = valid_mask.view(B, -1) > 0  # (B, H*W)

    # Compute spectral angles for all pixels
    dot_product = (img_ref_flat * img_pred_flat).sum(dim=1)  # (B, H*W)
    norm_ref = img_ref_flat.norm(dim=1)  # (B, H*W)
    norm_pred = img_pred_flat.norm(dim=1)

    # Avoid division by zero
    denominator = (norm_ref * norm_pred).clamp(min=eps)
    cos_theta = (dot_product / denominator).clamp(-1.0, 1.0)

    # Convert to angles in degrees
    angles = torch.acos(cos_theta) * (180.0 / torch.pi)  # (B, H*W)

    # Apply mask and compute mean for each sample
    sam_mean = torch.zeros(B, device=img_ref.device)
    for b in range(B):
        valid_angles = angles[b][mask_flat[b]]
        if valid_angles.numel() > 0:
            sam_mean[b] = valid_angles.mean()
        else:
            sam_mean[b] = float('nan')

    return sam_mean  # (B,)


# Model and loss loading functions (remain unchanged)
def get_model_class(model_name: str):
    """Dynamically load model class"""
    module_path = f"Models.{model_name}"
    try:
        module = importlib.import_module(module_path)
    except ImportError as e:
        raise ImportError(f"[get_model_class] Cannot import {module_path}: {e}")

    if hasattr(module, "MODEL_CLASS"):
        return getattr(module, "MODEL_CLASS")
    else:
        raise AttributeError(f"[get_model_class] {module_path}.py must define MODEL_CLASS")

def get_loss_class(loss_name: str):
    """Dynamically load loss class"""
    module_path = f"Losses.{loss_name}"
    try:
        module = importlib.import_module(module_path)
    except ImportError as e:
        raise ImportError(f"[get_loss_class] Cannot import {module_path}: {e}")

    if hasattr(module, "LOSS_CLASS"):
        return getattr(module, "LOSS_CLASS")
    else:
        raise AttributeError(f"[get_loss_class] {module_path}.py must define LOSS_CLASS")
def init_weights(m):
    """Initialize model weights."""
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
def GenerateMask(x,dim=1):
    mask = (x != 0).float().sum(dim=dim)
    mask = mask.unsqueeze(1)
    return mask
def compute_gradient_penalty(discriminator, real_samples, fake_samples, mask_s30,device):
    """
    Modified version: compute gradient penalty only for S30 images
    """
    batch_size = real_samples[0].shape[0]
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)

    # Interpolate only for S30 (as it is the generation target)
    l30, s1, planet, s30_real = real_samples
    _, _, _, s30_fake = fake_samples

    s30_interp = alpha * s30_real + (1 - alpha) * s30_fake
    s30_interp.requires_grad_(True)

    # Get discriminator output
    # mask_l30, mask_s1, mask_planet, mask_s30 = masks
    # mask_s30=GenerateMask(s30_real,1)
    disc_interp = discriminator(l30,s1,planet,s30_interp, mask_s30)

    # Handle multi-scale output
    if isinstance(disc_interp, list):
        # Average gradient penalty across all scales
        gradient_penalty = 0
        for scale_output in disc_interp:
            gradients = torch.autograd.grad(
                outputs=scale_output,
                inputs=s30_interp,
                grad_outputs=torch.ones_like(scale_output),
                create_graph=True,
                retain_graph=True,
                only_inputs=True
            )[0]

            gradients = gradients.view(batch_size, -1)
            gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
            gradient_penalty += gp

        gradient_penalty = gradient_penalty / len(disc_interp)
    else:
        # Single-scale discriminator
        gradients = torch.autograd.grad(
            outputs=disc_interp,
            inputs=s30_interp,
            grad_outputs=torch.ones_like(disc_interp),
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]

        gradients = gradients.view(batch_size, -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

    return gradient_penalty
def save_epoch_metrics(history, epoch, log_dir,log_name):
    """Save all metrics for the current epoch"""
    # Create dictionary with current epoch metrics
    current_metrics = {
        "epoch": epoch + 1,
        "train_g_loss": history["train_g_loss"][-1],
        "train_d_loss": history["train_d_loss"][-1],
        "train_rec_loss": history["train_rec_loss"][-1],
        "train_adv_loss": history["train_adv_loss"][-1],
        "val_rec_loss": history["val_rec_loss"][-1],
        "val_d_real": history["val_d_real"][-1],
        "val_d_fake": history["val_d_fake"][-1],
        "lr_g": history["lr_g"][-1],
        "lr_d": history["lr_d"][-1]
    }

    # Continuously update a single CSV file (easier to import into analysis tools)
    csv_file = log_dir / f"training_metrics_{log_name}.csv"

    # If file doesn't exist, create and write header
    if not csv_file.exists():
        with open(csv_file, 'w') as f:
            header = ",".join(current_metrics.keys())
            f.write(f"{header}\n")

    # Append current epoch values
    with open(csv_file, 'a') as f:
        values = ",".join(str(v) for v in current_metrics.values())
        f.write(f"{values}\n")

#Main

In [None]:
def main(args):
    # Device selection
    if args.gpu >= 0 and torch.cuda.is_available():
        device = torch.device(f"cuda:{args.gpu}")
    else:
        device = torch.device("cpu")
    print(f"[INFO] Using device: {device}")

    # Create directories
    ckpt_dir = Path(args.ckpt_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    log_dir = Path(args.log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    plot_dir = Path(args.plot_dir)
    plot_dir.mkdir(parents=True, exist_ok=True)

    model = UnetGAN(use_meta=False, use_selfattention=True,
                     use_spatial_attention=True).to(device)#True,False

    if not (hasattr(model, 'generator') and hasattr(model, 'discriminator')):
        raise AttributeError(f"Model '{args.model}' must have 'generator' and 'discriminator' attributes")
    generator = model.generator
    discriminator = model.discriminator

    # Count parameters
    g_params = sum(p.numel() for p in generator.parameters() if p.requires_grad)
    d_params = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
    print(f"[INFO] Generator parameters: {g_params:,}")
    print(f"[INFO] Discriminator parameters: {d_params:,}")

    # Create optimizers
    optimizer_g = optim.AdamW(generator.parameters(), lr=args.lr_g, weight_decay=args.weight_decay)
    optimizer_d = optim.AdamW(discriminator.parameters(), lr=args.lr_d, weight_decay=args.weight_decay)

    # Learning rate schedulers
    # scheduler_g = optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode="min", factor=0.5, patience=5)
    # scheduler_d = optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode="min", factor=0.5, patience=5)

    # Loss functions
    criterion_rec = MSELoss()  # Reconstruction loss
    if args.gan_mode == "vanilla":
        criterion_adv = nn.BCEWithLogitsLoss()

    # AMP setup
    use_amp = args.use_amp and device.type.startswith("cuda")
    scaler_g = torch.amp.GradScaler() if use_amp else None
    scaler_d = torch.amp.GradScaler() if use_amp else None
    # Prepare datasets
    train_dir = Path(args.data_dir) / "train"
    val_dir = Path(args.data_dir) / "val"
    test_dir = Path(args.data_dir) / "test"

    for d in (train_dir, val_dir, test_dir):
        if not d.exists():
            raise RuntimeError(f"[ERROR] Directory '{d}' does not exist")

    train_dataset = OptimizedPatchDataset(train_dir, cache_size=100)
    val_dataset = OptimizedPatchDataset(val_dir, cache_size=100)
    test_dataset = OptimizedPatchDataset(test_dir, cache_size=100)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)

    print(f"[INFO] Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    # Learning rate schedulers
    per_epoch_iteration = len(train_dataset) // args.batch_size
    total_iteration_g = per_epoch_iteration*(args.warmup_epochs+(args.epochs-args.warmup_epochs)*args.g_steps)
    total_iteration_d = per_epoch_iteration*(args.epochs-args.warmup_epochs)*args.d_steps

    scheduler_g = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_g, total_iteration_g, eta_min=1e-6)
    scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_d, total_iteration_d, eta_min=1e-6)

    # Load checkpoint if specified
    start_epoch = 0
    best_val_loss = float("inf")

    # Training history for plotting
    history = {
        'train_g_loss': [],
        'train_rec_loss': [],      # New addition
        'train_adv_loss': [],
        'train_d_loss': [],
        'train_d_real': [],
        'train_d_fake': [],
        'train_gp': [],
        'val_rec_loss': [],
        'val_d_real': [],
        'val_d_fake': [],
        'lr_g': [],
        'lr_d': []
    }

    if args.resume:
        if os.path.isfile(args.resume):
            ckpt = torch.load(args.resume, map_location=device, weights_only=False)
            generator.load_state_dict(ckpt["generator_state_dict"])
            discriminator.load_state_dict(ckpt["discriminator_state_dict"])
            optimizer_g.load_state_dict(ckpt["optimizer_g_state_dict"])
            optimizer_d.load_state_dict(ckpt["optimizer_d_state_dict"])
            scheduler_g.load_state_dict(ckpt["scheduler_g_state_dict"])
            scheduler_d.load_state_dict(ckpt["scheduler_d_state_dict"])
            start_epoch = ckpt.get("epoch", 0)
            best_val_loss = ckpt.get("best_val_loss", best_val_loss)
            history = ckpt.get("history", history)
            print(f"[INFO] Resumed from epoch {start_epoch}")

    # ======= Modification: add Warm-up/Transition logic in train_one_epoch =======
    def train_one_epoch(epoch_idx):
        """GAN training for one epoch, with Warm-up→Transition→Full-Adversarial"""
        generator.train()
        discriminator.train()

        # Calculate current epoch adversarial loss weight adv_w
        # 1) warmup phase: adv_weight_current = 0
        # 2) transition phase: adv_weight_current linearly from 0 to args.adv_weight
        # 3) full phase: adv_weight_current = args.adv_weight
        if epoch_idx < args.warmup_epochs:
            adv_weight_current = 0.0
            do_discriminator = False      # Don't update D during warmup
        elif epoch_idx < args.warmup_epochs + args.transition_epochs:
            # Transition phase, linear interpolation
            alpha = (epoch_idx - args.warmup_epochs + 1) / args.transition_epochs
            adv_weight_current = args.adv_weight * min(alpha, 1.0)
            do_discriminator = True
        else:
            adv_weight_current = args.adv_weight
            do_discriminator = True

        epoch_metrics = {
            'g_loss': [], 'd_loss': [], 'd_real': [],
            'd_fake': [], 'gp': [], 'rec_loss': [],
            'train_adv_loss':[]
        }#,'train_rec_loss':[]

        pbar = tqdm(train_loader, desc=f"Epoch {epoch_idx+1}/{args.epochs}")

        for batch_idx, batch in enumerate(pbar):
            # Move data to device
            l30_img = batch["l30_img"].to(device, non_blocking=True)
            # mask_l30 = batch["mask_l30"].to(device, non_blocking=True)
            l30_meta = batch["l30_meta"].to(device, non_blocking=True)
            s1_img = batch["s1_img"].to(device, non_blocking=True)
            # mask_s1 = batch["mask_s1"].to(device, non_blocking=True)
            s1_meta = batch["s1_meta"].to(device, non_blocking=True)
            planet_img = batch["planet_img"].to(device, non_blocking=True)
            # mask_planet = batch["mask_planet"].to(device, non_blocking=True)
            planet_meta = batch["planet_meta"].to(device, non_blocking=True)
            s30_gt = batch["s30_img_gt"].to(device, non_blocking=True)
            mask_s30 = batch["mask_s30"].to(device, non_blocking=True)

            batch_size = l30_img.shape[0]

            # =====================
            # 1) If in Warm-up phase, only train Generator with reconstruction, don't update discriminator D
            # =====================
            if not do_discriminator:
                optimizer_g.zero_grad()

                # Only calculate reconstruction loss
                fake_s30 = generator(
                    l30_img, l30_meta,
                    s1_img,  s1_meta,
                    planet_img, planet_meta
                )
                loss_rec = criterion_rec(fake_s30, s30_gt, mask_s30)
                ssims=ssim_eroded_mask_gpu_batch(s30_gt, fake_s30, mask_s30).mean()
                sams=sam_masked_gpu_batch(s30_gt, fake_s30, mask_s30).mean()
                psnrs=psnr_masked_gpu_batch(s30_gt, fake_s30, mask_s30).mean()

                loss_ssim  = (1 - ssims)   # Structure loss
                loss_sam   = sams/180         # Spectral angle loss
                loss_psnr  = 1.0 - psnrs/50.0
                loss_rec = 1*loss_rec + 0.5*loss_ssim + 0.5*loss_sam + 0.2 * loss_psnr

                loss_g = loss_rec  # adv_weight_current == 0, so adversarial loss is 0

                if use_amp:
                    scaler_g.scale(loss_g).backward()
                    scaler_g.unscale_(optimizer_g)
                    torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
                    scaler_g.step(optimizer_g)
                    scaler_g.update()
                else:
                    loss_g.backward()
                    torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
                    optimizer_g.step()
                scheduler_g.step()
                # Record metrics
                epoch_metrics['g_loss'].append(loss_g.item())
                epoch_metrics['d_loss'].append(0.0)
                epoch_metrics['d_real'].append(0.0)
                epoch_metrics['d_fake'].append(0.0)
                epoch_metrics['gp'].append(0.0)
                epoch_metrics['rec_loss'].append(loss_rec.item())
                epoch_metrics['train_adv_loss'].append(loss_g.item()-loss_rec.item())
                # epoch_metrics['train_rec_loss'].append(train_metrics['rec_loss'])

                pbar.set_postfix({
                    'G(rec)': f"{loss_rec.item():.3e}",
                    'advW': f"{adv_weight_current:.2e}"
                })
                continue  # Skip to next batch

            # =====================
            # 2) If in Transition / Full phase, first update discriminator D, then update generator G (with dynamic adv_weight_current)
            # =====================

            # ---------------------
            # 2.1 Train Discriminator
            # ---------------------
            d_losses_step = []
            d_real_scores = []
            d_fake_scores = []
            gp_values = []

            for _ in range(args.d_steps):
                optimizer_d.zero_grad()

                # Generate fake images (detach)
                with torch.no_grad():
                    fake_s30 = generator(
                        l30_img,  l30_meta,
                        s1_img,  s1_meta,
                        planet_img,  planet_meta
                    )
                    mask_ext = mask_s30.expand(-1, fake_s30.size(1), -1, -1)
                    fake_s30 = fake_s30 * mask_ext
                if use_amp:
                    with torch.amp.autocast(device_type=device.type):
                        # Discriminator output on real samples
                        d_real = discriminator(
                            l30_img, s1_img,
                            planet_img, s30_gt, mask_s30
                        )
                        # Discriminator output on fake samples
                        d_fake = discriminator(
                            l30_img,  s1_img,
                            planet_img, fake_s30.detach(), mask_s30
                        )

                        # Calculate D loss for multiple scales
                        if isinstance(d_real, list):
                            loss_d_real = sum(-torch.mean(dr) for dr in d_real) / len(d_real)
                            loss_d_fake = sum(torch.mean(df) for df in d_fake) / len(d_fake)
                            d_real_score = sum(dr.mean().item() for dr in d_real) / len(d_real)
                            d_fake_score = sum(df.mean().item() for df in d_fake) / len(d_fake)
                        else:
                            loss_d_real = -torch.mean(d_real)
                            loss_d_fake = torch.mean(d_fake)
                            d_real_score = d_real.mean().item()
                            d_fake_score = d_fake.mean().item()

                        if args.gan_mode == "wgan-gp":
                            # Calculate gradient penalty
                            gp = compute_gradient_penalty(
                                discriminator,
                                [l30_img, s1_img, planet_img, s30_gt],
                                [l30_img, s1_img, planet_img, fake_s30.detach()],mask_s30,
                                device
                            )
                            gp = torch.clamp(gp, 0, 50)  # Prevent gradient explosion
                            loss_d = loss_d_real + loss_d_fake + args.gp_weight * gp
                            gp_values.append(gp.item())
                        else:
                            loss_d = loss_d_real + loss_d_fake

                    scaler_d.scale(loss_d).backward()
                    scaler_d.unscale_(optimizer_d)
                    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
                    scaler_d.step(optimizer_d)
                    scaler_d.update()
                else:
                    d_real = discriminator(
                        l30_img,  s1_img,
                        planet_img, s30_gt, mask_s30
                    )
                    d_fake = discriminator(
                        l30_img,  s1_img,
                        planet_img, fake_s30.detach(), mask_s30
                    )

                    if isinstance(d_real, list):
                        loss_d_real = sum(-torch.mean(dr) for dr in d_real) / len(d_real)
                        loss_d_fake = sum(torch.mean(df) for df in d_fake) / len(d_fake)
                        d_real_score = sum(dr.mean().item() for dr in d_real) / len(d_real)
                        d_fake_score = sum(df.mean().item() for df in d_fake) / len(d_fake)
                    else:
                        loss_d_real = -torch.mean(d_real)
                        loss_d_fake = torch.mean(d_fake)
                        d_real_score = d_real.mean().item()
                        d_fake_score = d_fake.mean().item()

                    if args.gan_mode == "wgan-gp":
                        gp = compute_gradient_penalty(
                            discriminator,
                            [l30_img, s1_img, planet_img, s30_gt],
                            [l30_img, s1_img, planet_img, fake_s30.detach()],mask_s30,
                            device
                        )
                        gp = torch.clamp(gp, 0, 50)  # Prevent gradient explosion
                        loss_d = loss_d_real + loss_d_fake + args.gp_weight * gp
                        gp_values.append(gp.item())
                    else:
                        loss_d = loss_d_real + loss_d_fake

                    loss_d.backward()
                    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
                    optimizer_d.step()
                scheduler_d.step()
                d_losses_step.append(loss_d.item())
                d_real_scores.append(d_real_score)
                d_fake_scores.append(d_fake_score)

            # Calculate average D metrics
            avg_d_loss = np.mean(d_losses_step)
            avg_d_real = np.mean(d_real_scores)
            avg_d_fake = np.mean(d_fake_scores)
            avg_gp = np.mean(gp_values) if gp_values else 0

            # ---------------------
            # 2.2 Train Generator
            # ---------------------
            rec_losses, adv_losses, total_losses = [], [], []
            for _ in range(args.g_steps):
                optimizer_g.zero_grad()
                fake_s30 = generator(
                    l30_img, l30_meta,
                    s1_img, s1_meta,
                    planet_img, planet_meta
                )

                loss_rec = criterion_rec(fake_s30, s30_gt, mask_s30)

                ssims=ssim_eroded_mask_gpu_batch(s30_gt, fake_s30, mask_s30).mean()
                sams=sam_masked_gpu_batch(s30_gt, fake_s30, mask_s30).mean()
                psnrs=psnr_masked_gpu_batch(s30_gt, fake_s30, mask_s30).mean()

                loss_ssim  = (1 - ssims)   # Structure loss
                loss_sam   = sams/180         # Spectral angle loss
                loss_psnr  = 1.0 - psnrs/50.0
                loss_rec = 1*loss_rec + 0.5*loss_ssim + 0.5*loss_sam + 0.2 * loss_psnr

                d_fake_for_g = discriminator(
                    l30_img,  s1_img,
                    planet_img,  fake_s30, mask_s30
                )
                if isinstance(d_fake_for_g, list):
                    loss_adv = sum(-df.mean() for df in d_fake_for_g) / len(d_fake_for_g)
                else:
                    loss_adv = -d_fake_for_g.mean()

                loss_g = loss_rec + adv_weight_current * loss_adv
                if use_amp:
                    scaler_g.scale(loss_g).backward()
                    scaler_g.unscale_(optimizer_g)
                    torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
                    scaler_g.step(optimizer_g)
                    scaler_g.update()
                else:
                    loss_g.backward()
                    torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
                    optimizer_g.step()
                scheduler_g.step()
                rec_losses.append(loss_rec.item())
                adv_losses.append((adv_weight_current * loss_adv).item())
                total_losses.append(loss_g.item())

            # Average metrics from g_steps updates as this batch's metrics
            mean_rec   = sum(rec_losses)   / len(rec_losses)
            mean_adv   = sum(adv_losses)   / len(adv_losses)
            mean_total = sum(total_losses) / len(total_losses)

            # Record this batch's metrics
            epoch_metrics['g_loss'].append(mean_total)
            epoch_metrics['rec_loss'].append(mean_rec)
            epoch_metrics['train_adv_loss'].append(mean_adv)
            epoch_metrics['d_loss'].append(avg_d_loss)
            epoch_metrics['d_real'].append(avg_d_real)
            epoch_metrics['d_fake'].append(avg_d_fake)
            epoch_metrics['gp'].append(avg_gp)


            # Update progress bar info
            pbar.set_postfix({
                'G': f"{mean_total:.3e}",
                'G(adv)': f"{mean_adv:.3e}",
                'G(rec)': f"{mean_rec:.3e}",
                'D': f"{avg_d_loss:.3e}",
                'Dr': f"{avg_d_real:.3f}",
                'Df': f"{avg_d_fake:.3f}",
                'advW': f"{adv_weight_current:.2e}"

            })

        # Return this epoch's average metrics
        return {k: np.mean(v) for k, v in epoch_metrics.items()}

    # validate function remains unchanged
    def validate(epoch_idx):
        """Validation with reconstruction loss and discriminator scores"""
        generator.eval()
        # discriminator.eval()
        discriminator.eval()
        val_metrics = {
            'rec_loss': [], 'd_real': [], 'd_fake': []
        }

        if epoch_idx < args.warmup_epochs:
            do_discriminator = False      # Don't update D during warmup
        else:
            do_discriminator = True

        with torch.no_grad():
            pbar = tqdm(val_loader, desc="[Validation]")
            for batch in pbar:
                # Move data to device
                l30_img = batch["l30_img"].to(device, non_blocking=True)
                # mask_l30 = batch["mask_l30"].to(device, non_blocking=True)
                l30_meta = batch["l30_meta"].to(device, non_blocking=True)
                s1_img = batch["s1_img"].to(device, non_blocking=True)
                # mask_s1 = batch["mask_s1"].to(device, non_blocking=True)
                s1_meta = batch["s1_meta"].to(device, non_blocking=True)
                planet_img = batch["planet_img"].to(device, non_blocking=True)
                # mask_planet = batch["mask_planet"].to(device, non_blocking=True)
                planet_meta = batch["planet_meta"].to(device, non_blocking=True)
                s30_gt = batch["s30_img_gt"].to(device, non_blocking=True)
                mask_s30 = batch["mask_s30"].to(device, non_blocking=True)

                # Generate fake images
                fake_s30 = generator(
                    l30_img,  l30_meta,
                    s1_img,  s1_meta,
                    planet_img,  planet_meta
                )
                batch_size=mask_s30.shape[0]
                # Reconstruction loss
                rec_loss = criterion_rec(fake_s30, s30_gt, mask_s30)

                ssims=ssim_eroded_mask_gpu_batch(s30_gt, fake_s30, mask_s30).mean()
                sams=sam_masked_gpu_batch(s30_gt, fake_s30, mask_s30).mean()
                psnrs=psnr_masked_gpu_batch(s30_gt, fake_s30, mask_s30).mean()

                loss_ssim  = (1 - ssims)   # Structure loss
                loss_sam   = sams/180         # Spectral angle loss
                loss_psnr  = 1.0 - psnrs/50.0
                rec_loss = 1*rec_loss + 0.5*loss_ssim + 0.5*loss_sam + 0.2 * loss_psnr

                val_metrics['rec_loss'].append(rec_loss.item())
                if do_discriminator == False:
                    d_real_score=0
                    d_fake_score=0
                    val_metrics['d_real'].append(d_real_score)
                    val_metrics['d_fake'].append(d_fake_score)
                # Discriminator scores
                elif do_discriminator == True:
                    d_real = discriminator(
                        l30_img, s1_img, planet_img, s30_gt, mask_s30
                    )
                    d_fake = discriminator(
                        l30_img, s1_img,planet_img, fake_s30, mask_s30
                    )

                    if isinstance(d_real, list):
                        d_real_score = sum(dr.mean().item() for dr in d_real) / len(d_real)
                        d_fake_score = sum(df.mean().item() for df in d_fake) / len(d_fake)
                    else:
                        d_real_score = d_real.mean().item()
                        d_fake_score = d_fake.mean().item()

                    val_metrics['d_real'].append(d_real_score)
                    val_metrics['d_fake'].append(d_fake_score)

                pbar.set_postfix({
                    'rec': f"{rec_loss.item():.3e}",
                    'Dr': f"{d_real_score:.3f}",
                    'Df': f"{d_fake_score:.3f}"
                })
        # discriminator.eval()
        return {k: np.mean(v) for k, v in val_metrics.items()}

    # test_and_evaluate remains unchanged
    def test_and_evaluate():
        """Test evaluation with comprehensive metrics"""
        generator.eval()

        mse_list, rmse_list, psnr_list, ssim_list, sam_list = [], [], [], [], []
        num_valid_samples = 0

        save_samples = True
        num_samples_to_save = 10
        saved_cnt = 0
        sample_dir = log_dir / "test_samples"
        if save_samples:
            sample_dir.mkdir(parents=True, exist_ok=True)

        with torch.no_grad():
            pbar = tqdm(test_loader, desc="[Test]")
            for batch_idx, batch in enumerate(pbar):
                # Move data to device
                l30_img = batch["l30_img"].to(device, non_blocking=True)
                # mask_l30 = batch["mask_l30"].to(device, non_blocking=True)
                l30_meta = batch["l30_meta"].to(device, non_blocking=True)
                s1_img = batch["s1_img"].to(device, non_blocking=True)
                # mask_s1 = batch["mask_s1"].to(device, non_blocking=True)
                s1_meta = batch["s1_meta"].to(device, non_blocking=True)
                planet_img = batch["planet_img"].to(device, non_blocking=True)
                # mask_planet = batch["mask_planet"].to(device, non_blocking=True)
                planet_meta = batch["planet_meta"].to(device, non_blocking=True)
                s30_gt = batch["s30_img_gt"].to(device, non_blocking=True)
                mask_s30 = batch["mask_s30"].to(device, non_blocking=True)

                # Generate predictions
                outputs = generator(
                    l30_img,  l30_meta,
                    s1_img,  s1_meta,
                    planet_img,  planet_meta
                )
                mask_ext = mask_s30.expand(-1, outputs.size(1), -1, -1)
                outputs = outputs * mask_ext  # Apply mask to outputs here!
                # outputs = outputs * mask_s30
                pred_np = outputs.cpu().numpy()
                gt_np = s30_gt.cpu().numpy()
                mask_np = mask_s30.cpu().numpy().squeeze(1)

                B = pred_np.shape[0]
                for i in range(B):
                    pred_i = pred_np[i]
                    gt_i = gt_np[i]
                    mask_i = mask_np[i]

                    # Compute metrics
                    mse_map = (pred_i - gt_i) ** 2
                    num_valid_pixels = np.sum(mask_i) * pred_i.shape[0]
                    if num_valid_pixels > 0:
                        mse_val = np.sum(mse_map * mask_i[np.newaxis, ...]) / num_valid_pixels
                    else:
                        mse_val = np.nan
                    rmse_val = np.sqrt(mse_val) if not np.isnan(mse_val) else np.nan

                    psnr_val = psnr_masked(gt_i, pred_i, mask_i, data_range=1.0)
                    ssim_val = ssim_eroded_mask(gt_i, pred_i, mask_i, max_val=1.0)
                    sam_val = sam_masked(gt_i, pred_i, mask_i)

                    if not np.isnan(mse_val):
                        mse_list.append(mse_val)
                        rmse_list.append(rmse_val)
                        psnr_list.append(psnr_val)
                        ssim_list.append(ssim_val)
                        sam_list.append(sam_val)
                        num_valid_samples += 1

                    # Save visualizations
                    if save_samples and saved_cnt < num_samples_to_save:
                        # Create RGB visualization
                        if pred_i.shape[0] >= 3:
                            pred_rgb = np.clip(pred_i[:3].transpose(1, 2, 0), 0, 1)
                            gt_rgb = np.clip(gt_i[:3].transpose(1, 2, 0), 0, 1)
                        else:
                            pred_rgb = np.repeat(pred_i[0:1].transpose(1, 2, 0), 3, axis=2)
                            gt_rgb = np.repeat(gt_i[0:1].transpose(1, 2, 0), 3, axis=2)

                        # Create difference map
                        diff = np.abs(pred_rgb - gt_rgb)

                        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                        axes[0].imshow(gt_rgb)
                        axes[0].set_title("Ground Truth")
                        axes[0].axis("off")
                        axes[1].imshow(pred_rgb)
                        axes[1].set_title(f"Prediction (PSNR: {psnr_val:.2f} dB)")
                        axes[1].axis("off")
                        axes[2].imshow(diff, cmap='hot')
                        axes[2].set_title("Absolute Difference")
                        axes[2].axis("off")
                        plt.tight_layout()
                        plt.savefig(sample_dir / f"sample_{batch_idx:04d}_{i:02d}.png")
                        plt.close(fig)
                        saved_cnt += 1

                if mse_list:
                    pbar.set_postfix({
                        "MSE": f"{np.mean(mse_list):.6f}",
                        "PSNR": f"{np.mean(psnr_list):.2f}",
                        "SSIM": f"{np.mean(ssim_list):.4f}"
                    })

        # Final results
        results = {
            "mse_mean": float(np.mean(mse_list)),
            "mse_std": float(np.std(mse_list)),
            "rmse_mean": float(np.mean(rmse_list)),
            "rmse_std": float(np.std(rmse_list)),
            "psnr_mean": float(np.mean(psnr_list)),
            "psnr_std": float(np.std(psnr_list)),
            "ssim_mean": float(np.mean(ssim_list)),
            "ssim_std": float(np.std(ssim_list)),
            "sam_mean": float(np.mean(sam_list)),
            "sam_std": float(np.std(sam_list)),
            "num_valid_samples": num_valid_samples
        }

        # Save results
        results_file = log_dir / "test_results.txt"
        with open(results_file, "w") as f:
            f.write("===== Test Set Evaluation =====\n")
            f.write(f"MSE:  {results['mse_mean']:.6f} ± {results['mse_std']:.6f}\n")
            f.write(f"RMSE: {results['rmse_mean']:.6f} ± {results['rmse_std']:.6f}\n")
            f.write(f"PSNR: {results['psnr_mean']:.4f} ± {results['psnr_std']:.4f} dB\n")
            f.write(f"SSIM: {results['ssim_mean']:.4f} ± {results['ssim_std']:.4f}\n")
            f.write(f"SAM:  {results['sam_mean']:.4f} ± {results['sam_std']:.4f} degrees\n")
            f.write(f"Valid Samples: {results['num_valid_samples']}\n")

        print(f"\n[INFO] Test results saved to {results_file}")
        for k, v in results.items():
            if isinstance(v, float):
                print(f"{k}: {v:.6f}")
            else:
                print(f"{k}: {v}")

        return results

    # Test mode
    if args.mode == "test":
        print("[INFO] Running test evaluation only")
        if args.resume is None:
            raise RuntimeError("Test mode requires --resume checkpoint")
        test_results = test_and_evaluate()
        return

    # Training loop
    print("\n[INFO] Starting GAN training...")

    for epoch in range(start_epoch, args.epochs):
        epoch_start = time.time()

        # Train (will do Warm-up / Transition / Full based on epoch internally)
        train_metrics = train_one_epoch(epoch)

        # Validate
        val_metrics = validate(epoch)

        # Update learning rates (only use validation reconstruction loss as metric)
        # scheduler_g.step(val_metrics['rec_loss'])
        # scheduler_d.step(val_metrics['rec_loss'])

        # Record history
        history['train_g_loss'].append(train_metrics['g_loss'])
        history['train_adv_loss'].append(train_metrics['train_adv_loss'])
        history['train_rec_loss'].append(train_metrics['rec_loss'])
        history['train_d_loss'].append(train_metrics['d_loss'])
        history['train_d_real'].append(train_metrics['d_real'])
        history['train_d_fake'].append(train_metrics['d_fake'])
        history['train_gp'].append(train_metrics['gp'])
        history['val_rec_loss'].append(val_metrics['rec_loss'])
        history['val_d_real'].append(val_metrics['d_real'])
        history['val_d_fake'].append(val_metrics['d_fake'])
        history['lr_g'].append(optimizer_g.param_groups[0]['lr'])
        history['lr_d'].append(optimizer_d.param_groups[0]['lr'])

        # Print epoch summary
        elapsed = time.time() - epoch_start
        print(f"\n[Epoch {epoch+1}/{args.epochs}] Time: {elapsed:.1f}s")
        print(f"  Train - G: {train_metrics['g_loss']:.4f}, D: {train_metrics['d_loss']:.4f}, "
              f"D_real: {train_metrics['d_real']:.3f}, D_fake: {train_metrics['d_fake']:.3f}, "
              f"advW: {('%.3e' % (train_metrics['g_loss'] - train_metrics['rec_loss'])) if epoch >= args.warmup_epochs else '0.0e+00'}")
        print(f"  Val   - Rec: {val_metrics['rec_loss']:.4f}, "
              f"D_real: {val_metrics['d_real']:.3f}, D_fake: {val_metrics['d_fake']:.3f}")

        # Save best checkpoint if improved
        if val_metrics['rec_loss'] < best_val_loss:
            best_val_loss = val_metrics['rec_loss']
            ckpt_path = ckpt_dir / f"best_epoch{epoch+1}_val{best_val_loss:.4f}.pth"
            torch.save({
                "epoch": epoch + 1,
                "generator_state_dict": generator.state_dict(),
                "discriminator_state_dict": discriminator.state_dict(),
                "optimizer_g_state_dict": optimizer_g.state_dict(),
                "optimizer_d_state_dict": optimizer_d.state_dict(),
                "scheduler_g_state_dict": scheduler_g.state_dict(),
                "scheduler_d_state_dict": scheduler_d.state_dict(),
                "best_val_loss": best_val_loss,
                "history": history
            }, ckpt_path)
            print(f"  Saved best checkpoint to {ckpt_path}")

        # Save regular checkpoint every 10 epochs
        # if (epoch + 1) % 10 == 0:
        #     ckpt_path = ckpt_dir / f"checkpoint_epoch{epoch+1}.pth"
        #     torch.save({
        #         "epoch": epoch + 1,
        #         "generator_state_dict": generator.state_dict(),
        #         "discriminator_state_dict": discriminator.state_dict(),
        #         "optimizer_g_state_dict": optimizer_g.state_dict(),
        #         "optimizer_d_state_dict": optimizer_d.state_dict(),
        #         "scheduler_g_state_dict": scheduler_g.state_dict(),
        #         "scheduler_d_state_dict": scheduler_d.state_dict(),
        #         "best_val_loss": best_val_loss,
        #         "history": history
        #     }, ckpt_path)

        # Plot training curves every 5 epochs
        if (epoch + 1) % 10 == 0:
            plot_training_curves(history, plot_dir, epoch + 1)
        save_epoch_metrics(history, epoch, log_dir,args.log_name)
    print(f"\n[INFO] Training complete. Best validation loss: {best_val_loss:.6f}")

    # Final test evaluation
    print("\n" + "="*60)
    print("[INFO] Running final test evaluation...")
    print("="*60)
    test_results = test_and_evaluate()

    # Save final plots
    plot_training_curves(history, plot_dir, args.epochs, final=True)

    # Save training history
    history_file = log_dir / "training_history.json"
    with open(history_file, 'w') as f:
        json.dump(history, f, indent=2)
    print(f"\n[INFO] Training history saved to {history_file}")


def plot_training_curves(history, plot_dir, epoch, final=False):
    """Create comprehensive training plots for GAN (remains unchanged from original)"""
    plt.style.use('seaborn-v0_8-darkgrid')

    # Create figure with subplots
    fig = plt.figure(figsize=(20, 15))

    # 1. Generator and Discriminator Loss
    ax1 = plt.subplot(3, 3, 1)
    ax1.plot(history['train_g_loss'], label='Generator Loss', color='blue', linewidth=2)
    ax1.plot(history['train_d_loss'], label='Discriminator Loss', color='red', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Generator vs Discriminator Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # 2. Discriminator Scores on Real/Fake
    ax2 = plt.subplot(3, 3, 2)
    ax2.plot(history['train_d_real'], label='D(real) - Train', color='green', linewidth=2)
    ax2.plot(history['train_d_fake'], label='D(fake) - Train', color='orange', linewidth=2)
    ax2.plot(history['val_d_real'], label='D(real) - Val', color='darkgreen', linestyle='--')
    ax2.plot(history['val_d_fake'], label='D(fake) - Val', color='darkorange', linestyle='--')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Discriminator Score')
    ax2.set_title('Discriminator Scores (Higher=Real, Lower=Fake)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # 3. Reconstruction Loss
    ax3 = plt.subplot(3, 3, 3)
    ax3.plot(history['train_rec_loss'], label='Train Rec Loss', color='orange', linewidth=2)
    ax3.plot(history['train_adv_loss'], label='Train Adv Loss', color='blue', linewidth=2)
    ax3.plot(history['val_rec_loss'], label='Validation Rec Loss', color='purple', linewidth=2)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Reconstruction Loss')
    ax3.set_title('Rec/Adv Loss')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    # 4. Gradient Penalty (if WGAN-GP)
    ax4 = plt.subplot(3, 3, 4)
    if history['train_gp'] and any(history['train_gp']):
        ax4.plot(history['train_gp'], label='Gradient Penalty', color='brown', linewidth=2)
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('GP Value')
        ax4.set_title('WGAN-GP Gradient Penalty')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
    else:
        ax4.text(0.5, 0.5, 'N/A for Vanilla GAN', ha='center', va='center', transform=ax4.transAxes)
        ax4.set_title('Gradient Penalty')

    # 5. Learning Rates
    ax5 = plt.subplot(3, 3, 5)
    ax5.plot(history['lr_g'], label='Generator LR', color='blue', linewidth=2)
    ax5.plot(history['lr_d'], label='Discriminator LR', color='red', linewidth=2)
    ax5.set_xlabel('Epoch')
    ax5.set_ylabel('Learning Rate')
    ax5.set_title('Learning Rate Schedule')
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    ax5.set_yscale('log')

    # 6. Wasserstein Distance Estimate (D_real - D_fake)
    ax6 = plt.subplot(3, 3, 6)
    w_dist_train = [r - f for r, f in zip(history['train_d_real'], history['train_d_fake'])]
    w_dist_val = [r - f for r, f in zip(history['val_d_real'], history['val_d_fake'])]
    ax6.plot(w_dist_train, label='Train', color='blue', linewidth=2)
    ax6.plot(w_dist_val, label='Val', color='red', linewidth=2)
    ax6.set_xlabel('Epoch')
    ax6.set_ylabel('D(real) - D(fake)')
    ax6.set_title('Wasserstein Distance Estimate')
    ax6.legend()
    ax6.grid(True, alpha=0.3)

    # 7. Loss Ratio G/D
    ax7 = plt.subplot(3, 3, 7)
    loss_ratio = [g/d if d != 0 else 0 for g, d in zip(history['train_g_loss'], history['train_d_loss'])]
    ax7.plot(loss_ratio, label='G/D Loss Ratio', color='magenta', linewidth=2)
    ax7.set_xlabel('Epoch')
    ax7.set_ylabel('Ratio')
    ax7.set_title('Generator/Discriminator Loss Ratio')
    ax7.axhline(y=1, color='k', linestyle='--', alpha=0.5)
    ax7.legend()
    ax7.grid(True, alpha=0.3)

    # 8. Discriminator Accuracy (simplified)
    ax8 = plt.subplot(3, 3, 8)
    train_acc_real = [1 if score > 0 else 0 for score in history['train_d_real']]
    train_acc_fake = [1 if score < 0 else 0 for score in history['train_d_fake']]
    train_acc = [(r + f) / 2 for r, f in zip(train_acc_real, train_acc_fake)]

    val_acc_real = [1 if score > 0 else 0 for score in history['val_d_real']]
    val_acc_fake = [1 if score < 0 else 0 for score in history['val_d_fake']]
    val_acc = [(r + f) / 2 for r, f in zip(val_acc_real, val_acc_fake)]

    ax8.plot(train_acc, label='Train Accuracy', color='blue', linewidth=2)
    ax8.plot(val_acc, label='Val Accuracy', color='red', linewidth=2)
    ax8.set_xlabel('Epoch')
    ax8.set_ylabel('Accuracy')
    ax8.set_title('Discriminator Classification Accuracy')
    ax8.set_ylim([0, 1.1])
    ax8.legend()
    ax8.grid(True, alpha=0.3)

    # 9. Loss Components (Log Scale)
    ax9 = plt.subplot(3, 3, 9)
    ax9.plot(history['train_g_loss'], label='G Loss', color='blue', linewidth=2, alpha=0.7)
    ax9.plot(history['train_d_loss'], label='D Loss', color='red', linewidth=2, alpha=0.7)
    ax9.plot(history['val_rec_loss'], label='Val Rec Loss', color='green', linewidth=2, alpha=0.7)
    ax9.set_xlabel('Epoch')
    ax9.set_ylabel('Loss (Log Scale)')
    ax9.set_title('All Losses (Log Scale)')
    ax9.set_yscale('log')
    ax9.legend()
    ax9.grid(True, alpha=0.3)

    plt.suptitle(f'GAN Training Progress - Epoch {epoch}', fontsize=16)
    plt.tight_layout()

    # Save plot
    if final:
        plot_path = plot_dir / "training_curves_final.png"
    else:
        plot_path = plot_dir / f"training_curves_epoch{epoch}.png"

    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()

#Copy

In [None]:
import shutil
import os

src = "/content/gdrive/MyDrive/ColabData/ReconstructionDataset_Final"
dst = "/content/data/ReconstructionDataset_Final"

if os.path.exists(dst):
    shutil.rmtree(dst)
shutil.copytree(src, dst)

print("Copied files:", os.listdir(dst))

# Train

In [None]:
print("[INFO] Running in IDE debug mode with default settings")

class IDEArgs:
    mode = "train"
    model = "MultimodalUnetGAN"
    loss = "mse_loss"
    data_dir = "/content/data/ReconstructionDataset_Final"
    batch_size = 64
    epochs = 300
    gpu = 0
    ckpt_dir = "/content/gdrive/MyDrive/Colabdata/checkpoints/"+model
    log_dir = "/content/gdrive/MyDrive/Colabdata/logs/"+model
    log_name    = model
    plot_dir = "/content/gdrive/MyDrive/Colabdata/plots/"+model
    resume = None

    gan_mode = "wgan-gp"
    d_steps = 5
    g_steps = 2
    gp_weight = 10
    adv_weight = 0.01
    lr_g = 2e-4
    lr_d = 1e-4
    weight_decay = 1e-5
    use_amp = True

    warmup_epochs = 10
    transition_epochs = 10

args = IDEArgs()
main(args)

# infer-import

In [None]:
!pip install rasterio
import sys
from google.colab import drive

drive.mount('/content/gdrive')
# Replace with the actual path to the directory containing model2.py
model_dir_in_drive = '/content/gdrive/MyDrive/ColabModel/'
# Add the directory to the Python path
if model_dir_in_drive not in sys.path:
    sys.path.append(model_dir_in_drive)
    print(f"Added {model_dir_in_drive} to sys.path")
else:
    print(f"{model_dir_in_drive} already in sys.path")

from MultimodalUnetGAN import UnetGAN
import pandas as pd
import ast
import os
import argparse
from pathlib import Path
import json
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader

from scipy.ndimage import binary_erosion
from skimage.metrics import structural_similarity as ssim

import importlib


import rasterio
from rasterio.transform import Affine
from rasterio.errors import NotGeoreferencedWarning
import warnings

warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)

# infer-config

In [None]:
# -------------------------------------------------------------
# 1. Define dataset class (consistent with OptimizedPatchDataset in training script)
# -------------------------------------------------------------
class OptimizedPatchDataset(Dataset):
    """Dataset class for loading .pt files with caching,
    automatically removing samples whose L30, S1, Planet mask all equal zero.
    """
    def __init__(self, data_dir, cache_size=50, map_location="cpu"):
        self.data_dir = Path(data_dir)
        all_files = sorted(self.data_dir.glob("sample_*.pt"))
        if not all_files:
            raise RuntimeError(f"[OptimizedPatchDataset] No sample_*.pt files found in {self.data_dir}")

        # Filter: only keep samples with at least one non-zero mask
        valid_files = []
        for p in all_files:
            sample = torch.load(p, map_location=map_location)
            m0 = sample.get('mask_l30')
            m1 = sample.get('mask_s1')
            m2 = sample.get('mask_planet')
            # Keep if any mask is not all zeros
            if not ((m0.sum()==0) and (m1.sum()==0) and (m2.sum()==0)):
                valid_files.append(p)
        if not valid_files:
            raise RuntimeError("[OptimizedPatchDataset] All samples have zero masks!")

        self.files = valid_files
        self.cache = {}
        self.cache_size = cache_size
        self.map_location = map_location

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if idx in self.cache:
            return self.cache[idx]
        sample = torch.load(self.files[idx], map_location=self.map_location)
        # Type adjustment & contiguous
        for key, value in sample.items():
            if isinstance(value, torch.Tensor):
                if value.dtype == torch.float64:
                    value = value.float()
                sample[key] = value.contiguous()
        # Cache
        sample["__path__"] = str(self.files[idx])
        if len(self.cache) < self.cache_size:
            self.cache[idx] = sample
        return sample
def pad_to_multiple(tensor, mult=16):
    """
    Pad a tensor with shape (..., H, W) so that H and W are both multiples of mult.
    Pad zeros on the bottom and right sides.
    """
    *_, H, W = tensor.shape
    pad_h = (mult - H % mult) % mult
    pad_w = (mult - W % mult) % mult
    if pad_h == 0 and pad_w == 0:
        return tensor
    # Pad format is (pad_left, pad_right, pad_top, pad_bottom)
    return F.pad(tensor, (0, pad_w, 0, pad_h))
class MultimodalDataset(Dataset):
    def __init__(self,
                 csv_path,
                 root_dir,
                 bands_l30=11,
                 bands_s1=3,
                 bands_planet=7,
                 meta_dim=11):
        self.df = pd.read_csv(csv_path, dtype=str)
        self.root = root_dir
        self.bands_l30    = bands_l30
        self.bands_s1     = bands_s1
        self.bands_planet = bands_planet
        self.meta_dim     = meta_dim

    def __len__(self):
        return len(self.df)

    def _parse_list(self, cell):
        txt = (cell or '').strip()
        if txt.startswith('[') and txt.endswith(']'):
            txt = txt[1:-1]
        return [x.strip() for x in txt.split(';') if x.strip()]

    def _read_raster(self, folder, fname):
        path = os.path.join(self.root, folder, fname)
        with rasterio.open(path) as src:
            arr = src.read().astype(np.float32)
        return torch.from_numpy(arr)

    def _read_mask(self, folder, fname):
        path = os.path.join(self.root, folder, fname)
        with rasterio.open(path) as src:
            m = src.read(1).astype(np.uint8)
        return torch.from_numpy(m[None]).float()

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        date = row['date']

        # Parse all lists
        s30_files      = self._parse_list(row.get('S30', '[]'))
        mask_s30_files = self._parse_list(row.get('S30_mask', '[]'))
        l30_files      = self._parse_list(row.get('L30', '[]'))
        mask_l30_files = self._parse_list(row.get('L30_mask', '[]'))
        s1_files       = self._parse_list(row.get('S1', '[]'))
        mask_s1_files  = self._parse_list(row.get('S1_mask', '[]'))
        planet_files       = self._parse_list(row.get('Planet', '[]'))
        mask_planet_files  = self._parse_list(row.get('Planet_mask', '[]'))

        # —— 1) Read S30 GT & mask ——
        if s30_files:
            s30_img_gt  = self._read_raster('S30',      s30_files[0])
        else:
            # Fallback: all zeros
            # Assume we know H,W first
            raise RuntimeError(f"No S30 for row {idx}")
        if mask_s30_files:
            mask_s30 = self._read_mask('S30_mask', mask_s30_files[0])
        else:
            # Similarly
            mask_s30 = torch.zeros(1, s30_img_gt.shape[1], s30_img_gt.shape[2])

        # Get H,W for other defaults
        _, H, W = s30_img_gt.shape

        # —— 2) L30 ——
        if l30_files:
            l30_img   = self._read_raster('L30',    l30_files[0])
            mask_l30  = self._read_mask( 'L30_mask', mask_l30_files[0])
        else:
            l30_img  = torch.zeros(self.bands_l30, H, W)
            mask_l30 = torch.zeros(1,           H, W)

        # —— 3) S1 ——
        if s1_files:
            s1_img  = self._read_raster('S1',     s1_files[0])
            mask_s1 = self._read_mask( 'S1_mask', mask_s1_files[0])
        else:
            s1_img  = torch.zeros(self.bands_s1, H, W)
            mask_s1 = torch.zeros(1,           H, W)

        # —— 4) Planet ——
        if planet_files:
            planet_img   = self._read_raster('Planet',    planet_files[0])
            mask_planet  = self._read_mask( 'Planet_mask',mask_planet_files[0])
        else:
            planet_img  = torch.zeros(self.bands_planet, H, W)
            mask_planet = torch.zeros(1,                H, W)

        # —— 5) Meta all zeros ——
        l30_meta    = torch.zeros(self.meta_dim)
        s1_meta     = torch.zeros(self.meta_dim)
        planet_meta = torch.zeros(self.meta_dim)
        s30_fname = s30_files[0]
        sample = {
           'l30_img':      l30_img,
           'mask_l30':     mask_l30,
           'l30_meta':     l30_meta,
           's1_img':       s1_img,
           'mask_s1':      mask_s1,
           's1_meta':      s1_meta,
           'planet_img':   planet_img,
           'mask_planet':  mask_planet,
           'planet_meta':  planet_meta,
           's30_img_gt':   s30_img_gt,
           'mask_s30':     mask_s30,
           's30_fname':   s30_fname
        }

        # —— Uniformly pad all images and masks here ——
        for k, v in sample.items():
            # Only pad Tensors with shape like (C,H,W) or (1,H,W)
            if isinstance(v, torch.Tensor) and v.dim() >= 3:
                sample[k] = pad_to_multiple(v, mult=16)

        return sample
# -------------------------------------------------------------
# 2. Define metric functions (same version as used in training)
# -------------------------------------------------------------

def psnr_masked(img_ref, img_test, valid_mask, data_range=1.0):
    """Calculate masked PSNR"""
    if img_ref.shape != img_test.shape:
        raise ValueError("[psnr_masked] Shape mismatch")
    if valid_mask.shape != img_ref.shape[1:]:
        raise ValueError(f"[psnr_masked] Mask shape mismatch")
    I = img_ref.astype(np.float64)
    K = img_test.astype(np.float64)
    num_valid = np.sum(valid_mask)
    if num_valid == 0:
        return float("nan")
    sq_err = (I - K) ** 2
    mask_chw = np.expand_dims(valid_mask, axis=0).repeat(I.shape[0], axis=0)
    masked_sq = sq_err[mask_chw > 0]
    mse = np.sum(masked_sq) / (num_valid * I.shape[0])
    if mse <= 0:
        return float("inf")
    psnr_val = 10.0 * np.log10((data_range ** 2) / mse)
    return float(psnr_val)

def ssim_eroded_mask(img_ref, img_test, valid_mask, max_val=1.0, **ssim_kwargs):
    """Calculate SSIM with eroded mask"""
    if img_ref.shape != img_test.shape:
        raise ValueError("[ssim_eroded_mask] Shape mismatch")
    C, H, W = img_ref.shape
    orig = valid_mask > 0
    win_size = ssim_kwargs.pop('win_size', min(3, H, W))
    if win_size % 2 == 0:
        win_size -= 1
    if win_size < 3:
        return float("nan")
    struct_el = np.ones((win_size, win_size), dtype=bool)
    core = binary_erosion(orig, structure=struct_el, border_value=0)
    if np.count_nonzero(core) == 0:
        return float("nan")
    ssim_vals = []
    for c in range(C):
        ref_c = img_ref[c, :, :]
        test_c = img_test[c, :, :]
        try:
            _, ssim_map = ssim(
                ref_c, test_c,
                data_range=max_val,
                full=True,
                **ssim_kwargs
            )
            ssim_vals.append(np.mean(ssim_map[core]))
        except Exception:
            ssim_vals.append(float("nan"))
    return float(np.nanmean(ssim_vals))

def sam_masked(img_ref, img_test, valid_mask):
    """Calculate masked SAM (Spectral Angle Mapper)"""
    if img_ref.shape != img_test.shape:
        raise ValueError("[sam_masked] Shape mismatch")
    C, H, W = img_ref.shape
    I = img_ref.astype(np.float64)
    K = img_test.astype(np.float64)
    mask_bool = valid_mask > 0
    if not np.any(mask_bool):
        return float("nan")
    I_flat = I[:, mask_bool]
    K_flat = K[:, mask_bool]
    dot = np.sum(I_flat * K_flat, axis=0)
    norm_I = np.linalg.norm(I_flat, axis=0)
    norm_K = np.linalg.norm(K_flat, axis=0)
    denom = norm_I * norm_K
    valid_idx = denom > 0
    if not np.any(valid_idx):
        return float("nan")
    cos_theta = dot[valid_idx] / denom[valid_idx]
    cos_theta = np.clip(cos_theta, -1.0, 1.0)
    angles = np.arccos(cos_theta)
    sam_deg = np.degrees(angles)
    return float(np.mean(sam_deg))

# -------------------------------------------------------------
# 3. Dynamically load ModelClass
# -------------------------------------------------------------
# Assume pred is a NumPy array with shape (12, H, W), dtype float32
# Similarly, gt is the ground truth array with shape (12, H, W)
def save_multiband_tif_rasterio(array_3d, output_path):
    """
    Use rasterio to save array_3d with shape (C, H, W) as a 12-band GeoTIFF (without geographic coordinate information)
    """
    C, H, W = array_3d.shape
    assert C <= 2**16, "Number of bands should not be too large!"
    # No geographic reference, Affine.identity is equivalent to pixel-coordinate identity matrix
    transform = Affine.identity()

    # Open a new write handle
    with rasterio.open(
        output_path,
        'w',
        driver='GTiff',
        height=H,
        width=W,
        count=C,            # Number of bands
        dtype=array_3d.dtype,
        crs=None,           # No coordinate system
        transform=transform
    ) as dst:
        # Write band by band
        for band_idx in range(C):
            dst.write(array_3d[band_idx], band_idx + 1)
def GenerateMask(x,dim=1):
    mask = (x != 0).float().sum(dim=dim)
    mask = mask.unsqueeze(1)
    return mask

#infer-main

In [None]:
# -------------------------------------------------------------
# 4. Inference main function
# -------------------------------------------------------------
def main(args):
    # Select device
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Using device: {device}")

    # Create output directories
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    vis_dir = output_dir / "visualizations"
    vis_dir.mkdir(parents=True, exist_ok=True)
    arrays_dir = output_dir / "arrays"
    arrays_dir.mkdir(parents=True, exist_ok=True)

    # ------------------------------------------------------------------
    # 4.1 Load model
    # ------------------------------------------------------------------
    # model = UnetGAN(use_meta=True,use_selfattention=True).to(device)#False,True
    model = UnetGAN(use_meta=True, use_selfattention=True,
                    use_spatial_attention=True).to(device)#True,False
    # Check if has generator attribute
    if not hasattr(model, "generator"):
        raise AttributeError(f"[ERROR] Model '{args.model}' does not have attribute 'generator'")
    generator = model.generator

    # Load checkpoint
    if not os.path.isfile(args.ckpt_path):
        raise FileNotFoundError(f"[ERROR] Checkpoint not found: {args.ckpt_path}")
    ckpt = torch.load(args.ckpt_path, map_location=device,weights_only=False)
    if "generator_state_dict" in ckpt:
        generator.load_state_dict(ckpt["generator_state_dict"])
    else:
        generator.load_state_dict(ckpt)
    generator.eval()
    print(f"[INFO] Loaded checkpoint from {args.ckpt_path}")

    # ------------------------------------------------------------------
    # 4.2 Prepare test dataset DataLoader
    # ------------------------------------------------------------------
    test_dir = Path(args.data_dir) / "test"
    if not test_dir.exists():
        raise RuntimeError(f"[ERROR] Test directory '{test_dir}' does not exist")
    test_dataset = OptimizedPatchDataset(test_dir, cache_size=500, map_location="cpu")
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )

    # merged_base =  Path(args.data_dir)
    # # r'C:/Users/TongYu/Desktop/processed_merged'
    # csv_path    = os.path.join(merged_base, 'processed_image_info_merged.csv')

    # test_dataset = MultimodalDataset(
    #     csv_path=csv_path,
    #     root_dir=merged_base,
    #     bands_l30=11,
    #     bands_s1=3,
    #     bands_planet=7,
    #     meta_dim=11
    # )
    # test_loader = DataLoader(
    #     test_dataset,
    #     batch_size=args.batch_size,
    #     shuffle=False,
    #     num_workers=args.num_workers,
    #     pin_memory=True)


    # ------------------------------------------------------------------
    # 4.3 Inference loop: run inference on all samples in test_loader and calculate metrics
    # ------------------------------------------------------------------
    results = []
    total_mse, total_rmse, total_psnr, total_ssim, total_sam = [], [], [], [], []
    metrics_by_combo = defaultdict(list)

    with torch.no_grad():
        pbar = tqdm(test_loader, desc="[Inference]")
        for batch_idx, batch in enumerate(pbar):
            # Move data to device
            l30_img = batch["l30_img"].to(device, non_blocking=True)
            # mask_l30 = batch["mask_l30"].to(device, non_blocking=True)
            l30_meta = batch["l30_meta"].to(device, non_blocking=True)
            s1_img   = batch["s1_img"].to(device, non_blocking=True)
            # mask_s1  = batch["mask_s1"].to(device, non_blocking=True)
            s1_meta  = batch["s1_meta"].to(device, non_blocking=True)
            planet_img = batch["planet_img"].to(device, non_blocking=True)
            # mask_planet = batch["mask_planet"].to(device, non_blocking=True)
            planet_meta = batch["planet_meta"].to(device, non_blocking=True)
            s30_gt   = batch["s30_img_gt"].to(device, non_blocking=True)
            mask_s30 = batch["mask_s30"].to(device, non_blocking=True)

            # Forward inference
            fake_s30 = generator(
                l30_img, l30_meta,
                s1_img, s1_meta,
                planet_img, planet_meta
            )
            #********************************
            # input_union = (mask_l30 > 0) | (mask_s1 > 0) | (mask_planet > 0)
            # # 2) Intersect with S30 mask
            # joint_mask = input_union & (mask_s30 > 0)

            # # 3) Expand to (B, C_out, H, W) and convert to float
            # mask_ext = joint_mask.float().expand(-1, fake_s30.size(1), -1, -1)

            # # 4) Crop output with intersection mask
            # fake_s30 = fake_s30 * mask_ext
            # s30_gt = s30_gt * mask_ext
            #********************************
            mask_ext = mask_s30.expand(-1, fake_s30.size(1), -1, -1)
            fake_s30 = fake_s30 * mask_ext

            # Convert to numpy array for metric calculation and saving
            fake_np = fake_s30.detach().cpu().numpy()  # shape (B, C=12, H, W)
            gt_np   = s30_gt.detach().cpu().numpy()    # shape (B, C=12, H, W)
            mask_np = mask_s30.detach().cpu().numpy().squeeze(1)  # (B, H, W)

            mask_l30_np=GenerateMask(l30_img,dim=1).squeeze(1)
            mask_s1_np=GenerateMask(s1_img,dim=1).squeeze(1)
            mask_planet_np=GenerateMask(planet_img,dim=1).squeeze(1)

            # mask_l30_np    = mask_l30.detach().cpu().numpy().squeeze(1)
            # mask_s1_np     = mask_s1.detach().cpu().numpy().squeeze(1)
            # mask_planet_np = mask_planet.detach().cpu().numpy().squeeze(1)

            B = fake_np.shape[0]
            for i in range(B):
                pred = fake_np[i]  # (12, H, W)
                gt   = gt_np[i]    # (12, H, W)
                mask = mask_np[i]  # (H, W)

                # Per-pixel MSE / RMSE
                sq_err = (pred - gt) ** 2
                num_valid = np.sum(mask) * pred.shape[0]
                if num_valid > 0:
                    mse_val = np.sum(sq_err * mask[np.newaxis, ...]) / num_valid
                else:
                    mse_val = float("nan")
                rmse_val = np.sqrt(mse_val) if not np.isnan(mse_val) else float("nan")

                # PSNR / SSIM / SAM
                psnr_val = psnr_masked(gt, pred, mask, data_range=1.0)
                ssim_val = ssim_eroded_mask(gt, pred, mask, max_val=1.0)
                sam_val  = sam_masked(gt, pred, mask)
                # Group metrics by input combination
                combo = []
                if mask_l30_np[i].any():    combo.append("L30")
                if mask_s1_np[i].any():     combo.append("S1")
                if mask_planet_np[i].any(): combo.append("Planet")
                # combo_key = "+".join(combo) if combo else "None"
                if combo:
                    combo_key = "+".join(combo)
                else:
                    combo_key = "None"
                metrics_by_combo[combo_key].append({"mse": mse_val, "rmse": rmse_val, "psnr": psnr_val, "ssim": ssim_val, "sam": sam_val})

                total_mse.append(mse_val)
                total_rmse.append(rmse_val)
                total_psnr.append(psnr_val)
                total_ssim.append(ssim_val)
                total_sam.append(sam_val)

                # Construct sample name: use original .pt filename (without extension)
                full_path = batch["__path__"][i]
                sample_name = Path(full_path).stem

                # # batch['s30_fname'][i] example: 'F1_5_8_2022_S30.tif'
                # stem = Path(batch['s30_fname'][i]).stem
                # # Remove '_S30' suffix to get 'F1_5_8_2022'
                # if stem.endswith('_S30'):
                #     sample_name = stem[:-len('_S30')]
                # else:
                #     sample_name = stem

                # Save complete 12-band array as TIFF
                pred_save_path = arrays_dir / f"{sample_name}_pred.tif"
                gt_save_path   = arrays_dir / f"{sample_name}_gt.tif"
                save_multiband_tif_rasterio(pred, pred_save_path)
                save_multiband_tif_rasterio(gt,   gt_save_path)

                # Save metric information
                results.append({
                    "sample": sample_name,
                    "mse": float(mse_val),
                    "rmse": float(rmse_val),
                    "psnr": float(psnr_val),
                    "ssim": float(ssim_val),
                    "sam": float(sam_val),
                })

                # Visualize and save: RGB bands use 2,1,0
                if pred.shape[0] >= 3:
                    pred_rgb = np.clip(pred[[2,1,0]].transpose(1, 2, 0), 0, 1)
                    gt_rgb   = np.clip(gt[[2,1,0]].transpose(1, 2, 0), 0, 1)
                else:
                    pred_rgb = np.repeat(pred[0:1].transpose(1, 2, 0), 3, axis=2)
                    gt_rgb   = np.repeat(gt[0:1].transpose(1, 2, 0), 3, axis=2)
                diff = np.abs(pred_rgb - gt_rgb)

                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                axes[0].imshow(gt_rgb)
                axes[0].set_title("Ground Truth (bands 2,1,0)")
                axes[0].axis("off")
                axes[1].imshow(pred_rgb)
                axes[1].set_title(f"Prediction (PSNR: {psnr_val:.2f} dB)")
                axes[1].axis("off")
                axes[2].imshow(diff, cmap='hot')
                axes[2].set_title("Absolute Difference")
                axes[2].axis("off")
                plt.tight_layout()

                # Construct visualization save path
                vis_save_path = vis_dir / f"{sample_name}.png"
                plt.savefig(vis_save_path, dpi=200, bbox_inches='tight')
                plt.close(fig)

            # Update progress bar
            pbar.set_postfix({
                "MSE": f"{np.nanmean(total_mse):.6f}",
                "PSNR": f"{np.nanmean(total_psnr):.2f}",
                "SSIM": f"{np.nanmean(total_ssim):.4f}"
            })

    # ------------------------------------------------------------------
    # 4.4 Aggregate and save metrics
    # ------------------------------------------------------------------
    summary = {
        "num_samples": len(total_mse),
        "mse_mean": float(np.nanmean(total_mse)),
        "mse_std":  float(np.nanstd(total_mse)),
        "rmse_mean": float(np.nanmean(total_rmse)),
        "rmse_std":  float(np.nanstd(total_rmse)),
        "psnr_mean": float(np.nanmean(total_psnr)),
        "psnr_std":  float(np.nanstd(total_psnr)),
        "ssim_mean": float(np.nanmean(total_ssim)),
        "ssim_std":  float(np.nanstd(total_ssim)),
        "sam_mean":  float(np.nanmean(total_sam)),
        "sam_std":   float(np.nanstd(total_sam)),
    }

    # Save all sample metrics to JSON
    results_file = output_dir / "inference_results.json"
    with open(results_file, "w") as f:
        json.dump({
            "summary": summary,
            "details": results
        }, f, indent=2)
        # Performance summary by input combination
        combo_summaries = {}
        for combo_key, records in metrics_by_combo.items():
            arr = np.array([[r["psnr"], r["ssim"], r["sam"], r["mse"], r["rmse"]] for r in records])
            combo_summaries[combo_key] = {
                "count": len(records),
                "psnr_mean": float(np.nanmean(arr[:,0])), "psnr_std": float(np.nanstd(arr[:,0])),
                "ssim_mean": float(np.nanmean(arr[:,1])), "ssim_std": float(np.nanstd(arr[:,1])),
                "sam_mean": float(np.nanmean(arr[:,2])),   "sam_std": float(np.nanstd(arr[:,2])),
                "mse_mean": float(np.nanmean(arr[:,3])),   "mse_std": float(np.nanstd(arr[:,3])),
                "rmse_mean": float(np.nanmean(arr[:,4])),  "rmse_std": float(np.nanstd(arr[:,4])),
            }
        full_summary = {"summary": summary, "by_input_combo": combo_summaries}
        with open(output_dir/"inference_summary_by_combo.json", "w") as f2:
            json.dump(full_summary, f2, indent=2)
        print("\n=== Performance by Input Combination ===")
        for combo, stats in combo_summaries.items():
            print(f"{combo}: {stats['count']} samples — PSNR {stats['psnr_mean']:.2f}±{stats['psnr_std']:.2f}, SSIM {stats['ssim_mean']:.3f}±{stats['ssim_std']:.3f}, SAM {stats['sam_mean']:.2f}±{stats['sam_std']:.2f}")
    print(f"\n[INFO] Inference completed. Results saved to {results_file}")
    print(f"       Full 12-band TIFFs saved to {arrays_dir}")
    print(f"       Visualizations saved to {vis_dir}")

    # Print summary information
    print("\nInference Summary:")
    for k, v in summary.items():
        if isinstance(v, float):
            print(f"  {k}: {v:.6f}")
        else:
            print(f"  {k}: {v}")

# -------------------------------------------------------------
# 5. Command line entry & IDE debugging
# -------------------------------------------------------------
def parse_args():
    parser = argparse.ArgumentParser(
        description="Inference script for trained GAN generator (save 12-band TIFF)"
    )
    parser.add_argument("--model",      type=str, required=True,
                        help="Model name (MODEL_CLASS must be defined in Models/{model}.py)")
    parser.add_argument("--ckpt-path",  type=str, required=True,
                        help="Generator checkpoint path (can be entire model or generator_state_dict)")
    parser.add_argument("--data-dir",   type=str, required=True,
                        help="Data root directory, should contain 'test' subdirectory (consistent with training script)")
    parser.add_argument("--output-dir", type=str, default="inference_results",
                        help="Inference results output directory (contains TIFF and visualizations)")
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--device",     type=str, default="cuda:0",
                        help="Device to run inference (e.g., 'cuda:0' or 'cpu')")
    return parser.parse_args()

#infer-run

In [None]:
print("[INFO] Running in IDE debug mode with default settings")

class IDEArgs:
    model = "MultimodalUnetGAN"
    ckpt_path = "/content/gdrive/MyDrive/Colabdata/checkpoints/MultimodalUnetGAN_adap_all/best_epoch****.pth"
    data_dir = "/content/data/ReconstructionDataset_Final"
    output_dir = "/content/gdrive/MyDrive/Colabdata/inference_results/"+model
    batch_size = 64
    num_workers = 4
    device = "cuda:0"

args = IDEArgs()
main(args)