In [None]:
import gc
gc.collect()

In [None]:
################
import torch
import numpy as np
from PIL import Image
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error
from skimage.metrics import structural_similarity as ssim
import os
import time
import rasterio
from scipy.ndimage import gaussian_filter
import matplotlib
matplotlib.rcParams['font.family'] = 'DejaVu Sans'

In [None]:
##BI

In [None]:
import numpy as np
from scipy.ndimage import zoom, gaussian_filter

def upsample_bicubic_and_metrics(
    lr_data, hr_data, scale=None,
    save_tif_path=None, reference_tif_path=None
):
    
    H_hr, W_hr = hr_data.shape
    H_lr, W_lr = lr_data.shape
    if scale is None:
        zx, zy = H_hr / H_lr, W_hr / W_lr
    else:
        zx = zy = float(scale)

    up = zoom(lr_data, zoom=(zx, zy), order=3, mode='reflect', prefilter=True)
    up = up[:H_hr, :W_hr]
    if up.shape != (H_hr, W_hr):  
        pad_h, pad_w = H_hr - up.shape[0], W_hr - up.shape[1]
        if pad_h > 0 or pad_w > 0:
            up = np.pad(up, ((0, max(0, pad_h)), (0, max(0, pad_w)))
                        , mode='edge')

    up = up.astype(np.float32, copy=False)
    hr = hr_data.astype(np.float32, copy=False)

    valid = np.isfinite(up) & np.isfinite(hr)
    
    diff = up - hr
    mae = float(np.mean(np.abs(diff)[valid]))
    rmse = float(np.sqrt(np.mean((diff[valid])**2)))

    up_f = up.copy(); hr_f = hr.copy()
    up_f[~valid] = np.nanmean(up[valid])
    hr_f[~valid] = np.nanmean(hr[valid])

    L = float(np.nanmax(hr_f) - np.nanmin(hr_f))
    if not np.isfinite(L) or L == 0.0:
        L = 1.0
    K1, K2 = 0.01, 0.03
    C1, C2 = (K1 * L) ** 2, (K2 * L) ** 2

    mu_x = gaussian_filter(hr_f, sigma=1.5, truncate=3.5)
    mu_y = gaussian_filter(up_f, sigma=1.5, truncate=3.5)
    sigma_x2 = gaussian_filter(hr_f * hr_f, sigma=1.5, truncate=3.5) - mu_x * mu_x
    sigma_y2 = gaussian_filter(up_f * up_f, sigma=1.5, truncate=3.5) - mu_y * mu_y
    sigma_xy = gaussian_filter(hr_f * up_f, sigma=1.5, truncate=3.5) - mu_x * mu_y

    ssim_map = ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / (
               (mu_x**2 + mu_y**2 + C1) * (sigma_x2 + sigma_y2 + C2))
    ssim = float(np.mean(ssim_map[valid]))
    metrics = {'MAE': mae, 'RMSE': rmse, 'SSIM': ssim}

    # ------------ 保存为 GeoTIFF（可选）------------
    if save_tif_path is not None:
        if reference_tif_path is None:
            raise ValueError(" reference_tif_path  CRS/transform。")
        import rasterio, os
    
        with rasterio.open(reference_tif_path) as ref:
            profile = ref.profile
        profile = profile.copy()
        profile.update({
            'height': up.shape[0],
            'width':  up.shape[1],
            'count':  1,
            'dtype':  'float32',
            'driver': 'GTiff',
            'compress': 'lzw',   
        })
        dir_ = os.path.dirname(save_tif_path)
        if dir_:
            os.makedirs(dir_, exist_ok=True)
        with rasterio.open(save_tif_path, 'w', **profile) as dst:
            dst.write(up.astype(np.float32), 1)
        print(f"Upsampled (bicubic) GeoTIFF saved to: {save_tif_path}")

    return up, metrics



up_bicubic, metrics = upsample_bicubic_and_metrics(
    lr_data, hr_data, scale=4,
    save_tif_path='JAG/upsampled_bicubic_4.tif',
    reference_tif_path='JAG/磁异常.tif'  
)
print("Bicubic：", metrics)

In [None]:
import numpy as np
from PIL import Image
from skimage.transform import resize
import matplotlib.pyplot as plt
import rasterio
import os
import time
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

lr_image_path = 'JAG/lr_map_simulated_4t.tif'  
hr_image_path = 'JAG/磁异常.tif'            

UPSAMPLE_INTERP_ORDER = 3

UPSAMPLE_ANTIALIAS = True

PATCH_SIZE = 64


PATCH_STRIDE = 32

NUM_PATCH_SHOW = 4 

os.makedirs('JAG', exist_ok=True)

def image_normalization(image_array: np.ndarray):
    min_val = np.nanmin(image_array)
    max_val = np.nanmax(image_array)
    range_val = max_val - min_val
    if range_val < 1e-6: 
        range_val = 1.0
    normalized = (image_array - min_val) / range_val
    normalized = np.nan_to_num(normalized, nan=0.0, posinf=1.0, neginf=0.0)
    return normalized.astype(np.float32), float(min_val), float(max_val), float(range_val)

def extract_patches(hr_img: np.ndarray, lr_img: np.ndarray, patch_size=64, stride=32):

    assert hr_img.shape == lr_img.shape, 
    h, w = hr_img.shape
    hr_patches, lr_patches = [], []

    if h < patch_size or w < patch_size:
        # 图像小于 patch，退化为单 patch
        if h > 0 and w > 0:
            print(f"警告: 图像尺寸({h}x{w}) < patch 尺寸({patch_size}x{patch_size})，使用整图作为单个 patch")
            hr_patches.append(hr_img.copy())
            lr_patches.append(lr_img.copy())
            return hr_patches, lr_patches
        else:
            raise ValueError(f"无效图像尺寸: {h}x{w}")

    for i in range(0, h - patch_size + 1, stride):
        for j in range(0, w - patch_size + 1, stride):
            hr_patch = hr_img[i:i+patch_size, j:j+patch_size]
            lr_patch = lr_img[i:i+patch_size, j:j+patch_size]
            # 逻辑上已保证 shape 正确；此处再次校验更稳妥
            if hr_patch.shape == (patch_size, patch_size) and lr_patch.shape == (patch_size, patch_size):
                hr_patches.append(hr_patch.copy())
                lr_patches.append(lr_patch.copy())
    return hr_patches, lr_patches

def visualize_patch_pairs(lr_list, hr_list, num_show=4, save_path='results2/patch_examples.png'):

    n = min(num_show, len(lr_list), len(hr_list))
    if n <= 0:
        print("警告: 没有可视化的 patch，跳过。")
        return
    plt.figure(figsize=(12, 6))
    for i in range(n):
        plt.subplot(2, n, i + 1)
        plt.imshow(lr_list[i], cmap='gray')
        plt.title(f'LR Patch {i+1}')
        plt.axis('off')

        plt.subplot(2, n, i + 1 + n)
        plt.imshow(hr_list[i], cmap='gray')
        plt.title(f'HR Patch {i+1}')
        plt.axis('off')
    plt.suptitle(f'Patch 示例（上: LR / 下: HR）- 共 {len(hr_list)} 块')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()

def main():
    print("加载图像...")
    start_load = time.time()

    with rasterio.open(hr_image_path) as src:
        hr_array = src.read(1).astype(np.float32)
        H_hr, W_hr = hr_array.shape
        print(f"高分辨率图像尺寸: {H_hr} x {W_hr}")

    with rasterio.open(lr_image_path) as src:
        lr_array = src.read(1).astype(np.float32)
        H_lr, W_lr = lr_array.shape
        print(f"低分辨率图像尺寸: {H_lr} x {W_lr}")

    print("上采样低分辨率图像以对齐 HR 尺寸...")
    lr_array_upsampled = resize(
        lr_array,
        (H_hr, W_hr),
        order=UPSAMPLE_INTERP_ORDER,
        anti_aliasing=UPSAMPLE_ANTIALIAS
    ).astype(np.float32)

    print("归一化到 [0, 1]（以 HR min-max 为准）...")
    hr_array_norm, hr_min, hr_max, hr_range = image_normalization(hr_array)
    lr_array_norm = (lr_array_upsampled - hr_min) / (hr_range if hr_range != 0 else 1.0)
    lr_array_norm = np.clip(np.nan_to_num(lr_array_norm, nan=0.0, posinf=1.0, neginf=0.0), 0.0, 1.0).astype(np.float32)

    Image.fromarray((hr_array_norm * 255).astype(np.uint8)).save('JAG/normalized_high_res.png')
    Image.fromarray((lr_array_norm * 255).astype(np.uint8)).save('JAG/normalized_low_res.png')

    print(f"图像加载与预处理完成，耗时: {time.time() - start_load:.2f} 秒")

    print("提取图像块（patches）...")
    t0 = time.time()
    hr_patches, lr_patches = extract_patches(
        hr_array_norm, lr_array_norm,
        patch_size=PATCH_SIZE,
        stride=PATCH_STRIDE
    )

    if len(hr_patches) < 4:
        print(f"警告: 仅提取到 {len(hr_patches)} 个 patch，尝试自动降低门槛...")
        new_patch_size = min(32, H_hr, W_hr)
        new_stride = max(8, new_patch_size // 4)
        print(f"自动调整为：patch_size={new_patch_size}, stride={new_stride}")
        hr_patches, lr_patches = extract_patches(
            hr_array_norm, lr_array_norm,
            patch_size=new_patch_size,
            stride=new_stride
        )
        print(f"重新提取后共有 {len(hr_patches)} 个 patch")

    print(f"最终提取 {len(hr_patches)} 个 patch，耗时: {time.time() - t0:.2f} 秒")

    visualize_patch_pairs(lr_patches, hr_patches, num_show=NUM_PATCH_SHOW, save_path='results2/patch_examples.png')

    if len(hr_patches) == 0:
        raise RuntimeError("未提取到任何 patch，请调小 PATCH_SIZE 或减小 STRIDE。")

    hr_patches_arr = np.stack(hr_patches, axis=0).astype(np.float32)
    lr_patches_arr = np.stack(lr_patches, axis=0).astype(np.float32)

    np.savez(
        'JAG/preprocessed_data_4.npz',
        hr_patches=hr_patches_arr,   # (N, H, W)
        lr_patches=lr_patches_arr,   # (N, H, W)
        hr_array_norm=hr_array_norm, # (H, W)
        lr_array_norm=lr_array_norm, # (H, W)
        hr_min=hr_min,
        hr_max=hr_max,
        hr_range=hr_range,
        patch_size_used=int(hr_patches_arr.shape[1]),
        patch_stride_used=int(PATCH_STRIDE),
        upsample_order=int(UPSAMPLE_INTERP_ORDER)
    )

    print("\n=========== 预处理完成 Summary ===========")
    print(f"HR 尺寸: {H_hr} x {W_hr}")
    print(f"LR 原始尺寸: {H_lr} x {W_lr}  -> 上采样到 HR 尺寸, order={UPSAMPLE_INTERP_ORDER}")
    print(f"归一化参数: min={hr_min:.3f}, max={hr_max:.3f}, range={hr_range:.3f}")
    print(f"Patch: size={hr_patches_arr.shape[1]} stride={PATCH_STRIDE}  -> 数量={len(hr_patches_arr)}")
    print(f"可视化: results/patch_examples.png")
    print("数据包: JAG/preprocessed_data_4.npz")
    print("=========================================\n")

if __name__ == "__main__":
    main()

In [None]:
import rasterio, numpy as np
from skimage.transform import resize

hr_path = 'JAG/磁异常.tif'
lr_path = 'JAG/lr_map_simulated_4t.tif'
scale   = 4

with rasterio.open(hr_path) as hr_ds, rasterio.open(lr_path) as lr_ds:
    
    hr_arr = hr_ds.read(1).astype(np.float32)
    lr_arr = lr_ds.read(1).astype(np.float32)

    Hh, Wh = hr_ds.height, hr_ds.width
    Hl, Wl = lr_ds.height, lr_ds.width
    
    ratio_x = abs(lr_ds.transform.a / hr_ds.transform.a)
    ratio_y = abs(lr_ds.transform.e / hr_ds.transform.e)
    print(f"HR({Hh}x{Wh})  LR({Hl}x{Wl})  pixel-size ratio ≈ {ratio_x:.3f}, {ratio_y:.3f}")

   
    H2, W2 = (Hh // scale) * scale, (Wh // scale) * scale
    if (H2, W2) != (Hh, Wh):
        hr_arr = hr_arr[:H2, :W2]
        print(f"Cropped HR to ({H2},{W2}) so {scale}× divides both dims.")

    
    lr_up = resize(
        lr_arr, hr_arr.shape,
        order=3, anti_aliasing=False, preserve_range=True
    ).astype(np.float32)


mn, mx = float(np.nanmin(hr_arr)), float(np.nanmax(hr_arr))
rg = mx - mn if mx > mn else 1.0
hr01 = np.clip((hr_arr - mn) / rg, 0, 1).astype(np.float32)
lr01 = np.clip((lr_up - mn) / rg, 0, 1).astype(np.float32)

print('OK for 4× SR' if hr01.shape == lr01.shape else 'mismatch!')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
##############SRCNN#########

In [None]:

import os
import time
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings("ignore", category=UserWarning)

NUM_EPOCHS      = 1000
INIT_LR         = 1e-4
BATCH_SIZE      = 64                 
EVAL_EVERY      = 1                  
VAL_RATIO       = 0.1                
VAL_MAX_SAMPLES = 512                
DS_FACTOR       = 4                  
EARLY_STOP      = 10                 
USE_COMPILE     = True               
MODEL_PATH      = 'JAG/best_srcnn_fast.pth'


hr_image_path   = 'JAG/磁异常.tif'


os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)


data = np.load('JAG/preprocessed_data_3.npz', allow_pickle=True)
hr_patches    = data['hr_patches']         # (N, H, W), 0~1
lr_patches    = data['lr_patches']         # (N, H, W), 0~1
hr_array_norm = data['hr_array_norm']      # (H, W),   0~1
lr_array_norm = data['lr_array_norm']      # (H, W),   0~1
hr_min  = float(data['hr_min'])
hr_max  = float(data['hr_max'])
hr_range = float(data['hr_range'])


class SRCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(32, 1, kernel_size=5, padding=2)
    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        return self.conv3(x)

class SuperResolutionDataset(Dataset):
    def __init__(self, hr_patches, lr_patches):
        self.hr = np.ascontiguousarray(hr_patches.astype(np.float32))
        self.lr = np.ascontiguousarray(lr_patches.astype(np.float32))
    def __len__(self): return self.hr.shape[0]
    def __getitem__(self, idx):
        lr = torch.from_numpy(self.lr[idx]).unsqueeze(0)  # (1,H,W)
        hr = torch.from_numpy(self.hr[idx]).unsqueeze(0)
        return lr, hr

N = len(hr_patches)
idx = np.arange(N)
np.random.shuffle(idx)
split = int(N * (1 - VAL_RATIO))
train_idx, val_idx = idx[:split], idx[split:]
train_set = SuperResolutionDataset(hr_patches[train_idx], lr_patches[train_idx])
val_set   = SuperResolutionDataset(hr_patches[val_idx],  lr_patches[val_idx])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = min(8, os.cpu_count() or 4)
pin = torch.cuda.is_available()

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin,
    persistent_workers=(num_workers > 0),
    prefetch_factor=4,
    drop_last=True,
)
val_loader = DataLoader(
    val_set,
    batch_size=min(BATCH_SIZE, 32),
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin,
    persistent_workers=(num_workers > 0),
    prefetch_factor=4,
)

use_torchmetrics = True
try:
    from torchmetrics.functional import structural_similarity_index_measure as tm_ssim
    from torchmetrics.functional import peak_signal_noise_ratio as tm_psnr
except Exception:
    use_torchmetrics = False
    from skimage.metrics import structural_similarity as sk_ssim

def _downsample_torch(x, factor=1):
    if factor <= 1: return x
    return torch.nn.functional.avg_pool2d(x, kernel_size=factor, stride=factor, ceil_mode=False)

@torch.no_grad()
def evaluate_on_val(model, loader, eval_max=512, ds_factor=2, device='cuda'):
    """在验证集 patch 上评估（稀疏+可下采样），返回 PSNR/SSIM/MAE/RMSE（0~1 空间）。"""
    model.eval()
    n_seen = 0
    psnr_list, ssim_list, mae_list, rmse_list = [], [], [], []

    for lr, hr in loader:
        if n_seen >= eval_max:
            break
        lr = lr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        hr = hr.to(device, non_blocking=True).to(memory_format=torch.channels_last)

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            sr = model(lr)

       
        H = min(sr.shape[-2], hr.shape[-2])
        W = min(sr.shape[-1], hr.shape[-1])
        sr = sr[..., :H, :W].clamp_(0, 1)
        hr = hr[..., :H, :W].clamp_(0, 1)

       
        sr_eval = _downsample_torch(sr, ds_factor)
        hr_eval = _downsample_torch(hr, ds_factor)

        if use_torchmetrics:
            psnr = tm_psnr(sr_eval, hr_eval, data_range=1.0)
            ssim = tm_ssim(sr_eval, hr_eval, data_range=1.0)
            mae  = torch.mean(torch.abs(sr_eval - hr_eval))
            mse  = torch.mean((sr_eval - hr_eval) ** 2)
            rmse = torch.sqrt(mse)

            psnr_list.append(psnr.detach().float().item())
            ssim_list.append(ssim.detach().float().item())
            mae_list.append(mae.detach().float().item())
            rmse_list.append(rmse.detach().float().item())
        else:
            
            sr_cpu = sr_eval.squeeze(1).detach().cpu().numpy()  # (B,H,W)
            hr_cpu = hr_eval.squeeze(1).detach().cpu().numpy()
            for i in range(sr_cpu.shape[0]):
                mse  = np.mean((hr_cpu[i] - sr_cpu[i])**2)
                rmse = float(np.sqrt(mse))
                psnr = 20 * np.log10(1.0 / (np.sqrt(mse) + 1e-12))
                ssim = sk_ssim(hr_cpu[i], sr_cpu[i], data_range=1.0)
                mae  = float(np.mean(np.abs(hr_cpu[i] - sr_cpu[i])))

                psnr_list.append(float(psnr))
                ssim_list.append(float(ssim))
                mae_list.append(float(mae))
                rmse_list.append(float(rmse))

        n_seen += lr.size(0)

    if len(psnr_list) == 0:
        return 0.0, 0.0, 0.0, 0.0

    return (
        float(np.mean(psnr_list)),
        float(np.mean(ssim_list)),
        float(np.mean(mae_list)),
        float(np.mean(rmse_list)),
    )


torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

model = SRCNN().to(device).to(memory_format=torch.channels_last)
if USE_COMPILE:
    try:
        model = torch.compile(model, mode="reduce-overhead")  # 或 "max-autotune"
    except Exception:
        pass


optimizer = optim.AdamW(model.parameters(), lr=INIT_LR, fused=torch.cuda.is_available())
criterion = nn.L1Loss()
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',        
    factor=0.5,        
    patience=10,       
    verbose=True,
    min_lr=1e-6
)

best_score = -float('inf')  
epochs_no_improve = 0

def _get_lr(optim_):
    for pg in optim_.param_groups:
        return pg.get("lr", None)

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    t0 = time.time()
    running = 0.0

    for lr, hr in train_loader:
        lr = lr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        hr = hr.to(device, non_blocking=True).to(memory_format=torch.channels_last)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            pred = model(lr)
            loss = criterion(pred, hr)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running += loss.item()

    avg_loss = running / len(train_loader)
    log = f"[Epoch {epoch}/{NUM_EPOCHS}] lr={_get_lr(optimizer):.2e} loss={avg_loss:.6f} time={time.time()-t0:.1f}s"

    if epoch % EVAL_EVERY == 0:
        psnr, ssim, mae, rmse = evaluate_on_val(
            model, val_loader,
            eval_max=VAL_MAX_SAMPLES,
            ds_factor=DS_FACTOR,
            device=device
        )
  
        log += f" | val_mae={mae:.4f} val_rmse={rmse:.4f} val_ssim={ssim:.4f} score={score:.4f}"
        scheduler.step(score)
        if score > best_score + 1e-6:
            best_score = score
            epochs_no_improve = 0
            torch.save(model.state_dict(), MODEL_PATH)
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= EARLY_STOP:
            print(log)
            print(">>> 早停触发（组合分数连续无改善）")
            break

    print(log)

if not os.path.exists(MODEL_PATH):
    torch.save(model.state_dict(), MODEL_PATH)

@torch.no_grad()
def predict_full(model, lr_image_01, device='cuda'):
    model.eval()
    x = torch.from_numpy(lr_image_01.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(device)
    with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
        y = model(x).squeeze(0).squeeze(0).clamp_(0, 1)
    return y.detach().cpu().numpy()


best = SRCNN().to(device)
if USE_COMPILE:
    try:
        best = torch.compile(best, mode="reduce-overhead")
    except Exception:
        pass
best.load_state_dict(torch.load(MODEL_PATH, map_location=device))
best.eval()

sr_01 = predict_full(best, lr_array_norm, device=device)


sr_orig = (sr_01 * hr_range) + hr_min
sr_orig = np.clip(sr_orig, hr_min, hr_max).astype(np.float32)
hr_orig = (hr_array_norm * hr_range) + hr_min

hr_clean = np.nan_to_num(hr_orig, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float64)
sr_clean = np.nan_to_num(sr_orig, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float64)

finite_mask = np.isfinite(hr_clean) & np.isfinite(sr_clean)
if finite_mask.sum() == 0:
    raise RuntimeError("整图没有可用的像素")

hr_f = hr_clean[finite_mask]
sr_f = sr_clean[finite_mask]

hr_min_f = float(np.min(hr_f))
hr_max_f = float(np.max(hr_f))
data_range_f = hr_max_f - hr_min_f

diff = sr_f - hr_f
mse_val = float(np.mean(diff * diff))
mae_val = float(np.mean(np.abs(diff)))
rmse_val = float(np.sqrt(mse_val))

eps = 1e-12
if data_range_f < eps:
    psnr_val = float('inf') if mse_val < eps else 20 * np.log10((1.0) / np.sqrt(mse_val + eps))
else:
    psnr_val = 20 * np.log10(data_range_f / np.sqrt(mse_val + eps))

if data_range_f < eps:
    ssim_val = 1.0 if mse_val < eps else 0.0
else:
    hr01 = (hr_f - hr_min_f) / (data_range_f + eps)
    sr01 = (sr_f - hr_min_f) / (data_range_f + eps)
    if use_torchmetrics and torch.cuda.is_available():
        hr_t = torch.from_numpy(hr01.reshape(1, 1, -1, 1)).to(device)  # 形状不重要，只要是 2D
        sr_t = torch.from_numpy(sr01.reshape(1, 1, -1, 1)).to(device)
        with torch.amp.autocast('cuda', enabled=True):
            ssim_val = float(tm_ssim(sr_t, hr_t, data_range=1.0).item())
    else:
        from skimage.metrics import structural_similarity as sk_ssim
        H = int(np.sqrt(hr01.size)) or 1
        W = int(np.ceil(hr01.size / H))
        pad = H * W - hr01.size
        if pad > 0:
            hr01 = np.pad(hr01, (0, pad), constant_values=0)
            sr01 = np.pad(sr01, (0, pad), constant_values=0)
        hr_img = hr01.reshape(H, W)
        sr_img = sr01.reshape(H, W)
        ssim_val = float(sk_ssim(hr_img, sr_img, data_range=1.0))

print("\n===== Final Model Metrics (Full Image) =====")
print(f"PSNR: {psnr_val:.4f} dB")
print(f"SSIM: {ssim_val:.4f}")
print(f"MAE : {mae_val:.6f}")
print(f"MSE : {mse_val:.6f}")
print(f"RMSE: {rmse_val:.6f}")
print("============================================\n")

import rasterio
with rasterio.open(hr_image_path) as src:
    profile = src.profile
profile.update(count=1, dtype='float32')

out_tif = 'JAG/srcnn_reconstructed_3.tif'
with rasterio.open(out_tif, 'w', **profile) as dst:
    dst.write(sr_orig, 1)

print(f">>> 推理完成，GeoTIFF 保存：{out_tif}")

In [None]:
#DRN

In [None]:

import os
import time
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings("ignore", category=UserWarning)


NUM_EPOCHS      = 600       
INIT_LR         = 1e-4      
#INIT_LR         = 5e-5      
BATCH_SIZE      = 128       
EVAL_EVERY      = 1
#VAL_RATIO       = 0.2      
VAL_RATIO       = 0.3       
VAL_MAX_SAMPLES = 4096      
DS_FACTOR       = 1         
EARLY_STOP      = 5        
USE_COMPILE     = True
MODEL_PATH      = 'JAG/best_rdn_fast.pth'

W_SSIM = 0.7
W_MAE  = 0.3

SCALE_IN_MODEL  = 1 


hr_image_path   = 'JAG/磁异常.tif'


os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)


data = np.load('JAG/preprocessed_data_4.npz', allow_pickle=True)
hr_patches    = data['hr_patches']         # (N, H, W), 0~1
lr_patches    = data['lr_patches']         # (N, H, W), 0~1
hr_array_norm = data['hr_array_norm']      # (H, W),   0~1
lr_array_norm = data['lr_array_norm']      # (H, W),   0~1
hr_min  = float(data['hr_min'])
hr_max  = float(data['hr_max'])
hr_range = float(data['hr_range'])


class ResidualDenseBlock(nn.Module):
    def __init__(self, nf=64, gc=32):
        super().__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1)
        self.conv3 = nn.Conv2d(nf + 2*gc, gc, 3, 1, 1)
        self.conv4 = nn.Conv2d(nf + 3*gc, gc, 3, 1, 1)
        self.conv5 = nn.Conv2d(nf + 4*gc, nf, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat([x, x1], 1)))
        x3 = self.lrelu(self.conv3(torch.cat([x, x1, x2], 1)))
        x4 = self.lrelu(self.conv4(torch.cat([x, x1, x2, x3], 1)))
        x5 = self.conv5(torch.cat([x, x1, x2, x3, x4], 1))
        return x + 0.2 * x5

class ResidualDenseNetwork(nn.Module):
    def __init__(self, in_nc=1, out_nc=1, nf=64, nb=16, gc=32, scale=1):
        super().__init__()
        assert scale in (1, 2, 4, 8)
        self.scale = scale
        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1)
        self.RDBs = nn.ModuleList([ResidualDenseBlock(nf, gc) for _ in range(nb)])
        self.gff = nn.Sequential(
            nn.Conv2d(nb * nf, nf, 1, 1, 0),
            nn.Conv2d(nf, nf, 3, 1, 1)
        )
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        fea_shallow = fea.clone()
        rdb_outs = []
        for rdb in self.RDBs:
            fea = rdb(fea)
            rdb_outs.append(fea)
        fea_long = torch.cat(rdb_outs, 1)
        fea_gff = self.gff(fea_long)
        fea = fea_shallow + fea_gff
        if self.scale == 1:
            pass
        elif self.scale == 2:
            fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        elif self.scale == 4:
            fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
            fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
        elif self.scale == 8:
            fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
            fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
            fea = self.lrelu(self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(fea)
        return out


class SuperResolutionDataset(Dataset):
    def __init__(self, hr_patches, lr_patches):
        self.hr = np.ascontiguousarray(hr_patches.astype(np.float32))
        self.lr = np.ascontiguousarray(lr_patches.astype(np.float32))
    def __len__(self): return self.hr.shape[0]
    def __getitem__(self, idx):
        lr = torch.from_numpy(self.lr[idx]).unsqueeze(0)
        hr = torch.from_numpy(self.hr[idx]).unsqueeze(0)
        return lr, hr

N = len(hr_patches)
idx = np.arange(N); np.random.shuffle(idx)

if N == 0:
    raise RuntimeError("数据集中没有任何 patch")

# 先按 VAL_RATIO 计算，再裁剪到 [1, N-1]
proposed_split = int(N * (1 - VAL_RATIO))
if proposed_split <= 0:
    split = 1 if N > 1 else 1  # N==1 也给 1，后续会让 val=同一条
elif proposed_split >= N:
    split = N - 1
else:
    split = proposed_split

train_idx = idx[:split]
val_idx   = idx[split:]

if N == 1:
    train_idx = idx
    val_idx   = idx
    print(">>> 警告：仅有 1 个 patch。")

train_set = SuperResolutionDataset(hr_patches[train_idx], lr_patches[train_idx])
val_set   = SuperResolutionDataset(hr_patches[val_idx],  lr_patches[val_idx])

n_train, n_val = len(train_set), len(val_set)
print(f">>> 数据切分：N={N}, train={n_train}, val={n_val}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = min(8, os.cpu_count() or 4)
pin = torch.cuda.is_available()

TRAIN_BS = max(1, min(BATCH_SIZE, n_train))
VAL_BS   = max(1, min(BATCH_SIZE, n_val, 32))

train_loader = DataLoader(
    train_set, batch_size=TRAIN_BS, shuffle=True,
    num_workers=num_workers, pin_memory=pin,
    persistent_workers=(num_workers>0), prefetch_factor=4, drop_last=False
)
val_loader = DataLoader(
    val_set, batch_size=VAL_BS, shuffle=False,
    num_workers=num_workers, pin_memory=pin,
    persistent_workers=(num_workers>0), prefetch_factor=4
)

use_torchmetrics = True
try:
    from torchmetrics.functional import structural_similarity_index_measure as tm_ssim
    from torchmetrics.functional import peak_signal_noise_ratio as tm_psnr
except Exception:
    use_torchmetrics = False
    from skimage.metrics import structural_similarity as sk_ssim

def _downsample_torch(x, factor=1):
    if factor <= 1: return x
    return torch.nn.functional.avg_pool2d(x, kernel_size=factor, stride=factor, ceil_mode=False)

@torch.no_grad()
def evaluate_on_val(model, loader, eval_max=512, ds_factor=2, device='cuda'):
    model.eval()
    if len(loader) == 0:
        return 0.0, 0.0, 0.0, 0.0
    n_seen = 0
    psnr_list, ssim_list, mae_list, rmse_list = [], [], [], []
    for lr, hr in loader:
        if n_seen >= eval_max: break
        lr = lr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        hr = hr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            sr = model(lr)
        H = min(sr.shape[-2], hr.shape[-2]); W = min(sr.shape[-1], hr.shape[-1])
        sr = sr[..., :H, :W].clamp_(0, 1); hr = hr[..., :H, :W].clamp_(0, 1)
        sr_eval = _downsample_torch(sr, ds_factor)
        hr_eval = _downsample_torch(hr, ds_factor)
        if use_torchmetrics:
            psnr = tm_psnr(sr_eval, hr_eval, data_range=1.0)
            ssim = tm_ssim(sr_eval, hr_eval, data_range=1.0)
            mae  = torch.mean(torch.abs(sr_eval - hr_eval))
            mse  = torch.mean((sr_eval - hr_eval) ** 2)
            rmse = torch.sqrt(mse)
            psnr_list.append(psnr.detach().float().item())
            ssim_list.append(ssim.detach().float().item())
            mae_list.append(mae.detach().float().item())
            rmse_list.append(rmse.detach().float().item())
        else:
            sr_cpu = sr_eval.squeeze(1).detach().cpu().numpy()
            hr_cpu = hr_eval.squeeze(1).detach().cpu().numpy()
            for i in range(sr_cpu.shape[0]):
                diff = hr_cpu[i] - sr_cpu[i]
                mse  = float(np.mean(diff * diff))
                rmse = float(np.sqrt(mse))
                psnr = 20 * np.log10(1.0 / (np.sqrt(mse) + 1e-12))
                ssim = sk_ssim(hr_cpu[i], sr_cpu[i], data_range=1.0)
                mae  = float(np.mean(np.abs(diff)))
                psnr_list.append(psnr); ssim_list.append(ssim); mae_list.append(mae); rmse_list.append(rmse)
        n_seen += lr.size(0)
    if len(psnr_list) == 0: return 0.0, 0.0, 0.0, 0.0
    return (float(np.mean(psnr_list)), float(np.mean(ssim_list)),
            float(np.mean(mae_list)),  float(np.mean(rmse_list)))

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

model = ResidualDenseNetwork(in_nc=1, out_nc=1, nf=64, nb=16, gc=32, scale=SCALE_IN_MODEL).to(device).to(memory_format=torch.channels_last)
if USE_COMPILE:
    try: model = torch.compile(model, mode="reduce-overhead")
    except Exception: pass

optimizer = optim.AdamW(
    model.parameters(),
    lr=INIT_LR,
    betas=(0.9, 0.98),
    weight_decay=1e-3,
    fused=torch.cuda.is_available()
)

criterion = nn.L1Loss()
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=5, verbose=True, min_lr=1e-6
)

best_score = -float('inf')
epochs_no_improve = 0

def _get_lr(optim_):
    for pg in optim_.param_groups: return pg.get("lr", None)

for epoch in range(1, NUM_EPOCHS + 1):
    model.train(); t0 = time.time(); running = 0.0
    num_batches = 0
    for lr, hr in train_loader:
        lr = lr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        hr = hr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            pred = model(lr); loss = criterion(pred, hr)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        running += loss.item(); num_batches += 1
    avg_loss = running / max(1, num_batches)
    log = f"[Epoch {epoch}/{NUM_EPOCHS}] lr={_get_lr(optimizer):.2e} loss={avg_loss:.6f} time={time.time()-t0:.1f}s"

    if epoch % EVAL_EVERY == 0:
        psnr, ssim, mae, rmse = evaluate_on_val(model, val_loader, eval_max=VAL_MAX_SAMPLES, ds_factor=DS_FACTOR, device=device)
        score = W_SSIM * ssim + W_MAE * (1.0 - mae)
        log += f" | val_mae={mae:.4f} val_rmse={rmse:.4f} val_ssim={ssim:.4f} score={score:.4f}"
        scheduler.step(score)
        if score > best_score + 1e-6:
            best_score = score; epochs_no_improve = 0
            torch.save(model.state_dict(), MODEL_PATH)
        else:
            epochs_no_improve += 1
        if epochs_no_improve >= EARLY_STOP:
            print(log); print(">>> 早停触发"); break
    print(log)

if not os.path.exists(MODEL_PATH):
    torch.save(model.state_dict(), MODEL_PATH)

@torch.no_grad()
def predict_full(model, lr_image_01, device='cuda'):
    model.eval()
    x = torch.from_numpy(lr_image_01.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(device)
    with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
        y = model(x).squeeze(0).squeeze(0).clamp_(0, 1)
    return y.detach().cpu().numpy()

best = ResidualDenseNetwork(in_nc=1, out_nc=1, nf=64, nb=16, gc=32, scale=SCALE_IN_MODEL).to(device)
if USE_COMPILE:
    try: best = torch.compile(best, mode="reduce-overhead")
    except Exception: pass
best.load_state_dict(torch.load(MODEL_PATH, map_location=device))
best.eval()

sr_01 = predict_full(best, lr_array_norm, device=device)


sr_orig = (sr_01 * hr_range) + hr_min
sr_orig = np.clip(sr_orig, hr_min, hr_max).astype(np.float32)
hr_orig = (hr_array_norm * hr_range) + hr_min


hr_clean = np.nan_to_num(hr_orig, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float64)
sr_clean = np.nan_to_num(sr_orig, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float64)
finite_mask = np.isfinite(hr_clean) & np.isfinite(sr_clean)
if finite_mask.sum() == 0: raise RuntimeError("整图没有可用的有限像素用于评估。")
hr_f = hr_clean[finite_mask]; sr_f = sr_clean[finite_mask]
hr_min_f = float(np.min(hr_f)); hr_max_f = float(np.max(hr_f))
data_range_f = hr_max_f - hr_min_f
diff = sr_f - hr_f
mse_val = float(np.mean(diff * diff))
mae_val = float(np.mean(np.abs(diff)))
rmse_val = float(np.sqrt(mse_val))
eps = 1e-12
if data_range_f < eps:
    psnr_val = float('inf') if mse_val < eps else 20 * np.log10(1.0 / np.sqrt(mse_val + eps))
else:
    psnr_val = 20 * np.log10(data_range_f / np.sqrt(mse_val + eps))

if data_range_f < eps:
    ssim_val = 1.0 if mse_val < eps else 0.0
else:
    if use_torchmetrics and torch.cuda.is_available():
        hr01 = (hr_f - hr_min_f)/(data_range_f + eps)
        sr01 = (sr_f - hr_min_f)/(data_range_f + eps)
        hr_t = torch.from_numpy(hr01.reshape(1,1,-1,1)).to(device)
        sr_t = torch.from_numpy(sr01.reshape(1,1,-1,1)).to(device)
        with torch.amp.autocast('cuda', enabled=True):
            ssim_val = float(tm_ssim(sr_t, hr_t, data_range=1.0).item())
    else:
        from skimage.metrics import structural_similarity as sk_ssim
        hr01 = (hr_f - hr_min_f)/(data_range_f + eps)
        sr01 = (sr_f - hr_min_f)/(data_range_f + eps)
        H = int(np.sqrt(hr01.size)) or 1; W = int(np.ceil(hr01.size / H))
        pad = H*W - hr01.size
        if pad>0:
            hr01 = np.pad(hr01, (0,pad), constant_values=0)
            sr01 = np.pad(sr01, (0,pad), constant_values=0)
        ssim_val = float(sk_ssim(hr01.reshape(H,W), sr01.reshape(H,W), data_range=1.0))

print("\n===== Final Model Metrics (Full Image) =====")
print(f"PSNR: {psnr_val:.4f} dB")
print(f"SSIM: {ssim_val:.4f}")
print(f"MAE : {mae_val:.6f}")
print(f"MSE : {mse_val:.6f}")
print(f"RMSE: {rmse_val:.6f}")
print("============================================\n")

import rasterio
from rasterio.transform import Affine

def _round_to_16_leq(x, default=256):
    if x < 16: return 16
    return max(16, int(x // 16) * 16)

def _safe_remove(path):
    try:
        if os.path.exists(path):
            os.remove(path)
    except Exception:
        pass

def _cleanup_sidecars(path):
    exts = ["", ".aux.xml", ".ovr", ".msk", ".msk.ovr", ".msk.aux.xml"]
    for e in exts:
        _safe_remove(path + e)

def save_geotiff_no_hole(out_path, arr, ref_path, scale_in_model=1):
   
    with rasterio.open(ref_path) as src:
        crs = src.crs
        t   = src.transform

    H, W = int(arr.shape[0]), int(arr.shape[1])

    
    med = float(np.nanmedian(arr)) if (np.isnan(arr).any() or np.isinf(arr).any()) else float(np.median(arr))


    profile = {
        "driver": "GTiff",
        "height": H,
        "width": W,
        "count": 1,
        "dtype": "float32",
        "compress": "lzw",
        "crs": crs,
        "transform": t if scale_in_model == 1 else Affine(t.a/scale_in_model, t.b, t.c,
                                                          t.d, t.e/scale_in_model, t.f),
        "photometric": "MINISBLACK",
    }

    
    bx = _round_to_16_leq(min(256, W))
    by = _round_to_16_leq(min(256, H))
    try:
        with rasterio.Env(GDAL_TIFF_INTERNAL_MASK="YES", BIGTIFF="IF_SAFER"):
            with rasterio.open(out_path, "w", **profile, tiled=True, blockxsize=bx, blockysize=by) as dst:
                dst.write(clean, 1)
                dst.write_mask(valid_mask)
        print(f"GTiff 写出成功（tiled=True, block={bx}x{by}）。")
        return
    except Exception as e:
        print(">>> 警告：tiled 写出失败，将清理残留并回退条带。错误：", repr(e))
        _cleanup_sidecars(out_path)

    # 回退：非 tiled（条带），显式指定 rows per strip
    rps = min(512, H)  # ROWS_PER_STRIP
    try:
        with rasterio.Env(GDAL_TIFF_INTERNAL_MASK="YES", BIGTIFF="IF_SAFER"):
            with rasterio.open(out_path, "w", **profile, tiled=False, blockysize=rps) as dst:
                dst.write(clean, 1)
                dst.write_mask(valid_mask)
        print(f"GTiff 写出成功（tiled=False, rows_per_strip={rps}）。")
    except Exception as e2:
        # 仍失败：彻底清理后，用最小参数无压缩兜底
        print(">>> 二次写出仍失败，尝试无压缩最小参数。错误：", repr(e2))
        _cleanup_sidecars(out_path)
        minimal = {
            "driver": "GTiff",
            "height": H,
            "width": W,
            "count": 1,
            "dtype": "float32",
            "crs": crs,
            "transform": t if scale_in_model == 1 else Affine(t.a/scale_in_model, t.b, t.c,
                                                              t.d, t.e/scale_in_model, t.f),
        }
        with rasterio.open(out_path, "w", **minimal) as dst:
            dst.write(clean, 1)
            dst.write_mask(valid_mask)


    # 自检
    with rasterio.open(out_path) as chk:
        m = chk.read_masks(1)
        meta = chk.profile

# 保存
print(">>> 保存前检查：NaN/Inf =",
      int(np.isnan(sr_orig).sum()), int(np.isinf(sr_orig).sum()))
with rasterio.open(hr_image_path) as _src_chk:
    print("    ref(H,W) =", _src_chk.height, _src_chk.width, " pred(H,W) =", sr_orig.shape)

out_tif = 'JAG/rdn_reconstructed_4.tif'
save_geotiff_no_hole(out_tif, sr_orig, ref_path=hr_image_path, scale_in_model=SCALE_IN_MODEL)
print(f">>> 推理完成，GeoTIFF 保存：{out_tif}")

In [None]:
import os, time, warnings, gc
import numpy as np
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
import matplotlib.pyplot as plt

import torch, torch.nn as nn, torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings("ignore", category=UserWarning)
os.makedirs('results_x4', exist_ok=True)


HR_TIF = 'JAG/磁异常.tif'                 
LR_TIF = 'JAG/lr_map_simulated_4t.tif'     
SCALE   = 4                                

HR_PATCH_SIZE   = 64
HR_PATCH_STRIDE = 32
PATCH_LR   = max(1, int(round(HR_PATCH_SIZE   / SCALE)))   
STRIDE_LR  = max(1, int(round(HR_PATCH_STRIDE / SCALE)))   
NUM_PATCH_SHOW = 4


NUM_EPOCHS      = 300

VAL_RATIO       = 0.1
VAL_MAX_SAMPLES = 512
USE_COMPILE     = False
EARLY_STOP      = 8
MODEL_PATH_G    = 'JAG/best_esrgan_x4_gen.pth'

def _read_first_band(path):
    with rasterio.open(path) as src:
        arr = src.read(1).astype(np.float32)
        prof = src.profile
    return arr, prof

def _reproject_match(src_arr, src_prof, dst_prof):
    if src_prof.get('crs') is None or dst_prof.get('crs') is None:
        raise ValueError("reproject 需要 src/dst 都有 CRS")
    dst = np.zeros((dst_prof['height'], dst_prof['width']), dtype=np.float32)
    reproject(
        source=src_arr,
        destination=dst,
        src_transform=src_prof['transform'],
        src_crs=src_prof['crs'],
        dst_transform=dst_prof['transform'],
        dst_crs=dst_prof['crs'],
        resampling=Resampling.bilinear
    )
    return dst

def _ensure_ratio_scale(hr_arr, hr_prof, lr_arr, lr_prof, scale=SCALE):
    
    hr_crs = hr_prof.get('crs'); lr_crs = lr_prof.get('crs')
    can_reproject = (hr_crs is not None) and (lr_crs is not None)

    
    if can_reproject and lr_prof['crs'] != hr_prof['crs']:
        transform, width, height = calculate_default_transform(
            lr_prof['crs'], hr_prof['crs'], lr_prof['width'], lr_prof['height'],
            *rasterio.transform.array_bounds(lr_prof['height'], lr_prof['width'], lr_prof['transform'])
        )
        tmp_prof = lr_prof.copy()
        tmp_prof.update(crs=hr_prof['crs'], transform=transform, width=width, height=height)
        lr_arr = _reproject_match(lr_arr, lr_prof, tmp_prof)
        lr_prof = tmp_prof

    
    hr_a, hr_e = hr_prof['transform'].a, hr_prof['transform'].e
    dst_transform = rasterio.Affine(
        hr_a * scale, 0, hr_prof['transform'].c,
        0,   hr_e * scale, hr_prof['transform'].f
    )
    dst_width  = int(np.ceil(hr_prof['width']  / scale))
    dst_height = int(np.ceil(hr_prof['height'] / scale))
    dst_prof = hr_prof.copy()
    dst_prof.update(transform=dst_transform, width=dst_width, height=dst_height)

    if can_reproject:
        lr_arr = _reproject_match(lr_arr, lr_prof, dst_prof)
        lr_prof = dst_prof
    else:
        print("跳过")


    H_lr = min(lr_arr.shape[0], hr_arr.shape[0] // scale)
    W_lr = min(lr_arr.shape[1], hr_arr.shape[1] // scale)
    lr_arr = lr_arr[:H_lr, :W_lr]
    hr_arr = hr_arr[:H_lr*scale, :W_lr*scale]
    return hr_arr, hr_prof, lr_arr, lr_prof

def image_normalization(arr):
    mn = np.nanmin(arr); mx = np.nanmax(arr)
    rg = mx - mn
    if rg < 1e-6: rg = 1.0
    out = (arr - mn) / rg
    out = np.nan_to_num(out, nan=0.0, posinf=1.0, neginf=0.0).astype(np.float32)
    return out, float(mn), float(mx), float(rg)

def extract_pairs(hr01, lr01, scale=SCALE, patch_lr=PATCH_LR, stride_lr=STRIDE_LR):

    H_lr, W_lr = lr01.shape
    ps, st = patch_lr, stride_lr
    hr_patches, lr_patches = [], []
    for i in range(0, H_lr - ps + 1, st):
        for j in range(0, W_lr - ps + 1, st):
            lr_patch = lr01[i:i+ps, j:j+ps]
            hi, hj = i*scale, j*scale
            hr_patch = hr01[hi:hi+ps*scale, hj:hj+ps*scale]
            if hr_patch.shape == (ps*scale, ps*scale):
                lr_patches.append(lr_patch.copy())
                hr_patches.append(hr_patch.copy())
    return np.asarray(hr_patches, np.float32), np.asarray(lr_patches, np.float32)

def preprocess_all(hr_path, lr_path, scale, patch_lr, stride_lr, num_show=4):
    t0 = time.time()
    print(">> 读取 HR/LR ...")
    hr_arr, hr_prof = _read_first_band(hr_path)
    lr_arr, lr_prof = _read_first_band(lr_path)

    print(f">> 对齐网格（确保 LR:HR = 1:{scale}） ...")
    hr_arr, hr_prof, lr_arr, lr_prof = _ensure_ratio_scale(hr_arr, hr_prof, lr_arr, lr_prof, scale)

    print(">> 归一化（以 HR 的 min-max 为准） ...")
    hr01, hr_min, hr_max, hr_range = image_normalization(hr_arr)
    lr01 = np.clip((lr_arr - hr_min) / (hr_range if hr_range != 0 else 1.0), 0, 1).astype(np.float32)

    print(f">> 切 patch：HR规格≈{HR_PATCH_SIZE} / {HR_PATCH_STRIDE}  →  LR规格={patch_lr} / {stride_lr}")
    hr_p, lr_p = extract_pairs(hr01, lr01, scale, patch_lr, stride_lr)
    if hr_p.size == 0:
        raise RuntimeError("未切出任何 patch，检查尺寸/stride/patch。")
    print(f">> 样本数：{len(hr_p)}")

    # 预览
    nshow = min(num_show, len(hr_p))
    if nshow > 0:
        plt.figure(figsize=(12, 4))
        for k in range(nshow):
            plt.subplot(2, nshow, k+1)
            plt.imshow(lr_p[k], cmap='gray'); plt.axis('off'); plt.title(f'LR#{k+1}')
            plt.subplot(2, nshow, k+1+nshow)
            plt.imshow(hr_p[k], cmap='gray'); plt.axis('off'); plt.title(f'HR#{k+1}')
        plt.tight_layout(); plt.savefig('results_x4/patch_examples_x4.png', dpi=200); plt.close()

    print(f">> 预处理完成，用时 {time.time()-t0:.1f}s")
    return (hr_p, lr_p, hr01, lr01, hr_min, hr_max, hr_range, hr_prof, lr_prof)

class ResidualDenseBlock(nn.Module):
    def __init__(self, nf=64, gc=32):
        super().__init__()
        self.c1 = nn.Conv2d(nf, gc, 3, 1, 1)
        self.c2 = nn.Conv2d(nf+gc, gc, 3, 1, 1)
        self.c3 = nn.Conv2d(nf+gc*2, gc, 3, 1, 1)
        self.c4 = nn.Conv2d(nf+gc*3, gc, 3, 1, 1)
        self.c5 = nn.Conv2d(nf+gc*4, nf, 3, 1, 1)
        self.a  = nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        x1 = self.a(self.c1(x))
        x2 = self.a(self.c2(torch.cat([x, x1], 1)))
        x3 = self.a(self.c3(torch.cat([x, x1, x2], 1)))
        x4 = self.a(self.c4(torch.cat([x, x1, x2, x3], 1)))
        x5 = self.c5(torch.cat([x, x1, x2, x3, x4], 1))
        return x + 0.2*x5

class RRDB(nn.Module):
    def __init__(self, nf=64, gc=32):
        super().__init__()
        self.b1 = ResidualDenseBlock(nf, gc)
        self.b2 = ResidualDenseBlock(nf, gc)
        self.b3 = ResidualDenseBlock(nf, gc)
    def forward(self, x):
        y = self.b1(x); y = self.b2(y); y = self.b3(y)
        return x + 0.2*y

class ESRGANGenerator(nn.Module):
    def __init__(self, in_nc=1, out_nc=1, nf=64, nb=8, gc=32, scale=SCALE):
        super().__init__()
        assert scale in (2,3,4,8), "scale 仅支持 2/3/4/8"
        self.scale = scale
        self.head = nn.Conv2d(in_nc, nf, 3, 1, 1)
        self.trunk = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1)
        self.up1 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.up2 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.up3 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.tail = nn.Conv2d(nf, out_nc, 3, 1, 1)
        self.a = nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        fea = self.head(x)
        trunk = self.trunk_conv(self.trunk(fea))
        fea = fea + trunk
        if self.scale == 2:
            fea = self.a(self.up1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        elif self.scale == 3:
            fea = self.a(self.up1(F.interpolate(fea, scale_factor=3, mode='nearest')))
        elif self.scale == 4:
            fea = self.a(self.up1(F.interpolate(fea, scale_factor=2, mode='nearest')))
            fea = self.a(self.up2(F.interpolate(fea, scale_factor=2, mode='nearest')))
        elif self.scale == 8:
            fea = self.a(self.up1(F.interpolate(fea, scale_factor=2, mode='nearest')))
            fea = self.a(self.up2(F.interpolate(fea, scale_factor=2, mode='nearest')))
            fea = self.a(self.up3(F.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.tail(fea)
        return out


class PairSet(Dataset):
    def __init__(self, hr_p, lr_p):
        self.hr = np.ascontiguousarray(hr_p.astype(np.float32))  # (N, SCALE*P, SCALE*P)
        self.lr = np.ascontiguousarray(lr_p.astype(np.float32))  # (N, P, P)
    def __len__(self): return self.hr.shape[0]
    def __getitem__(self, i):
        lr = torch.from_numpy(self.lr[i]).unsqueeze(0)
        hr = torch.from_numpy(self.hr[i]).unsqueeze(0)
        return lr, hr

use_torchmetrics = True
try:
    from torchmetrics.functional import structural_similarity_index_measure as tm_ssim
    from torchmetrics.functional import peak_signal_noise_ratio as tm_psnr
except Exception:
    use_torchmetrics = False
    from skimage.metrics import structural_similarity as sk_ssim

@torch.no_grad()
def evaluate_on_val(gen, loader, eval_max=512, device='cuda'):
   
    gen.eval()
    n=0; psnrL=[]; ssimL=[]; maeL=[]; rmseL=[]
    for lr, hr in loader:
        if n>=eval_max: break
        lr=lr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        hr=hr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            sr=gen(lr)
        H=min(sr.shape[-2], hr.shape[-2]); W=min(sr.shape[-1], hr.shape[-1])
        sr=sr[..., :H,:W].clamp_(0,1); hr=hr[..., :H,:W].clamp_(0,1)

        if use_torchmetrics:
            ps=tm_psnr(sr, hr, data_range=1.0)
            ss=tm_ssim(sr, hr, data_range=1.0)
            ma=torch.mean(torch.abs(sr-hr))
            mse=torch.mean((sr-hr)**2)
            rmse=torch.sqrt(mse)
            psnrL.append(ps.detach().float().item())
            ssimL.append(ss.detach().float().item())
            maeL.append(ma.detach().float().item())
            rmseL.append(rmse.detach().float().item())
        else:
            sr_np=sr.squeeze(1).detach().cpu().numpy()
            hr_np=hr.squeeze(1).detach().cpu().numpy()
            for i in range(sr_np.shape[0]):
                diff=hr_np[i]-sr_np[i]
                mse=float((diff*diff).mean()); rmse=float(np.sqrt(mse))
                ps=20*np.log10(1.0/(np.sqrt(mse)+1e-12))
                ss=sk_ssim(hr_np[i], sr_np[i], data_range=1.0)
                ma=float(np.abs(diff).mean())
                psnrL.append(float(ps)); ssimL.append(float(ss)); maeL.append(float(ma)); rmseL.append(float(rmse))
        n+=lr.size(0)

    if not psnrL: return 0.0,0.0,0.0,0.0
    return float(np.mean(psnrL)), float(np.mean(ssimL)), float(np.mean(maeL)), float(np.mean(rmseL))


def main():
  
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

  
    (hr_patches, lr_patches, hr01_full, lr01_full,
     hr_min, hr_max, hr_range, hr_prof, lr_prof) = preprocess_all(
        HR_TIF, LR_TIF, SCALE, PATCH_LR, STRIDE_LR, NUM_PATCH_SHOW
    )

    # 切分数据
    N = len(hr_patches)
    perm = np.random.permutation(N)
    split = int(N*(1-VAL_RATIO))
    train_idx, val_idx = perm[:split], perm[split:]
    train_set = PairSet(hr_patches[train_idx], lr_patches[train_idx])
    val_set   = PairSet(hr_patches[val_idx],  lr_patches[val_idx])

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_workers = min(8, os.cpu_count() or 4)
    pin = torch.cuda.is_available()

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=num_workers, pin_memory=pin, persistent_workers=(num_workers>0),
                              prefetch_factor=4, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=min(BATCH_SIZE,32), shuffle=False,
                            num_workers=num_workers, pin_memory=pin, persistent_workers=(num_workers>0),

    G = ESRGANGenerator(scale=SCALE, nb=8).to(device).to(memory_format=torch.channels_last)
    if USE_COMPILE:
        try: G=torch.compile(G, mode="reduce-overhead")
        except Exception: pass

    opt = optim.AdamW(G.parameters(), lr=INIT_LR, fused=torch.cuda.is_available())
    crit = nn.L1Loss()
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
    # 调度器：以 score（越大越好）为依据
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='max', factor=0.5, patience=3, verbose=True, min_lr=1e-6)

    best_score=-1e9; epochs_no_improve=0
    def _getlr(optim_):
        for pg in optim_.param_groups: return pg.get("lr", None)

    try:
run/max(1,nstep):.6f} time={time.time()-t0:.1f}s")

        for epoch in range(1, NUM_EPOCHS+1):
            G.train(); t0=time.time(); run=0.0; nstep=0
            for lr,hr in train_loader:
                lr=lr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                hr=hr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                opt.zero_grad(set_to_none=True)
                with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                    sr=G(lr); loss=crit(sr,hr)
                scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
                run+=float(loss.item()); nstep+=1
            avg_loss=run/max(1,nstep)
            log=f"[Epoch {epoch}/{NUM_EPOCHS}] lr={_getlr(opt):.2e} loss={avg_loss:.6f} time={time.time()-t0:.1f}s"

            if epoch % EVAL_EVERY==0:
                psnr, ssim, mae, rmse = evaluate_on_val(G, val_loader, eval_max=VAL_MAX_SAMPLES, device=device)
                
                log += f" | val_mae={mae:.4f} val_rmse={rmse:.4f} val_ssim={ssim:.4f} val_psnr={psnr:.3f}"
     
                score = psnr + 10.0*ssim
                sched.step(score)
                if score > best_score + 1e-6:
                    best_score = score; epochs_no_improve=0
                    torch.save(G.state_dict(), MODEL_PATH_G)
                else:
                    epochs_no_improve += 1
                if epochs_no_improve >= EARLY_STOP:
                    print(log); print(">>> 早停触发"); break
            print(log)

    except KeyboardInterrupt:
        ckpt = MODEL_PATH_G.replace('.pth', '_interrupt.pth')
        torch.save(G.state_dict(), ckpt)
        print(f"\n>>> 收到 KeyboardInterrupt，已保存中断权重到 {ckpt}")


    if not os.path.exists(MODEL_PATH_G):
        torch.save(G.state_dict(), MODEL_PATH_G)

    @torch.no_grad()
    def predict_full(model, lr01, device='cuda'):
        model.eval()
        x=torch.from_numpy(lr01.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(device)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            y=model(x).squeeze(0).squeeze(0).clamp_(0,1)
        return y.detach().cpu().numpy()

    best = ESRGANGenerator(scale=SCALE, nb=8).to(device)
    if USE_COMPILE:
        try: best=torch.compile(best, mode="reduce-overhead")
        except Exception: pass
    best.load_state_dict(torch.load(MODEL_PATH_G, map_location=device))
    best.eval()

    sr01 = predict_full(best, lr01_full, device=device)


    sr_orig = (sr01 * hr_range) + hr_min
    sr_orig = np.clip(sr_orig, hr_min, hr_max).astype(np.float32)
    hr_orig = (hr01_full * hr_range) + hr_min


    hr_f = np.nan_to_num(hr_orig, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float64)
    sr_f = np.nan_to_num(sr_orig, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float64)
    H = min(hr_f.shape[0], sr_f.shape[0]); W = min(hr_f.shape[1], sr_f.shape[1])
    hr_f, sr_f = hr_f[:H,:W], sr_f[:H,:W]

    diff = sr_f - hr_f
    mse_val = float(np.mean(diff*diff))
    mae_val = float(np.mean(np.abs(diff)))
    rmse_val = float(np.sqrt(mse_val))
    hr_min_f = float(np.min(hr_f)); hr_max_f = float(np.max(hr_f))
    rg = hr_max_f - hr_min_f; eps=1e-12
    psnr_val = (20*np.log10(rg/np.sqrt(mse_val+eps))) if rg>=eps else (float('inf') if mse_val<eps else 20*np.log10(1.0/np.sqrt(mse_val+eps)))

    try:
        from skimage.metrics import structural_similarity as sk_ssim
        hr01m = (hr_f - hr_min_f)/(rg+eps); sr01m=(sr_f - hr_min_f)/(rg+eps)
        ssim_val = float(sk_ssim(hr01m, sr01m, data_range=1.0))
    except Exception:
        ssim_val = float('nan')

    print("\n===== Final Model Metrics (Full Image) =====")
    print(f"PSNR: {psnr_val:.4f} dB")
    print(f"SSIM: {ssim_val:.4f}")
    print(f"MAE : {mae_val:.6f}")
    print(f"MSE : {mse_val:.6f}")
    print(f"RMSE: {rmse_val:.6f}")
    print("============================================\n")
    new_prof.update(
        height=sr_orig.shape[0],
        width =sr_orig.shape[1],
        dtype='float32',
        count=1,
        transform=rasterio.Affine(
            lr_prof['transform'].a/SCALE, 0, lr_prof['transform'].c,
            0, lr_prof['transform'].e/SCALE, lr_prof['transform'].f
        )
    )
    out_tif = 'JAG/esrgan_x4_reconstructed.tif'
    with rasterio.open(out_tif, 'w', **new_prof) as dst:
        dst.write(sr_orig, 1)
    print(f">>> 推理完成，GeoTIFF 保存：{out_tif}")

if __name__ == "__main__":
    main()

In [None]:
#SRFORMER

In [None]:
import os
import time
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings("ignore", category=UserWarning)


NUM_EPOCHS      = 1000
INIT_LR         = 2e-4
BATCH_SIZE      = 64                
EVAL_EVERY      = 1                 
DS_FACTOR       = 1                  
EARLY_STOP      = 20                 
USE_COMPILE     = True               
MODEL_PATH      = 'JAG/best_srformer_fast.pth'

W_SSIM = 0.7
W_MAE  = 0.3

#
# ----------------------------
data = np.load('JAG/preprocessed_data_4.npz', allow_pickle=True)

hr_min  = float(data['hr_min'])
hr_max  = float(data['hr_max'])

class DropPath(nn.Module):
   
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = drop_prob
    def forward(self, x):
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # (B,1,1,1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor

class PermutedSelfAttention(nn.Module):

    def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.dim = dim
        self.h_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads,
                                            dropout=attn_drop, batch_first=True)
        self.w_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads,
                                            dropout=attn_drop, batch_first=True)
        self.proj = nn.Conv2d(dim, dim, kernel_size=1)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):

        B, C, H, W = x.shape

        x_h = x.permute(0, 3, 2, 1).reshape(B * W, H, C).contiguous()  # (B*W, H, C)
        y_h, _ = self.h_attn(x_h, x_h, x_h, need_weights=False)
        y_h = y_h.reshape(B, W, H, C).permute(0, 3, 2, 1).contiguous()  # (B, C, H, W)
        
        x_w = x.permute(0, 2, 3, 1).reshape(B * H, W, C).contiguous()  # (B*H, W, C)
        y_w, _ = self.w_attn(x_w, x_w, x_w, need_weights=False)
        y_w = y_w.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()  # (B, C, H, W)
        
        y = 0.5 * (y_h + y_w)
        y = self.proj(y)
        y = self.proj_drop(y)
        return y

class ConvMLP(nn.Module):
    
    def __init__(self, dim, expansion=2.0, drop=0.0):
        super().__init__()
        hidden = int(dim * expansion)
        self.pw1 = nn.Conv2d(dim, hidden, kernel_size=1, bias=True)
        self.dw  = nn.Conv2d(hidden, hidden, kernel_size=3, padding=1, groups=hidden, bias=True)
        self.act = nn.GELU()
        self.pw2 = nn.Conv2d(hidden, dim, kernel_size=1, bias=True)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.pw1(x)
        x = self.dw(x)
        x = self.act(x)
        x = self.pw2(x)
        x = self.drop(x)
        return x

class LayerNorm2d(nn.Module):
    
    def __init__(self, num_channels, eps=1e-6):
        super().__init__()
        self.norm = nn.LayerNorm(num_channels, eps=eps)
    def forward(self, x):
        # x: (B, C, H, W)
        x = x.permute(0, 2, 3, 1)                 # -> (B, H, W, C)
        x = self.norm(x)
        return x.permute(0, 3, 1, 2).contiguous() # -> (B, C, H, W)

class SRFormerBlock(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=2.0, attn_drop=0.0, proj_drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = LayerNorm2d(dim)
        self.attn  = PermutedSelfAttention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = LayerNorm2d(dim)
        self.mlp   = ConvMLP(dim, expansion=mlp_ratio, drop=proj_drop)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
    def forward(self, x):
        x = x + self.drop_path1(self.attn(self.norm1(x)))
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x

class SRFormer(nn.Module):

    def __init__(self,
                 in_ch=1,
                 embed_dim=64,
                 depth=8,
                 num_heads=8,
                 mlp_ratio=2.0,
                 attn_drop=0.0,
                 proj_drop=0.0,
                 drop_path_rate=0.0):
        super().__init__()
        self.in_proj  = nn.Conv2d(in_ch, embed_dim, kernel_size=3, padding=1)
        dpr = torch.linspace(0, drop_path_rate, steps=depth).tolist() if depth > 1 else [drop_path_rate]
        blocks = []
        for i in range(depth):
            blocks.append(
                SRFormerBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    attn_drop=attn_drop,
                    proj_drop=proj_drop,
                    drop_path=dpr[i]
                )
            )
        self.blocks = nn.Sequential(*blocks)
        self.out_proj = nn.Conv2d(embed_dim, in_ch, kernel_size=3, padding=1)
    def forward(self, x):
        inp = x
        x = self.in_proj(x)
        x = self.blocks(x)
        x = self.out_proj(x)
        return x + inp  

class SuperResolutionDataset(Dataset):
    def __init__(self, hr_patches, lr_patches):
        self.hr = np.ascontiguousarray(hr_patches.astype(np.float32))
        self.lr = np.ascontiguousarray(lr_patches.astype(np.float32))
    def __len__(self): return self.hr.shape[0]
    def __getitem__(self, idx):
        lr = torch.from_numpy(self.lr[idx]).unsqueeze(0)  # (1,H,W)
        hr = torch.from_numpy(self.hr[idx]).unsqueeze(0)
        return lr, hr

N = len(hr_patches)
idx = np.arange(N)
np.random.shuffle(idx)
split = int(N * (1 - VAL_RATIO))
train_idx, val_idx = idx[:split], idx[split:]
train_set = SuperResolutionDataset(hr_patches[train_idx], lr_patches[train_idx])
val_set   = SuperResolutionDataset(hr_patches[val_idx],  lr_patches[val_idx])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = min(8, os.cpu_count() or 4)
pin = torch.cuda.is_available()

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin,
    persistent_workers=(num_workers > 0),
    prefetch_factor=4,
    drop_last=True,
)
val_loader = DataLoader(
    val_set,
    batch_size=min(BATCH_SIZE, 32),
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin,
    persistent_workers=(num_workers > 0),
    prefetch_factor=4,
)

use_torchmetrics = True
try:
    from torchmetrics.functional import structural_similarity_index_measure as tm_ssim
    from torchmetrics.functional import peak_signal_noise_ratio as tm_psnr
except Exception:
    use_torchmetrics = False
    from skimage.metrics import structural_similarity as sk_ssim

def _downsample_torch(x, factor=1):
    if factor <= 1: return x
    return torch.nn.functional.avg_pool2d(x, kernel_size=factor, stride=factor, ceil_mode=False)

@torch.no_grad()
def evaluate_on_val(model, loader, eval_max=512, ds_factor=2, device='cuda'):
    
    model.eval()
    n_seen = 0
    psnr_list, ssim_list, mae_list, rmse_list = [], [], [], []
    for lr, hr in loader:
        if n_seen >= eval_max:
            break
        lr = lr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        hr = hr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            sr = model(lr)
        sr_eval = _downsample_torch(sr, ds_factor)
        hr_eval = _downsample_torch(hr, ds_factor)
        if use_torchmetrics:
            psnr = tm_psnr(sr_eval, hr_eval, data_range=1.0)
            ssim = tm_ssim(sr_eval, hr_eval, data_range=1.0)
            mae  = torch.mean(torch.abs(sr_eval - hr_eval))
            mse  = torch.mean((sr_eval - hr_eval) ** 2)
            rmse = torch.sqrt(mse)
            psnr_list.append(psnr.detach().float().item())
            ssim_list.append(ssim.detach().float().item())
            mae_list.append(mae.detach().float().item())
            rmse_list.append(rmse.detach().float().item())
        else:
            sr_cpu = sr_eval.squeeze(1).detach().cpu().numpy()  # (B,H,W)
            hr_cpu = hr_eval.squeeze(1).detach().cpu().numpy()
            for i in range(sr_cpu.shape[0]):
                diff = hr_cpu[i] - sr_cpu[i]
                mse  = float(np.mean(diff * diff))
                rmse = float(np.sqrt(mse))
                psnr = 20 * np.log10(1.0 / (np.sqrt(mse) + 1e-12))
                ssim = sk_ssim(hr_cpu[i], sr_cpu[i], data_range=1.0)
                mae  = float(np.mean(np.abs(diff)))
                psnr_list.append(psnr); ssim_list.append(ssim); mae_list.append(mae); rmse_list.append(rmse)
        n_seen += lr.size(0)
    if len(psnr_list) == 0:
        return 0.0, 0.0, 0.0, 0.0
    return (float(np.mean(psnr_list)), float(np.mean(ssim_list)),
            float(np.mean(mae_list)), float(np.mean(rmse_list)))

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

model = SRFormer(in_ch=1, embed_dim=64, depth=8, num_heads=8,
                 mlp_ratio=2.0, attn_drop=0.0, proj_drop=0.0,
                 drop_path_rate=0.0).to(device).to(memory_format=torch.channels_last)
if USE_COMPILE:
    try:
        model = torch.compile(model, mode="reduce-overhead")
    except Exception:
        pass

optimizer = optim.AdamW(model.parameters(), lr=INIT_LR, fused=torch.cuda.is_available())
criterion = nn.L1Loss()
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',       
    factor=0.5,        
    patience=10,      
    verbose=True,
    min_lr=1e-6
)

best_score = -float('inf')  
epochs_no_improve = 0

def _get_lr(optim_):
    for pg in optim_.param_groups:
        return pg.get("lr", None)

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    t0 = time.time()
    running = 0.0
    for lr, hr in train_loader:
        lr = lr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        hr = hr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            pred = model(lr)
            loss = criterion(pred, hr)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running += loss.item()
    avg_loss = running / len(train_loader)
    log = f"[Epoch {epoch}/{NUM_EPOCHS}] lr={_get_lr(optimizer):.2e} loss={avg_loss:.6f} time={time.time()-t0:.1f}s"
    
    if epoch % EVAL_EVERY == 0:
        psnr, ssim, mae, rmse = evaluate_on_val(
            model, val_loader,
            eval_max=VAL_MAX_SAMPLES,
            ds_factor=DS_FACTOR,
            device=device
        )
        
        score = W_SSIM * ssim + W_MAE * (1.0 - mae)
        log += f" | val_mae={mae:.4f} val_rmse={rmse:.4f} val_ssim={ssim:.4f} score={score:.4f}"
        
        scheduler.step(score)
       
        if score > best_score + 1e-6:
            best_score = score
            epochs_no_improve = 0
            torch.save(model.state_dict(), MODEL_PATH)
        else:
            epochs_no_improve += 1
        if epochs_no_improve >= EARLY_STOP:
            print(log)
            print(">>> 早停触发（组合分数连续无改善）")
            break
    print(log)

if not os.path.exists(MODEL_PATH):
    torch.save(model.state_dict(), MODEL_PATH)

@torch.no_grad()
def predict_full(model, lr_image_01, device='cuda'):
    model.eval()
    x = torch.from_numpy(lr_image_01.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(device)
    with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
        y = model(x).squeeze(0).squeeze(0).clamp_(0, 1)
    return y.detach().cpu().numpy()


best = SRFormer(in_ch=1, embed_dim=64, depth=8, num_heads=8,
                mlp_ratio=2.0, attn_drop=0.0, proj_drop=0.0,
                drop_path_rate=0.0).to(device)
if USE_COMPILE:
    try:
        best = torch.compile(best, mode="reduce-overhead")
    except Exception:
        pass
best.load_state_dict(torch.load(MODEL_PATH, map_location=device))
best.eval()

sr_01 = predict_full(best, lr_array_norm, device=device)

sr_orig = (sr_01 * hr_range) + hr_min
sr_orig = np.clip(sr_orig, hr_min, hr_max).astype(np.float32)
hr_orig = (hr_array_norm * hr_range) + hr_min

hr_clean = np.nan_to_num(hr_orig, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float64)
sr_clean = np.nan_to_num(sr_orig, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float64)


finite_mask = np.isfinite(hr_clean) & np.isfinite(sr_clean)
if finite_mask.sum() == 0:
    raise RuntimeError("整图没有可用的有限像素")

hr_f = hr_clean[finite_mask]
sr_f = sr_clean[finite_mask]

hr_min_f = float(np.min(hr_f))
hr_max_f = float(np.max(hr_f))
data_range_f = hr_max_f - hr_min_f

diff = sr_f - hr_f
mse_val = float(np.mean(diff * diff))
mae_val = float(np.mean(np.abs(diff)))
rmse_val = float(np.sqrt(mse_val))

eps = 1e-12
if data_range_f < eps:
    psnr_val = float('inf') if mse_val < eps else 20 * np.log10((1.0) / np.sqrt(mse_val + eps))
else:
    psnr_val = 20 * np.log10(data_range_f / np.sqrt(mse_val + eps))

if data_range_f < eps:
    ssim_val = 1.0 if mse_val < eps else 0.0
else:
    hr01 = (hr_f - hr_min_f) / (data_range_f + eps)
    sr01 = (sr_f - hr_min_f) / (data_range_f + eps)
    if use_torchmetrics and torch.cuda.is_available():
        hr_t = torch.from_numpy(hr01.reshape(1, 1, -1, 1)).to(device)  # 形状不重要，只要是 2D
        sr_t = torch.from_numpy(sr01.reshape(1, 1, -1, 1)).to(device)
        with torch.amp.autocast('cuda', enabled=True):
            from torchmetrics.functional import structural_similarity_index_measure as tm_ssim
            ssim_val = float(tm_ssim(sr_t, hr_t, data_range=1.0).item())
    else:
        from skimage.metrics import structural_similarity as sk_ssim
        H = int(np.sqrt(hr01.size)) or 1
        W = int(np.ceil(hr01.size / H))
        pad = H * W - hr01.size
        if pad > 0:
            hr01 = np.pad(hr01, (0, pad), constant_values=0)
            sr01 = np.pad(sr01, (0, pad), constant_values=0)
        hr_img = hr01.reshape(H, W)
        sr_img = sr01.reshape(H, W)
        ssim_val = float(sk_ssim(hr_img, sr_img, data_range=1.0))

print("\n===== Final Model Metrics (Full Image) =====")
print(f"PSNR: {psnr_val:.4f} dB")
print(f"SSIM: {ssim_val:.4f}")
print(f"MAE : {mae_val:.6f}")
print(f"MSE : {mse_val:.6f}")
print(f"RMSE: {rmse_val:.6f}")
print("============================================\n")

# ====== 保存 GeoTIFF ======
import rasterio
with rasterio.open(hr_image_path) as src:
    profile = src.profile
profile.update(count=1, dtype='float32')

out_tif = 'JAG/srformer_reconstructed_4.tif'
with rasterio.open(out_tif, 'w', **profile) as dst:
    dst.write(sr_orig, 1)

print(f">>> 推理完成，GeoTIFF 保存：{out_tif}")

In [None]:
#####FGSA

In [None]:
import os
import time
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from einops import rearrange
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

warnings.filterwarnings("ignore", category=UserWarning)
os.makedirs('JAG', exist_ok=True)

NUM_EPOCHS      = 1000
INIT_LR         = 1e-4
BATCH_SIZE      = 32
EVAL_EVERY      = 1         
VAL_RATIO       = 0.1
VAL_MAX_SAMPLES = 512
DS_FACTOR       = 4
EARLY_STOP      = 6
USE_COMPILE     = False




MODEL_PATH      = 'JAG/best_srcnn_fma_small_anom_aggressive.pth'

hr_image_path   = 'JAG/磁异常.tif'


data = np.load('JAG/preprocessed_data_4.npz', allow_pickle=True)
hr_patches    = data['hr_patches']         # (N, H, W), 0~1
lr_patches    = data['lr_patches']         # (N, H, W), 0~1
hr_array_norm = data['hr_array_norm']      # (H, W),   0~1
lr_array_norm = data['lr_array_norm']      # (H, W),   0~1
hr_min  = float(data['hr_min'])
hr_max  = float(data['hr_max'])
hr_range = float(data['hr_range'])

def window_partition(x, window_size):
    B, C, H, W = x.shape
    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    if pad_h or pad_w:
        x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
    Hp, Wp = x.shape[-2:]
    x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
    x = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size)
    return x, Hp, Wp, (pad_h, pad_w)

def window_unpartition(windows, Hp, Wp, window_size, pad):
    pad_h, pad_w = pad
    B_ = windows.shape[0] // ((Hp // window_size) * (Wp // window_size))
    C = windows.shape[1]
    x = windows.view(B_, Hp // window_size, Wp // window_size, C, window_size, window_size)
    x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(B_, C, Hp, Wp)
    if pad_h or pad_w:
        x = x[:, :, :Hp - pad_h, :Wp - pad_w]
    return x

class LayerNormCF(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        return self.weight[:, None, None] * x + self.bias[:, None, None]

class FourierUnitEnhanced(nn.Module):

    def __init__(self, dim, groups=1, fft_norm='ortho', bottleneck_ratio=0.75):
        super().__init__()
        self.groups = groups
        self.fft_norm = fft_norm
        mid = max(1, int(dim * bottleneck_ratio))
        self.proj_in  = nn.Conv2d(dim * 2, mid * 2, 1, bias=False, groups=self.groups)
        self.act      = nn.GELU()
        self.proj_out = nn.Conv2d(mid * 2, dim * 2, 1, bias=False, groups=self.groups)
        self.mag_gate = nn.Sequential(
            nn.Conv2d(dim, max(1, dim // 2), 1, bias=True),
            nn.GELU(),
            nn.Conv2d(max(1, dim // 2), dim, 1, bias=True),
            nn.Sigmoid()
        )
    def forward(self, x):
        B, C, H, W = x.size()
        with torch.amp.autocast('cuda', enabled=False):
            x32 = x.float()
            Xf = torch.fft.rfft2(x32, norm=self.fft_norm)   # complex64
            mag = torch.abs(Xf).float()
            gate = self.mag_gate(mag)
            Xf = Xf * gate
            real = Xf.real
            imag = Xf.imag
            cat = torch.cat([real, imag], dim=1).float()
            y = self.proj_in(cat); y = self.act(y); y = self.proj_out(y)
            real2, imag2 = torch.chunk(y, 2, dim=1)
            Xf2 = torch.complex(real2, imag2)
            out32 = torch.fft.irfft2(Xf2, s=(H, W), norm=self.fft_norm).float()
        return out32.to(x.dtype)

class FMAPlus(nn.Module):

    def __init__(self, dim=64, num_heads=16, window_size=3, bottleneck_ratio=0.75, temp_init=3.0):
        super().__init__()
        assert dim % num_heads == 0, 
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.window_size = window_size

        self.norm = LayerNormCF(dim, eps=1e-6)
        self.fourier = FourierUnitEnhanced(dim, bottleneck_ratio=bottleneck_ratio)

        self.v_proj1 = nn.Conv2d(dim, dim, kernel_size=1, bias=False)
        self.v_dw    = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=False)
        self.v_act   = nn.GELU()
        self.v_proj2 = nn.Conv2d(dim, dim, kernel_size=1, bias=False)

        self._temp = nn.Parameter(torch.ones(num_heads) * float(temp_init))
        self.softplus = nn.Softplus()

        self.cpe  = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
        self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=True)
        self.layer_scale = nn.Parameter(1e-6 * torch.ones(dim))

    def forward(self, x):
        B, C, H, W = x.shape
        shortcut = x
        pos = self.cpe(x)
        x = self.norm(x)

        A = self.fourier(x)
        V = self.v_proj1(x); V = self.v_dw(V); V = self.v_act(V); V = self.v_proj2(V)

        w = self.window_size
        A_win, Hp, Wp, pad = window_partition(A, w)
        V_win, _,  _,  _   = window_partition(V, w)

        A_tok = rearrange(A_win, 'bn (h c) a b -> bn h c (a b)', h=self.num_heads)
        V_tok = rearrange(V_win, 'bn (h c) a b -> bn h c (a b)', h=self.num_heads)

        temp = self.softplus(self._temp).view(1, self.num_heads, 1, 1)
        attn = (A_tok * V_tok) * temp
        attn = F.softmax(attn, dim=-1)

        Xw = rearrange(attn, 'bn h c (a b) -> bn (h c) a b', a=w, b=w)
        X  = window_unpartition(Xw, Hp, Wp, w, pad)

        X  = X + pos
        X  = self.proj(X)
        X  = self.layer_scale.view(1, -1, 1, 1) * X
        return X + shortcut

class MultiScaleFMA(nn.Module):

    def __init__(self, dim=64, heads=16, win_small=3, win_mid=5, bottle=0.75, temp_init=3.0):
        super().__init__()
        self.fma_s = FMAPlus(dim, heads, win_small, bottle, temp_init)
        self.fma_m = FMAPlus(dim, heads, win_mid,   bottle, temp_init)
        self.gate  = nn.Parameter(torch.tensor(0.5))
    def forward(self, x):
        xs = self.fma_s(x)
        xm = self.fma_m(x)
        g  = torch.sigmoid(self.gate)
        return g*xs + (1-g)*xm

class HighPassResidual(nn.Module):

    def __init__(self, ch=64):
        super().__init__()
        self.dw = nn.Conv2d(ch, ch, 3, 1, 1, groups=ch, bias=False)
        k = torch.tensor([[0., -1.,  0.],
                          [-1., 4., -1.],
                          [0., -1.,  0.]]).view(1,1,3,3)
        self.register_buffer('lap', k)
        self.alpha = nn.Parameter(torch.tensor(0.5))
    def forward(self, x):
        y1 = self.dw(x)
        y2 = F.conv2d(x, self.lap.to(dtype=x.dtype, device=x.device).expand(x.size(1),1,3,3),
                      padding=1, groups=x.size(1))
        a = torch.sigmoid(self.alpha)
        return x + a*y1 + (1-a)*y2
class SRCNN_FMA_AnomAgg(nn.Module):
    def __init__(self, dim=64, heads=16):
        super().__init__()
        self.conv1 = nn.Conv2d(1, dim, kernel_size=9, padding=4)
        self.relu1 = nn.ReLU(inplace=True)
        self.hires = HighPassResidual(ch=dim)
        self.ms1 = MultiScaleFMA(dim=dim, heads=heads, win_small=3, win_mid=5, bottle=0.75, temp_init=3.0)
        self.ms2 = MultiScaleFMA(dim=dim, heads=heads, win_small=3, win_mid=5, bottle=0.75, temp_init=3.0)
        self.ms3 = MultiScaleFMA(dim=dim, heads=heads, win_small=3, win_mid=5, bottle=0.75, temp_init=3.0)
        self.conv2 = nn.Conv2d(dim, 32, kernel_size=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(32, 1, kernel_size=5, padding=2)
    def forward(self, x):
        inp = x
        x = self.relu1(self.conv1(x))
        x = self.hires(x)
        x = self.ms1(x)
        x = self.ms2(x)
        x = self.ms3(x)
        x = self.relu2(self.conv2(x))
        x = self.conv3(x)
        return x + inp

class SRDataset(Dataset):
    def __init__(self, hr_patches, lr_patches):
        self.hr = np.ascontiguousarray(hr_patches.astype(np.float32))
        self.lr = np.ascontiguousarray(lr_patches.astype(np.float32))
    def __len__(self): return self.hr.shape[0]
    def __getitem__(self, idx):
        lr = torch.from_numpy(self.lr[idx]).unsqueeze(0)
        hr = torch.from_numpy(self.hr[idx]).unsqueeze(0)
        return lr, hr

def patch_variance_weights(arr):
    v = np.var(arr.reshape(arr.shape[0], -1), axis=1) + 1e-6
    w = v / v.mean()
    w = np.clip(w, 0.2, 5.0)
    return w.astype(np.float64)

N = len(hr_patches)
idx_all = np.arange(N); np.random.shuffle(idx_all)
split = int(N * (1 - VAL_RATIO))
train_idx, val_idx = idx_all[:split], idx_all[split:]

train_set = SRDataset(hr_patches[train_idx], lr_patches[train_idx])
val_set   = SRDataset(hr_patches[val_idx],  lr_patches[val_idx])

weights = patch_variance_weights(hr_patches[train_idx])
sampler = WeightedRandomSampler(weights=weights, num_samples=len(train_idx), replacement=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = min(8, os.cpu_count() or 4)
pin = torch.cuda.is_available()

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=num_workers,
    pin_memory=pin,
    persistent_workers=(num_workers > 0),
    prefetch_factor=4,
    drop_last=True,
)
val_loader = DataLoader(
    val_set,
    batch_size=min(BATCH_SIZE, 32),
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin,
    persistent_workers=(num_workers > 0),
    prefetch_factor=4,
)

use_torchmetrics = True
try:
    from torchmetrics.functional import structural_similarity_index_measure as tm_ssim
    from torchmetrics.functional import peak_signal_noise_ratio as tm_psnr
except Exception:
    use_torchmetrics = False
    from skimage.metrics import structural_similarity as sk_ssim

def _downsample_torch(x, factor=1):
    if factor <= 1: return x
    return F.avg_pool2d(x, kernel_size=factor, stride=factor, ceil_mode=False)

@torch.no_grad()
def evaluate_on_val(model, loader, eval_max=512, ds_factor=2, device='cuda'):
    model.eval()
    n_seen = 0
    psnr_list, ssim_list, mae_list, rmse_list = [], [], [], []
    for lr, hr in loader:
        if n_seen >= eval_max: break
        lr = lr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        hr = hr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
            sr = model(lr)
        H = min(sr.shape[-2], hr.shape[-2]); W = min(sr.shape[-1], hr.shape[-1])
        sr = sr[..., :H, :W].clamp_(0, 1); hr = hr[..., :H, :W].clamp_(0, 1)
        sr_eval = _downsample_torch(sr, ds_factor)
        hr_eval = _downsample_torch(hr, ds_factor)
        if use_torchmetrics:
            psnr = tm_psnr(sr_eval, hr_eval, data_range=1.0)
            ssim = tm_ssim(sr_eval, hr_eval, data_range=1.0)
            mae  = torch.mean(torch.abs(sr_eval - hr_eval))
            mse  = torch.mean((sr_eval - hr_eval) ** 2)
            rmse = torch.sqrt(mse)
            psnr_list.append(psnr.detach().float().item())
            ssim_list.append(ssim.detach().float().item())
            mae_list.append(mae.detach().float().item())
            rmse_list.append(rmse.detach().float().item())
        else:
            sr_cpu = sr_eval.squeeze(1).detach().cpu().numpy()
            hr_cpu = hr_eval.squeeze(1).detach().cpu().numpy()
            for i in range(sr_cpu.shape[0]):
                diff = hr_cpu[i] - sr_cpu[i]
                mse  = float(np.mean(diff * diff))
                rmse = float(np.sqrt(mse))
                psnr = 20 * np.log10(1.0 / (np.sqrt(mse) + 1e-12))
                ssim = sk_ssim(hr_cpu[i], sr_cpu[i], data_range=1.0)
                mae  = float(np.mean(np.abs(diff)))
                psnr_list.append(psnr); ssim_list.append(ssim); mae_list.append(mae); rmse_list.append(rmse)
        n_seen += lr.size(0)
    if not psnr_list: return 0.0, 0.0, 0.0, 0.0
    return float(np.mean(psnr_list)), float(np.mean(ssim_list)), float(np.mean(mae_list)), float(np.mean(rmse_list))

criterion = nn.L1Loss()

torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

model = SRCNN_FMA_AnomAgg(dim=64, heads=16).to(device).to(memory_format=torch.channels_last)
if USE_COMPILE:
    try:
        model = torch.compile(model, mode="reduce-overhead")
    except Exception:
        pass

optimizer = optim.AdamW(model.parameters(), lr=INIT_LR, fused=torch.cuda.is_available())
scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available())


scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3, verbose=True, min_lr=1e-6
)

best_score = -1e9
epochs_no_improve = 0

def _get_lr(optim_):
    for pg in optim_.param_groups:
        return pg.get("lr", None)

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    t0 = time.time()
    running = 0.0
    for lr, hr in train_loader:
        lr = lr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        hr = hr.to(device, non_blocking=True).to(memory_format=torch.channels_last)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
            pred = model(lr)
            loss = criterion(pred.clamp(0,1), hr)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running += loss.item()

    avg_loss = running / max(1, len(train_loader))
    log = f"[Epoch {epoch}/{NUM_EPOCHS}] lr={_get_lr(optimizer):.2e} loss={avg_loss:.6f} time={time.time()-t0:.1f}s"

    if epoch % EVAL_EVERY == 0:
        psnr, ssim, mae, rmse = evaluate_on_val(model, val_loader, eval_max=VAL_MAX_SAMPLES, ds_factor=DS_FACTOR, device=device)
        score = W_SSIM * ssim + W_MAE * (1.0 - mae)
        log += f" | val_mae={mae:.4f} val_rmse={rmse:.4f} val_ssim={ssim:.4f} score={score:.4f}"

        scheduler.step(score)

        if score > best_score + 1e-6:
            best_score = score
            epochs_no_improve = 0
            torch.save(model.state_dict(), MODEL_PATH)
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= EARLY_STOP:
            print(log); print(">>> 早停触发（组合分数连续无改善）"); break

    print(log)

if not os.path.exists(MODEL_PATH):
    torch.save(model.state_dict(), MODEL_PATH)

@torch.no_grad()
def predict_full(model, lr_image_01, device='cuda'):
    model.eval()
    x = torch.from_numpy(lr_image_01.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(device)
    with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
        y = model(x).squeeze(0).squeeze(0).clamp_(0, 1)
    return y.detach().cpu().numpy()

best = SRCNN_FMA_AnomAgg(dim=64, heads=16).to(device)
state = torch.load(MODEL_PATH, map_location=device)
best.load_state_dict(state, strict=True)
best.eval()

sr_01 = predict_full(best, lr_array_norm, device=device)


sr_orig = (sr_01 * hr_range) + hr_min
sr_orig = np.clip(sr_orig, hr_min, hr_max).astype(np.float32)
hr_orig = (hr_array_norm * hr_range) + hr_min

hr_clean = np.nan_to_num(hr_orig, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float64)
sr_clean = np.nan_to_num(sr_orig, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float64)
finite_mask = np.isfinite(hr_clean) & np.isfinite(sr_clean)
if finite_mask.sum() == 0:
    raise RuntimeError()
hr_f = hr_clean[finite_mask]; sr_f = sr_clean[finite_mask]
hr_min_f = float(np.min(hr_f)); hr_max_f = float(np.max(hr_f))
data_range_f = hr_max_f - hr_min_f
diff = sr_f - hr_f
mse_val = float(np.mean(diff * diff)); mae_val = float(np.mean(np.abs(diff))); rmse_val = float(np.sqrt(mse_val))

eps = 1e-12
if data_range_f < eps:
    psnr_val = float('inf') if mse_val < eps else 20 * np.log10((1.0) / np.sqrt(mse_val + eps))
else:
    psnr_val = 20 * np.log10(data_range_f / np.sqrt(mse_val + eps))

if data_range_f < eps:
    ssim_val = 1.0 if mse_val < eps else 0.0
else:
    from skimage.metrics import structural_similarity as sk_ssim
    hr01 = (hr_f - hr_min_f) / (data_range_f + eps)
    sr01 = (sr_f - hr_min_f) / (data_range_f + eps)
    H = int(np.sqrt(hr01.size)) or 1; W = int(np.ceil(hr01.size / H))
    pad = H * W - hr01.size
    if pad > 0:
        hr01 = np.pad(hr01, (0, pad), constant_values=0)
        sr01 = np.pad(sr01, (0, pad), constant_values=0)
    hr_img = hr01.reshape(H, W); sr_img = sr01.reshape(H, W)
    ssim_val = float(sk_ssim(hr_img, sr_img, data_range=1.0))

print("\n===== Final Model Metrics (Full Image) =====")
print(f"PSNR: {psnr_val:.4f} dB")
print(f"SSIM: {ssim_val:.4f}")
print(f"MAE : {mae_val:.6f}")
print(f"MSE : {mse_val:.6f}")
print(f"RMSE: {rmse_val:.6f}")
print("============================================\n")

import rasterio
with rasterio.open(hr_image_path) as src:
    profile = src.profile
profile.update(count=1, dtype='float32')

out_tif = 'JAG/fgsa_reconstructed4.tif'
with rasterio.open(out_tif, 'w', **profile) as dst:
    dst.write(sr_orig, 1)

print(f">>> 推理完成，GeoTIFF 保存：{out_tif}")