In [None]:
import os
import math
import logging
from glob import glob

import numpy as np
import rasterio
import torch
import torch.nn as nn
import torch.nn.functional as F
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm

# -------------------------
# Paths / constants
# -------------------------
TILES_DIR   = "tiles"                          # folder with tile_0000.tif etc.
MODEL_PATH  = os.path.join("models", "hls_ssl4eo_best.pth")  # or ssl4eo_best.pth
RESULTS_CSV = "results_tiles_eval2.csv"

UPSCALE     = 2          # must match training
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s:%(name)s: %(message)s"
)
logger = logging.getLogger("tiles_eval")

# -------------------------
# Band mapping (same as training)
# -------------------------
BAND_IDX = {
    "B1": 1,
    "B2": 2,
    "B3": 3,
    "B4": 4,
    "B5": 5,
    "B6": 6,
    "B7": 7,
    "B8": 8,
    "B9": 9,
    "B10": 10,
    "B11": 11,
}

# -------------------------
# Utilities from training
# -------------------------
def norm_np(a: np.ndarray) -> np.ndarray:
    """Per-band min-max normalization to [0,1] with NaN/Inf protection."""
    a = np.array(a, dtype=np.float32)
    if np.isnan(a).any() or np.isinf(a).any():
        a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
    mn = float(np.nanmin(a))
    mx = float(np.nanmax(a))
    if mx - mn < 1e-6:
        return np.zeros_like(a, dtype=np.float32)
    return ((a - mn) / (mx - mn)).astype(np.float32)

def compute_metrics(pred: np.ndarray, target: np.ndarray):
    """PSNR / SSIM / RMSE on [0,1] normalized arrays."""
    pred   = np.nan_to_num(pred,   nan=0.0, posinf=1.0, neginf=0.0)
    target = np.nan_to_num(target, nan=0.0, posinf=1.0, neginf=0.0)
    mse = float(np.mean((pred - target) ** 2))
    if not np.isfinite(mse) or mse < 1e-12:
        psnr_val = 100.0
        rmse_val = 0.0
    else:
        psnr_val = 10 * math.log10(1.0 / mse)
        rmse_val = math.sqrt(mse)
    try:
        ssim_val = ssim(target, pred, data_range=1.0)
    except Exception:
        ssim_val = 0.0
    return psnr_val, ssim_val, rmse_val

# -------------------------
# Model (must match training)
# -------------------------
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        y = self.avgpool(x)
        y = self.fc(y)
        return x * y

class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, max(8, in_channels//2), kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(8, in_channels//2), 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
    def forward(self, x):
        att = self.conv(x)
        return x * att

class RCAB(nn.Module):
    def __init__(self, channels, kernel_size=3, reduction=16):
        super().__init__()
        pad = kernel_size // 2
        self.body = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size, padding=pad),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size, padding=pad)
        )
        self.ca = ChannelAttention(channels, reduction=reduction)
        self.res_scale = 0.1
    def forward(self, x):
        res = self.body(x)
        res = self.ca(res)
        return x + res * self.res_scale

class ResidualGroup(nn.Module):
    def __init__(self, channels, n_rcab=4):
        super().__init__()
        layers = [RCAB(channels) for _ in range(n_rcab)]
        self.body = nn.Sequential(*layers)
    def forward(self, x):
        return self.body(x) + x

class LearnedUpsampler(nn.Module):
    def __init__(self, in_channels, out_channels, scale=UPSCALE):
        super().__init__()
        self.scale = scale
        self.proj = nn.Conv2d(in_channels, out_channels * (scale*scale),
                              kernel_size=3, padding=1)
        self.post = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x, target_size=None):
        x = self.proj(x)
        x = torch.nn.functional.pixel_shuffle(x, self.scale)
        x = self.post(x)
        if target_size is not None:
            x = torch.nn.functional.interpolate(
                x, size=target_size, mode="bilinear", align_corners=False
            )
        return x

class DualEDSRPlus(nn.Module):
    def __init__(self, n_resgroups=4, n_rcab=4, n_feats=64, upscale=UPSCALE):
        super().__init__()
        self.upscale = upscale
        self.n_feats = n_feats

        self.convT_in = nn.Conv2d(1, n_feats, 3, padding=1)
        self.convO_in = nn.Conv2d(3, n_feats, 3, padding=1)

        self.t_groups = nn.Sequential(
            *[ResidualGroup(n_feats, n_rcab) for _ in range(n_resgroups)]
        )
        self.o_groups = nn.Sequential(
            *[ResidualGroup(n_feats, n_rcab) for _ in range(n_resgroups)]
        )

        self.t_upsampler = LearnedUpsampler(n_feats, n_feats, scale=upscale)

        self.convFuse = nn.Conv2d(2 * n_feats, n_feats, kernel_size=1)
        self.fuse_ca  = ChannelAttention(n_feats)
        self.fuse_sa  = SpatialAttention(n_feats)

        self.refine = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.convOut = nn.Conv2d(n_feats, 1, kernel_size=3, padding=1)

        # Kaiming init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in",
                                        nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, xT, xO):
        fT = torch.nn.functional.relu(self.convT_in(xT))
        fO = torch.nn.functional.relu(self.convO_in(xO))

        fT = self.t_groups(fT)
        fO = self.o_groups(fO)

        fT_up_raw = self.t_upsampler(fT)
        target_hw = (fO.shape[2], fO.shape[3])
        fT_up = torch.nn.functional.interpolate(
            fT_up_raw, size=target_hw, mode="bilinear", align_corners=False
        )

        f = torch.cat([fT_up, fO], dim=1)
        f = torch.nn.functional.relu(self.convFuse(f))
        f = self.fuse_ca(f)
        f = self.fuse_sa(f)
        f = self.refine(f)
        out = self.convOut(f)
        return out

# -------------------------
# Load model
# -------------------------
def load_model(model_path: str) -> DualEDSRPlus:
    model = DualEDSRPlus().to(DEVICE)
    ckpt = torch.load(model_path, map_location=DEVICE)
    # handle both "model_state" dict or plain state_dict
    if isinstance(ckpt, dict) and "model_state" in ckpt:
        state = ckpt["model_state"]
    else:
        state = ckpt
    model.load_state_dict(state)
    model.eval()
    logger.info(f"Loaded model from {model_path}")
    return model

# -------------------------
# Evaluate on all tiles
# -------------------------
def evaluate_tiles(tiles_dir: str, model_path: str):
    model = load_model(model_path)
    tile_paths = sorted(glob(os.path.join(tiles_dir, "*.tif")))
    if len(tile_paths) == 0:
        logger.error(f"No .tif files found in {tiles_dir}")
        return

    import csv
    rows = []

    ps_sum = ss_sum = rm_sum = 0.0
    count_tiles = 0

    logger.info(f"Evaluating on {len(tile_paths)} tiles from {tiles_dir}")
    for tile_path in tqdm(tile_paths, desc="Tiles"):
        with rasterio.open(tile_path) as src:
            rgb = src.read([
                BAND_IDX["B2"],
                BAND_IDX["B3"],
                BAND_IDX["B4"],
            ]).astype(np.float32)

            thr = src.read(BAND_IDX["B10"]).astype(np.float32)

        rgb_n = np.stack([norm_np(rgb[c]) for c in range(3)], axis=0)
        thr_n = norm_np(thr)

        H, W = thr_n.shape
        H_lr, W_lr = H // UPSCALE, W // UPSCALE

        lr = torch.from_numpy(thr_n).unsqueeze(0).unsqueeze(0)
        lr_down = torch.nn.functional.interpolate(
            lr.float(), size=(H_lr, W_lr), mode="bilinear", align_corners=False
        )

        xT = lr_down.to(DEVICE)
        xO = torch.from_numpy(rgb_n).unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            pred_hr = model(xT, xO)

        pred = pred_hr.squeeze().cpu().numpy()
        gt   = thr_n

        ps, ss, rm = compute_metrics(pred, gt)

        ps_sum += ps
        ss_sum += ss
        rm_sum += rm
        count_tiles += 1

        logger.info(
            f"TILE {os.path.basename(tile_path)} -> "
            f"PSNR={ps:.3f} dB, SSIM={ss:.4f}, RMSE={rm:.6f}"
        )
        rows.append({
            "tile": os.path.basename(tile_path),
            "psnr": ps,
            "ssim": ss,
            "rmse": rm,
            "height": H,
            "width": W,
        })

    if count_tiles > 0:
        avg_ps = ps_sum / count_tiles
        avg_ss = ss_sum / count_tiles
        avg_rm = rm_sum / count_tiles

        logger.info(
            f"OVERALL on {tiles_dir} -> "
            f"PSNR={avg_ps:.3f} dB, SSIM={avg_ss:.4f}, RMSE={avg_rm:.6f}"
        )
        rows.append({
            "tile": "ALL",
            "psnr": avg_ps,
            "ssim": avg_ss,
            "rmse": avg_rm,
            "height": -1,
            "width": -1,
        })

    # Save CSV with per-tile and overall metrics
    with open(RESULTS_CSV, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=["tile", "psnr", "ssim", "rmse", "height", "width"])
        writer.writeheader()
        for r in rows:
            writer.writerow(r)
    logger.info(f"Saved tile evaluation summary -> {RESULTS_CSV}")

if __name__ == "__main__":
    evaluate_tiles(TILES_DIR, MODEL_PATH)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv("results_tiles_eval2.csv")

# Plot PSNR per tile
plt.figure()
plt.plot(df.index, df["psnr"], marker="o", linestyle="-")
plt.xlabel("Tile Index")
plt.ylabel("PSNR (dB)")
plt.title("PSNR per Tile")
plt.grid(True)
plt.show()

# Plot SSIM
plt.figure()
plt.plot(df.index, df["ssim"], marker="o", linestyle="-")
plt.xlabel("Tile Index")
plt.ylabel("SSIM")
plt.title("SSIM per Tile")
plt.grid(True)
plt.show()

# Plot RMSE
plt.figure()
plt.plot(df.index, df["rmse"], marker="o", linestyle="-")
plt.xlabel("Tile Index")
plt.ylabel("RMSE")
plt.title("RMSE per Tile")
plt.grid(True)
plt.show()
