In [2]:
import os
import sys
import math
import logging
from glob import glob
from typing import List, Dict

import numpy as np
import rasterio
from rasterio.enums import Resampling as RioResampling
from rasterio.warp import reproject
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm
import earthaccess

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
RAW_DIR   = "data_hls_raw"
OUT_DIR   = "data_hls_processed"
MODELS_DIR = "models"

os.makedirs(RAW_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)

# HLS region / time (you can change)
BBOX        = (78.3, 17.2, 78.7, 17.6)
DATE_RANGE  = ("2023-01-01", "2023-01-31")
MAX_GRANULES = 8  # number of L30 + S30 granules to download

In [4]:
UPSCALE       = 2            # 2x SR (HR = 10m if LR ~ 20m / 30m etc.)
HR_PATCH      = 128
LR_PATCH      = HR_PATCH // UPSCALE
BATCH_SIZE    = 4
NUM_EPOCHS    = 40
LEARNING_RATE = 1e-4
PHYS_LAMBDA   = 0.1          # physics-aware loss weight

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s:%(name)s: %(message)s"
)
logger = logging.getLogger("hls_ssl4eo")

In [5]:
def norm_np(a: np.ndarray) -> np.ndarray:
    """Per-band min-max normalization to [0,1] with NaN/Inf protection."""
    a = np.array(a, dtype=np.float32)
    if np.isnan(a).any() or np.isinf(a).any():
        a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
    mn = float(np.nanmin(a))
    mx = float(np.nanmax(a))
    if mx - mn < 1e-6:
        return np.zeros_like(a, dtype=np.float32)
    return ((a - mn) / (mx - mn)).astype(np.float32)


def compute_metrics(pred: np.ndarray, target: np.ndarray):
    """PSNR / SSIM / RMSE on [0,1] normalized arrays."""
    pred   = np.nan_to_num(pred,   nan=0.0, posinf=1.0, neginf=0.0)
    target = np.nan_to_num(target, nan=0.0, posinf=1.0, neginf=0.0)
    mse = float(np.mean((pred - target) ** 2))
    if not np.isfinite(mse) or mse < 1e-12:
        psnr_val = 100.0
        rmse_val = 0.0
    else:
        psnr_val = 10 * math.log10(1.0 / mse)
        rmse_val = math.sqrt(mse)
    try:
        ssim_val = ssim(target, pred, data_range=1.0)
    except Exception:
        ssim_val = 0.0
    return psnr_val, ssim_val, rmse_val


In [6]:
def earthdata_download_hls(
    bbox=BBOX, date_range=DATE_RANGE,
    max_granules=MAX_GRANULES, out_dir=RAW_DIR
):
    """
    Downloads HLSL30 (thermal) and HLSS30 (optical) granules for a region/time.
    Returns two lists of paths: (l30_files, s30_files)
    """
    logger.info("Logging into Earthdata (HLS).")
    auth = earthaccess.login()  # interactive if not already logged in

    logger.info("Searching HLSL30 (Landsat-based, includes thermal band)...")
    results_L30 = earthaccess.search_data(
        short_name="HLSL30",
        temporal=date_range,
        bounding_box=bbox,
    )
    logger.info(f"Found {len(results_L30)} L30 granules. Downloading up to {max_granules}...")
    l30_downloads = earthaccess.download(results_L30[:max_granules], out_dir)

    logger.info("Searching HLSS30 (Sentinel-2-based optical B02/B03/B04)...")
    results_S30 = earthaccess.search_data(
        short_name="HLSS30",
        temporal=date_range,
        bounding_box=bbox,
    )
    logger.info(f"Found {len(results_S30)} S30 granules. Downloading up to {max_granules}...")
    s30_downloads = earthaccess.download(results_S30[:max_granules], out_dir)

    l30_files = list(l30_downloads) if l30_downloads is not None else []
    s30_files = list(s30_downloads) if s30_downloads is not None else []
    return l30_files, s30_files


In [7]:
def extract_and_save_bands(l30_files: List[str], s30_files: List[str], out_dir=OUT_DIR):
    """
    For each L30+S30 pair:
      - Extract thermal from L30 (single-band).
      - Extract B02/B03/B04 from S30 (as separate TIFFs).
    Returns list of dicts: each dict has paths:
      { "thermal30", "b02", "b03", "b04" }
    """
    saved_pairs: List[Dict[str, str]] = []
    n = min(len(l30_files), len(s30_files))
    logger.info(f"Extracting bands for {n} matched HLS L30/S30 pairs.")

    def find_band_file(dirpath, band_codes):
        from glob import glob as gglob
        for code in band_codes:
            candidates = gglob(os.path.join(dirpath, f"*{code}*.tif"))
            if candidates:
                return candidates[0]
        return None

    for i in range(n):
        l30 = l30_files[i]
        s30 = s30_files[i]
        try:
            with rasterio.open(l30) as src:
                # For HLSL30, thermal typically is one band (already subsetted).
                thermal = src.read(1)
                prof = src.profile
        except Exception as e:
            logger.error(f"Unable to read L30 file {l30}: {e}")
            continue

        s30_dir = os.path.dirname(s30)
        b02_path = find_band_file(s30_dir, ["B02", "b02", "B2"])
        b03_path = find_band_file(s30_dir, ["B03", "b03", "B3"])
        b04_path = find_band_file(s30_dir, ["B04", "b04", "B4"])
        optical_profile = None

        if not (b02_path and b03_path and b04_path):
            # try reading as multi-band
            try:
                with rasterio.open(s30) as ssrc:
                    if ssrc.count >= 3:
                        b02_arr = ssrc.read(1)
                        b03_arr = ssrc.read(2)
                        b04_arr = ssrc.read(3)
                        optical_profile = ssrc.profile
                    else:
                        raise RuntimeError("S30 granule doesn't have RGB bands.")
            except Exception as e:
                logger.error(f"Unable to extract RGB bands for S30 granule {s30}: {e}")
                continue
        else:
            with rasterio.open(b02_path) as sb:
                b02_arr = sb.read(1)
                optical_profile = sb.profile
            with rasterio.open(b03_path) as sb:
                b03_arr = sb.read(1)
            with rasterio.open(b04_path) as sb:
                b04_arr = sb.read(1)

        base = f"hls_sample_{i:03d}"
        t_path   = os.path.join(out_dir, f"{base}_thermal_30m.tif")
        b02_out  = os.path.join(out_dir, f"{base}_B02_10m.tif")
        b03_out  = os.path.join(out_dir, f"{base}_B03_10m.tif")
        b04_out  = os.path.join(out_dir, f"{base}_B04_10m.tif")

        # Write thermal
        t_prof = prof.copy()
        t_prof.update(driver="GTiff", count=1, dtype=thermal.dtype)
        with rasterio.open(t_path, "w", **t_prof) as dst:
            dst.write(thermal, 1)

        # Write optical bands
        opt_prof = optical_profile.copy()
        opt_prof.update(driver="GTiff", count=1, dtype=b02_arr.dtype)
        with rasterio.open(b02_out, "w", **opt_prof) as dst:
            dst.write(b02_arr, 1)
        with rasterio.open(b03_out, "w", **opt_prof) as dst:
            dst.write(b03_arr, 1)
        with rasterio.open(b04_out, "w", **opt_prof) as dst:
            dst.write(b04_arr, 1)

        saved_pairs.append({
            "thermal30": t_path,
            "b02": b02_out,
            "b03": b03_out,
            "b04": b04_out
        })

        logger.info(f"Saved HLS sample {i}: {t_path}, {b02_out}, {b03_out}, {b04_out}")

    return saved_pairs

In [8]:
def resample_thermal_to_optical(thermal_path: str, optical_reference_path: str, out_path: str):
    """
    Resample 30m thermal to match the optical reference grid (e.g., ~10m).
    """
    with rasterio.open(optical_reference_path) as ref:
        ref_profile  = ref.profile
        ref_transform = ref.transform
        ref_crs       = ref.crs
        ref_width     = ref.width
        ref_height    = ref.height

    with rasterio.open(thermal_path) as src:
        data = src.read(1)
        target_profile = ref_profile.copy()
        target_profile.update({
            "driver": "GTiff",
            "dtype": "float32",
            "count": 1,
            "width": ref_width,
            "height": ref_height,
            "crs": ref_crs,
            "transform": ref_transform
        })
        dest = np.empty((ref_height, ref_width), dtype=np.float32)
        reproject(
            source=data,
            destination=dest,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=ref_transform,
            dst_crs=ref_crs,
            resampling=RioResampling.bilinear
        )
        with rasterio.open(out_path, "w", **target_profile) as dst:
            dst.write(dest, 1)
    logger.info(f"Resampled thermal -> {out_path}")
    return out_path

In [9]:
class HlsPatchDataset(Dataset):
    """
    For each HLS sample:
      - Reads optical bands (B02,B03,B04) as HR RGB.
      - Reads resampled thermal_10m as HR thermal.
      - Synthesizes LR thermal by downsampling HR by UPSCALE.
      - Returns aligned patches:
          lr_thermal (1, LR, LR),
          hr_rgb     (3, HR, HR),
          hr_thermal (1, HR, HR)
    """
    def __init__(
        self,
        samples: List[Dict[str, str]],
        hr_patch: int = HR_PATCH,
        upscale: int = UPSCALE,
        patches_per_image: int = 32
    ):
        super().__init__()
        self.samples = list(samples)
        self.hr_patch = hr_patch
        self.lr_patch = hr_patch // upscale
        self.upscale = upscale
        self.patches_per_image = patches_per_image

    def __len__(self):
        # approx patches_per_image patches per sample
        return max(len(self.samples) * self.patches_per_image, 1000)

    def _read_arrays_for_sample(self, sample_idx: int):
        s = self.samples[sample_idx % len(self.samples)]

        b02_path = s["b02"]
        b03_path = s["b03"]
        b04_path = s["b04"]
        therm10_path = s["thermal10"]

        # Read optical
        with rasterio.open(b02_path) as sb2:
            r00 = sb2.read(1).astype(np.float32)
            profile = sb2.profile
        with rasterio.open(b03_path) as sb3:
            r01 = sb3.read(1).astype(np.float32)
        with rasterio.open(b04_path) as sb4:
            r02 = sb4.read(1).astype(np.float32)

        rgb = np.stack([r02, r01, r00], axis=0)  # (3,H,W), R,G,B order

        # Read HR thermal (already resampled to optical grid)
        with rasterio.open(therm10_path) as t10f:
            thr_hr = t10f.read(1).astype(np.float32)

        # Normalize per band
        rgb_n = np.stack([norm_np(rgb[c]) for c in range(3)], axis=0)  # (3,H,W)
        thr_hr_n = norm_np(thr_hr)                                     # (H,W)

        return rgb_n, thr_hr_n

    def __getitem__(self, idx):
        # pick random sample, then random patch
        sample_idx = np.random.randint(0, len(self.samples))
        rgb_n, thr_hr_n = self._read_arrays_for_sample(sample_idx)

        H_hr, W_hr = thr_hr_n.shape

        # ensure divisible by UPSCALE
        H_hr = H_hr - (H_hr % self.upscale)
        W_hr = W_hr - (W_hr % self.upscale)
        rgb_n = rgb_n[:, :H_hr, :W_hr]
        thr_hr_n = thr_hr_n[:H_hr, :W_hr]

        # synthesize LR by downsampling HR thermal
        H_lr, W_lr = H_hr // self.upscale, W_hr // self.upscale
        lr_full = F.interpolate(
            torch.from_numpy(thr_hr_n).unsqueeze(0).unsqueeze(0).float(),
            size=(H_lr, W_lr),
            mode="bilinear",
            align_corners=False
        ).squeeze().numpy()  # (H_lr, W_lr)

        # random HR patch
        if H_hr < self.hr_patch or W_hr < self.hr_patch:
            raise ValueError("HR image smaller than HR_PATCH; choose smaller patch or larger AOI")

        max_y = H_hr - self.hr_patch
        max_x = W_hr - self.hr_patch
        y = np.random.randint(0, max_y + 1)
        x = np.random.randint(0, max_x + 1)

        hr_t_patch   = thr_hr_n[y:y + self.hr_patch, x:x + self.hr_patch]   # (HR,HR)
        hr_rgb_patch = rgb_n[:, y:y + self.hr_patch, x:x + self.hr_patch]   # (3,HR,HR)

        ly, lx = y // self.upscale, x // self.upscale
        lr_t_patch = lr_full[ly:ly + self.lr_patch, lx:lx + self.lr_patch]  # (LR,LR)

        # to tensors
        lr_t   = torch.from_numpy(lr_t_patch).unsqueeze(0).float()          # (1,LR,LR)
        hr_rgb = torch.from_numpy(hr_rgb_patch).float()                     # (3,HR,HR)
        hr_t   = torch.from_numpy(hr_t_patch).unsqueeze(0).float()          # (1,HR,HR)

        return lr_t, hr_rgb, hr_t

In [10]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avgpool(x)
        y = self.fc(y)
        return x * y


class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, max(8, in_channels//2), kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(8, in_channels//2), 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        att = self.conv(x)
        return x * att


class RCAB(nn.Module):
    def __init__(self, channels, kernel_size=3, reduction=16):
        super().__init__()
        pad = kernel_size // 2
        self.body = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size, padding=pad),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size, padding=pad)
        )
        self.ca = ChannelAttention(channels, reduction=reduction)
        self.res_scale = 0.1

    def forward(self, x):
        res = self.body(x)
        res = self.ca(res)
        return x + res * self.res_scale


class ResidualGroup(nn.Module):
    def __init__(self, channels, n_rcab=4):
        super().__init__()
        layers = [RCAB(channels) for _ in range(n_rcab)]
        self.body = nn.Sequential(*layers)

    def forward(self, x):
        return self.body(x) + x


class LearnedUpsampler(nn.Module):
    def __init__(self, in_channels, out_channels, scale=UPSCALE):
        super().__init__()
        self.scale = scale
        self.proj = nn.Conv2d(in_channels, out_channels * (scale*scale), kernel_size=3, padding=1)
        self.post = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, target_size=None):
        x = self.proj(x)
        x = F.pixel_shuffle(x, self.scale)
        x = self.post(x)
        if target_size is not None:
            x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
        return x


class DualEDSRPlus(nn.Module):
    def __init__(self, n_resgroups=4, n_rcab=4, n_feats=64, upscale=UPSCALE):
        super().__init__()
        self.upscale = upscale
        self.n_feats = n_feats

        self.convT_in = nn.Conv2d(1, n_feats, 3, padding=1)
        self.convO_in = nn.Conv2d(3, n_feats, 3, padding=1)

        self.t_groups = nn.Sequential(*[ResidualGroup(n_feats, n_rcab) for _ in range(n_resgroups)])
        self.o_groups = nn.Sequential(*[ResidualGroup(n_feats, n_rcab) for _ in range(n_resgroups)])

        self.t_upsampler = LearnedUpsampler(n_feats, n_feats, scale=upscale)

        self.convFuse = nn.Conv2d(2 * n_feats, n_feats, kernel_size=1)
        self.fuse_ca  = ChannelAttention(n_feats)
        self.fuse_sa  = SpatialAttention(n_feats)

        self.refine = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.convOut = nn.Conv2d(n_feats, 1, kernel_size=3, padding=1)

        # Init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, xT, xO):
        fT = F.relu(self.convT_in(xT))
        fO = F.relu(self.convO_in(xO))

        fT = self.t_groups(fT)
        fO = self.o_groups(fO)

        fT_up_raw = self.t_upsampler(fT)
        target_hw = (fO.shape[2], fO.shape[3])
        fT_up = F.interpolate(fT_up_raw, size=target_hw, mode="bilinear", align_corners=False)

        f = torch.cat([fT_up, fO], dim=1)
        f = F.relu(self.convFuse(f))
        f = self.fuse_ca(f)
        f = self.fuse_sa(f)
        f = self.refine(f)
        out = self.convOut(f)
        return out


In [11]:
def partial_load_weights(model: nn.Module, ckpt_path: str, verbose=False):
    """
    Optional: warm-start from your SSL4EO model if you want.
    """
    if not os.path.exists(ckpt_path):
        logger.warning(f"No checkpoint at {ckpt_path}")
        return 0
    src = torch.load(ckpt_path, map_location=DEVICE)
    if isinstance(src, dict) and "model_state" in src:
        src = src["model_state"]
    model_dict = model.state_dict()
    loaded = 0
    for k, v in src.items():
        if k in model_dict and model_dict[k].shape == v.shape:
            model_dict[k] = v
            loaded += 1
        elif verbose:
            logger.info(f"Skipping {k}; mismatch or missing.")
    model.load_state_dict(model_dict)
    logger.info(f"Partial-loaded {loaded} tensors from {ckpt_path}")
    return loaded

In [12]:
def train_hls_ssl4eo():
    # STEP 1: Download HLS data
    logger.info("STEP 1: Download HLS L30/S30 granules")
    l30_files, s30_files = earthdata_download_hls()

    if len(l30_files) == 0 or len(s30_files) == 0:
        logger.error("No HLS granules downloaded; aborting.")
        sys.exit(1)

    # STEP 2: Extract bands (thermal 30m + B02/B03/B04)
    logger.info("STEP 2: Extract thermal + optical bands")
    samples = extract_and_save_bands(l30_files, s30_files, OUT_DIR)
    if len(samples) == 0:
        logger.error("No samples extracted; aborting.")
        sys.exit(1)

    # STEP 3: Resample thermal 30m -> 10m (or whatever optical resolution is)
    logger.info("STEP 3: Resample thermal 30m -> optical grid (~10m)")
    for s in samples:
        t30 = s["thermal30"]
        ref = s["b02"]  # use B02 as reference grid
        t10_out = t30.replace("_thermal_30m.tif", "_thermal_10m.tif")
        resample_thermal_to_optical(t30, ref, t10_out)
        s["thermal10"] = t10_out

    # STEP 4: Train / Val / Test split (on sample list)
    n = len(samples)
    n_train = int(0.7 * n)
    n_val   = int(0.15 * n)
    n_test  = n - n_train - n_val

    train_samples = samples[:n_train]
    val_samples   = samples[n_train:n_train + n_val]
    test_samples  = samples[n_train + n_val:]

    logger.info(f"HLS samples: total={n}, train={len(train_samples)}, val={len(val_samples)}, test={len(test_samples)}")

    # STEP 5: Datasets + loaders
    train_ds = HlsPatchDataset(train_samples, hr_patch=HR_PATCH, upscale=UPSCALE, patches_per_image=32)
    val_ds   = HlsPatchDataset(val_samples,   hr_patch=HR_PATCH, upscale=UPSCALE, patches_per_image=8) if val_samples else None
    test_ds  = HlsPatchDataset(test_samples,  hr_patch=HR_PATCH, upscale=UPSCALE, patches_per_image=8) if test_samples else None

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=0, pin_memory=(DEVICE.type == "cuda")
    )
    val_loader = DataLoader(
        val_ds, batch_size=1, shuffle=False,
        num_workers=0, pin_memory=(DEVICE.type == "cuda")
    ) if val_ds else None
    test_loader = DataLoader(
        test_ds, batch_size=1, shuffle=False,
        num_workers=0, pin_memory=(DEVICE.type == "cuda")
    ) if test_ds else None

    # STEP 6: Model, optimizer, losses
    model = DualEDSRPlus(n_resgroups=4, n_rcab=4, n_feats=64, upscale=UPSCALE).to(DEVICE)

    # Optional warm-start from SSL4EO:
    ssl4eo_best = os.path.join(MODELS_DIR, "ssl4eo_best.pth")
    if os.path.exists(ssl4eo_best):
        logger.info(f"Warm-starting from SSL4EO weights: {ssl4eo_best}")
        partial_load_weights(model, ssl4eo_best, verbose=False)

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    mse_loss  = nn.MSELoss()

    BEST_PATH = os.path.join(MODELS_DIR, "hls_ssl4eo_best.pth")
    LAST_PATH = os.path.join(MODELS_DIR, "hls_ssl4eo_last.pth")

    best_val_psnr = -1e9
    start_epoch   = 1

    # Resume if LAST exists
    if os.path.exists(LAST_PATH):
        try:
            ckpt = torch.load(LAST_PATH, map_location=DEVICE)
            if isinstance(ckpt, dict) and "model_state" in ckpt:
                model.load_state_dict(ckpt["model_state"])
                if "optimizer_state" in ckpt:
                    optimizer.load_state_dict(ckpt["optimizer_state"])
                if "epoch" in ckpt:
                    start_epoch = ckpt["epoch"] + 1
                best_val_psnr = ckpt.get("best_val_psnr", -1e9)
                logger.info(
                    f"Resuming from {LAST_PATH}: start_epoch={start_epoch}, best_val_psnr={best_val_psnr:.3f}"
                )
        except Exception as e:
            logger.warning(f"Failed to load {LAST_PATH}: {e}")

    logger.info(f"Training HLS SSL4EO model from epoch {start_epoch} to {NUM_EPOCHS}")

    # STEP 7: Training loop
    for epoch in range(start_epoch, NUM_EPOCHS + 1):
        model.train()
        running = 0.0
        it = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS} (train)")

        for lr_t, hr_rgb, hr_t in pbar:
            lr_t   = lr_t.to(DEVICE)      # (B,1,LR,LR)
            hr_rgb = hr_rgb.to(DEVICE)    # (B,3,HR,HR)
            hr_t   = hr_t.to(DEVICE)      # (B,1,HR,HR)

            optimizer.zero_grad()
            pred_hr = model(lr_t, hr_rgb)

            # data fidelity loss
            loss_fid = mse_loss(pred_hr, hr_t)

            # physics-aware loss: average back to LR
            pred_lr = F.interpolate(
                pred_hr, size=(lr_t.shape[2], lr_t.shape[3]),
                mode="area"
            )
            loss_phys = mse_loss(pred_lr, lr_t)

            loss = loss_fid + PHYS_LAMBDA * loss_phys
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running += float(loss.item())
            it += 1
            pbar.set_postfix(loss=running / max(1, it))

        avg_train_loss = running / max(1, it)
        logger.info(f"Epoch {epoch} TRAIN loss={avg_train_loss:.6f}")

        # Validation
        if val_loader:
            model.eval()
            ps_sum = ss_sum = rm_sum = 0.0
            cnt = 0
            with torch.no_grad():
                for lr_t, hr_rgb, hr_t in tqdm(val_loader, desc=f"Epoch {epoch} (val)"):
                    lr_t   = lr_t.to(DEVICE)
                    hr_rgb = hr_rgb.to(DEVICE)
                    hr_t   = hr_t.to(DEVICE)

                    out = model(lr_t, hr_rgb)
                    pred = out.cpu().squeeze().numpy()
                    tgt  = hr_t.cpu().squeeze().numpy()
                    ps, ss, rm = compute_metrics(pred, tgt)
                    ps_sum += ps
                    ss_sum += ss
                    rm_sum += rm
                    cnt += 1

            if cnt > 0:
                avg_ps = ps_sum / cnt
                avg_ss = ss_sum / cnt
                avg_rm = rm_sum / cnt
                logger.info(
                    f"Epoch {epoch} VAL: PSNR={avg_ps:.3f} dB, SSIM={avg_ss:.4f}, RMSE={avg_rm:.6f}"
                )
                if avg_ps > best_val_psnr:
                    best_val_psnr = avg_ps
                    torch.save(
                        {
                            "model_state": model.state_dict(),
                            "epoch": epoch,
                            "best_val_psnr": best_val_psnr,
                        },
                        BEST_PATH
                    )
                    logger.info(f"Saved BEST HLS model -> {BEST_PATH} (PSNR={avg_ps:.3f})")

        # Save LAST checkpoint
        torch.save(
            {
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "epoch": epoch,
                "best_val_psnr": best_val_psnr,
            },
            LAST_PATH
        )
        logger.info(f"Saved LAST HLS model -> {LAST_PATH} (epoch={epoch})")

    # STEP 8: Final test evaluation using BEST model
    if test_loader and os.path.exists(BEST_PATH):
        ckpt = torch.load(BEST_PATH, map_location=DEVICE)
        model.load_state_dict(ckpt["model_state"])
        logger.info(f"Loaded BEST HLS model from {BEST_PATH} for TEST evaluation.")

        model.eval()
        ps_sum = ss_sum = rm_sum = 0.0
        cnt = 0
        with torch.no_grad():
            for lr_t, hr_rgb, hr_t in tqdm(test_loader, desc="TEST (HLS)"):
                lr_t   = lr_t.to(DEVICE)
                hr_rgb = hr_rgb.to(DEVICE)
                hr_t   = hr_t.to(DEVICE)

                out = model(lr_t, hr_rgb)
                pred = out.cpu().squeeze().numpy()
                tgt  = hr_t.cpu().squeeze().numpy()
                ps, ss, rm = compute_metrics(pred, tgt)
                ps_sum += ps
                ss_sum += ss
                rm_sum += rm
                cnt += 1

        if cnt > 0:
            avg_ps = ps_sum / cnt
            avg_ss = ss_sum / cnt
            avg_rm = rm_sum / cnt
            logger.info(
                f"HLS TEST SUMMARY (DualEDSRPlus, physics-aware): "
                f"PSNR={avg_ps:.3f} dB, SSIM={avg_ss:.4f}, RMSE={avg_rm:.6f}"
            )

In [None]:
if __name__ == "__main__":
    train_hls_ssl4eo()

2025-12-08 16:20:03,133 INFO:hls_ssl4eo: STEP 1: Download HLS L30/S30 granules
2025-12-08 16:20:03,133 INFO:hls_ssl4eo: Logging into Earthdata (HLS).
2025-12-08 16:20:20,415 INFO:earthaccess.auth: You're now authenticated with NASA Earthdata Login
2025-12-08 16:20:22,733 INFO:hls_ssl4eo: Searching HLSL30 (Landsat-based, includes thermal band)...
2025-12-08 16:20:24,164 INFO:earthaccess.api: Granules found: 8
2025-12-08 16:20:25,348 INFO:hls_ssl4eo: Found 8 L30 granules. Downloading up to 8...
2025-12-08 16:20:25,349 INFO:earthaccess.store:  Getting 8 granules, approx download size: 1.39 GB
QUEUEING TASKS | : 100%|██████████| 120/120 [00:00<00:00, 13234.37it/s]
PROCESSING TASKS | : 100%|██████████| 120/120 [05:54<00:00,  2.95s/it]
COLLECTING RESULTS | : 100%|██████████| 120/120 [00:00<00:00, 489607.47it/s]
2025-12-08 16:26:19,522 INFO:hls_ssl4eo: Searching HLSS30 (Sentinel-2-based optical B02/B03/B04)...
2025-12-08 16:26:23,005 INFO:earthaccess.api: Granules found: 12
2025-12-08 16:26:2