In [23]:
import os, math, logging, tarfile, random
from glob import glob

import numpy as np
import rasterio
from rasterio.warp import Resampling
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm

In [2]:
SSL4EO_DIR   = "data_ssl4eo"
SCENES_ROOT  = os.path.join(SSL4EO_DIR, "scenes")   # will search recursively for all_bands.tif
MODELS_DIR   = "models"

# kept ONLY as future hooks (not used now, but available for pretraining warm-start if you want)
ECOSTRESS_DIR    = "data_ecostress"
PROCESSED_DIR    = "data_processed"
RAW_DIR          = "data_raw"
ECO_BEST         = os.path.join(MODELS_DIR, "ecostress_pretrained_best.pth")
ECO_LAST         = os.path.join(MODELS_DIR, "ecostress_pretrained_last.pth")

os.makedirs(SSL4EO_DIR,  exist_ok=True)
os.makedirs(SCENES_ROOT, exist_ok=True)
os.makedirs(MODELS_DIR,  exist_ok=True)


### Hyperparameters

* We are doing 2x super-resolution 
* For every batch sample: 
    * HR patch is 128x128
    * LR patch is 64x64 (128/2)
* `PHYS_LAMBDA` controls how strong the physics-aware loss is compared to pixel MSE

In [3]:
UPSCALE        = 2         # 2× or 4× are both allowed by PS; here 2× (you can change to 4)
HR_PATCH       = 128       # HR patch size
LR_PATCH       = HR_PATCH // UPSCALE
BATCH_SIZE     = 4
NUM_EPOCHS     = 50
LEARNING_RATE  = 1e-4
PHYS_LAMBDA    = 0.1    

### What is a scene?
* A scene is a full satellite image.
* The data set has 17,500 training scenes and each scene is huge 
* For example if an image is 3000 x 3000 pixels then the scene is huge

### What is a Patch? 
* A patch is a small crop from that big image 
* For example a patch can be 128 x 128 pixels 
* Instead of training on the whole scene (Which is impossible due to memory) we cut the scene into multiple patches 

### Why not cut all possible patches? 
* Because one scene can yield hundreds or millions of patches 
* And 17,500 training scenes can create billions of patches 

### Our solution 
* Each epoch, we choose only 4 random locations from each scene 
* So in 1 Epoch:
    * `Training Scenes` :  17,500
    * `PATCHES_PER_SCENE` =  4
    * `TOTAL_PATCHES_PER_EPOCH` = 17,500 x 4 = 70,000
* Meaning: We will train on 70,000 patches per each epoch 
* But next epoch, we take different random patches from scene 
* Over many epochs, the model sees diverse area from each scene 


In [4]:
PATCHES_PER_SCENE_TRAIN = 4
PATCHES_PER_SCENE_VAL   = 2
PATCHES_PER_SCENE_TEST  = 2

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s:%(name)s: %(message)s"
)
logger = logging.getLogger("infranova_ssl4eo")


In [6]:
SSL4EO_URL = (
    "https://huggingface.co/datasets/torchgeo/ssl4eo_l_benchmark/resolve/main/"
    "ssl4eo_l_oli_tirs_toa_benchmark.tar.gz?download=true"
)

### Bands we use
* B2,B3,B4 - Optical RGB guidance 
* B10,B11 - Thermal IR channels - we treat them as two possible HR thermal signals 
* We do not use B1, B5–B9 here because:
    * B2–B4 are already strong for edges, textures, landcover boundaries → perfect guidance.
    * B10/B11 are the actual temperature-sensitive channels

In [7]:
BAND_IDX = {
    "B2": 2,    # Blue
    "B3": 3,    # Green
    "B4": 4,    # Red
    "B10": 10,  # Thermal IR 1
    "B11": 11   # Thermal IR 2
}

### Normalization 
* `norm_np(a)`
    * Converts data into `float32`
    * Replaces NaN/Inf with reasonable numbers using `np.nan_to_num`
    * Min-Max normalizes to [0,1] per patch 
    * IF patch is nearly constant `(mx - mn < 1e-6)`, returns a zeros array 

* This ensures:
    * stable gradients
    * consistent PSNR,SSIM meaning


In [8]:
def norm_np(a: np.ndarray) -> np.ndarray:
    """Per-band min-max normalization to [0,1] with NaN/Inf protection."""
    a = np.array(a, dtype=np.float32)
    if np.isnan(a).any() or np.isinf(a).any():
        a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
    mn = float(np.nanmin(a))
    mx = float(np.nanmax(a))
    if mx - mn < 1e-6:
        return np.zeros_like(a, dtype=np.float32)
    return ((a - mn) / (mx - mn)).astype(np.float32)

### Metrics
* `compute_metrics(pred,target)`
    * Both inputs are expected in [0,1]
    * Computes:
        1. MSE
        2. PSNR = 10 * log10(1 / MSE)  {Clamped to 100 dB if MSE extremly small}
        3. RMSE = sqrt(MSE)
        4. SSIm using `skimage.ssim(data_range=1.0)`.

Metrics such as PSNR, SSIM, and RMSE are computed in normalized form for consistency. The RMSE we report can be directly converted back to Kelvin by multiplying with the original temperature range. For example, if the normalized RMSE is 0.012 and the temperature range is 50K, then the true thermal error is 0.6 Kelvin. This aligns with real-world expectations and ensures physical meaning.

In [9]:
def compute_metrics(pred: np.ndarray, target: np.ndarray):
    """PSNR / SSIM / RMSE on [0,1] normalized arrays."""
    pred   = np.nan_to_num(pred,   nan=0.0, posinf=1.0, neginf=0.0)
    target = np.nan_to_num(target, nan=0.0, posinf=1.0, neginf=0.0)
    mse = float(np.mean((pred - target) ** 2))
    if not np.isfinite(mse) or mse < 1e-12:
        psnr_val = 100.0
        rmse_val = 0.0
    else:
        psnr_val = 10 * math.log10(1.0 / mse)
        rmse_val = math.sqrt(mse)
    try:
        ssim_val = ssim(target, pred, data_range=1.0)
    except Exception:
        ssim_val = 0.0
    return psnr_val, ssim_val, rmse_val

### Dataset Download and extraction process

In [10]:
def discover_ssl4eo_scenes(root=SCENES_ROOT):
    """
    Recursively find all scenes that contain `all_bands.tif`.
    Your existing extract matches this layout.
    """
    pattern = os.path.join(root, "**", "all_bands.tif")
    scene_files = sorted(glob(pattern, recursive=True))
    logger.info(f"Discovered {len(scene_files)} SSL4EO scenes with all_bands.tif")
    return scene_files

In [11]:
def maybe_download_ssl4eo(archive_dir=SSL4EO_DIR, scenes_root=SCENES_ROOT):
    """
    If no scenes found, optionally download the SSL4EO benchmark archive and extract.
    If you've already downloaded/extracted, this will just return quickly.
    """
    scenes = discover_ssl4eo_scenes(scenes_root)
    if len(scenes) > 0:
        return scenes  # already present

    logger.info("No SSL4EO scenes found. Attempting to download archive...")
    os.makedirs(archive_dir, exist_ok=True)
    archive_path = os.path.join(archive_dir, "ssl4eo_benchmark.tar.gz")

    if not os.path.exists(archive_path):
        import requests
        with requests.get(SSL4EO_URL, stream=True) as r:
            r.raise_for_status()
            total = int(r.headers.get("content-length", 0))
            logger.info(f"Downloading SSL4EO archive ({total/1e9:.2f} GB approx)...")
            with open(archive_path, "wb") as f, tqdm(
                total=total, unit="B", unit_scale=True, desc="ssl4eo_download"
            ) as pbar:
                for chunk in r.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
                        pbar.update(len(chunk))
    else:
        logger.info(f"Found existing archive: {archive_path}")

    logger.info(f"Extracting {archive_path} -> {scenes_root}")
    os.makedirs(scenes_root, exist_ok=True)
    with tarfile.open(archive_path, "r:gz") as tf:
        tf.extractall(path=scenes_root)
    logger.info("Extraction finished.")
    return discover_ssl4eo_scenes(scenes_root)

### SSL4EOPatchDataset - The data pipeline

For each training step, the dataset does this:
* "Pick a scene -> pick a random location -> Cut out a 128,128 high resolution patch of thermal + rgb -> downsample thermal to low resolution -> Give the model low resolution thermal +  high resolution RGB -> Predict High resolution thermal (Ground truth)."

So the dataset is a factory that takes big scenes and turns them into millions of tiny supervised examples:
* `scene_files` : list of paths likes `.../ssl4eo/scenes/.../some_scene/all_bands.tif.`
* `hr_patch = 128` : The output resolution you want the network to produce (HR Thermal)
* `upscale = 2` : We pretend LR sensor sees a coarser 64x64 version of the world, and we want to go to 128x128
* `lr_patch = 128 // 2 = 64` : Size of the LR patch
* `patches_per_scene`: How many random patches you want to extract per scene per epoch

This keeps:
1. Training faster
2. Per-scene contribution balanced 
3. And still covers diverse locations across epochs



#### `def __len__(self)` - how many patches per epoch:
Example: 
* len(scene_files) = 25,000 -> Total SSL4EO scenes
* Train split = 17,500 scenes
* PATCHES_PER_SCENE_TRAIN = 4

So for the training dataset:
* Length = 17500 x 4 = 70,000 patches per epoch 
* Batch Size = 4 -> steps per epoch = 70,000 / 4 = 17,500 iteratiions

So even though the scenes are large, each epoch only uses a subsampled set of patches

#### `def _read_bands()` -  reading the needed channels
* all_bands.tif has many bands (B1-B11)
* We don't need all of them 
* We select only 
    * B2,B3,B4 -> RGB 
    * B10,B11 -> Thermal
Output Shape = (5,H,W):
* bands[0]: B2 (blue)
* bands[1]: B3 (green)
* bands[2]: B4 (red)
* bands[3]: B10 (thermal 1)
* bands[4]: B11 (thermal 2)


#### `def __getitem__` - How one patch is created
1. Choose which scene this index belongs to
* scene_idx = idx // self.patches_per_scene
* scene_path = self.scene_files[scene_idx]

__getitem__ – how one patch is created

This is the important flow. Let’s go line by line.

* 4.1 Choose which scene this index belongs to
    * scene_idx = idx // self.patches_per_scene
    * scene_path = self.scene_files[scene_idx]
* 4.2 Read bands
    * bands = self._read_bands(scene_path)   # (5, H, W)
    * rgb = bands[0:3, :, :]                 # (3, H, W)   B2,B3,B4
    * t10 = bands[3, :, :]                   # (H, W)
    * t11 = bands[4, :, :]                   # (H, W)
SO now we have: 
1. rgb = 3 channel high res optical
2. t10,t11 = two thermal channels 

* 4.3 Pick which thermal band to supervise with:
    *   `if random.random() < 0.5:`
    *        `thermal_hr = t10`
    *   `else:`
    *        `thermal_hr = t11`

For this patch:
* Sometimes we use B10 as ground truth
* Sometimes we use B11

Why?
1. Both bands are thermal, buth with slight spectral differences
2. Randomizing teaches the model to be robust and not overfit to just one band
3. Over many epoch, it 'sees' both


* 4.4 Normalize optical and thermal
    * rgb_n = np.stack([norm_np(rgb[c]) for c in range(3)], axis=0) # (3, H, W)
    * thr_n = norm_np(thermal_hr)                                   # (H, W)
Each band is mapped to [0,1] using norm_np:
1. removes NaN/Infs
2. min-> 0, max -> 1
This stablizes training and makes metrics well defined 

* 4.5 Ensure scene big enough - padding if needed
    * H, W = thr_n.shape
    * if H < self.hr_patch or W < self.hr_patch:
        * pad_y = max(0, self.hr_patch - H)
        * pad_x = max(0, self.hr_patch - W)
        * thr_n = np.pad(thr_n, ((0, pad_y), (0, pad_x)), mode='reflect')
        * rgb_n = np.pad(rgb_n, ((0, 0), (0, pad_y), (0, pad_x)), mode='reflect')
        * H, W = thr_n.shape
If the scene is smaller than 128×128:
* We pad using reflection (mirror the edges).
* This avoids black borders and keeps patterns natural.


#### `Synthesizing the LR thermal (fake low-res sensor)`
Now we simulate what a coarser thermal sensor would see. 
* H_lr, W_lr = H // self.upscale, W // self.upscale
* lr_full = F.interpolate(
    * torch.from_numpy(thr_n).unsqueeze(0).unsqueeze(0).float(),  # (1,1,H,W)
    * size=(H_lr, W_lr),
    * mode="bilinear",
    * align_corners=False
* ).squeeze().numpy()  # (H_lr, W_lr)

1. Input: HR thermal thr_n with shape (H, W) → add batch + channel → (1, 1, H, W).
2. Downsample to (H/2, W/2) if UPSCALE=2.
3. Output: lr_full is the synthetic low-res thermal.

So: 
* thr_n = what a hypothetical super - high resolution thermal sensor would measure
* lr_full = what the real lower - res satellite thermal band would see
We teach the model to go from `lr_full` back up to `thr_n`, guided by high-res RGB

#### `Random crop – aligned HR + LR patches`
Now we select a random HR patch and align its LR counterparts
* max_y = H - self.hr_patch
* max_x = W - self.hr_patch
* if max_y <= 0 or max_x <= 0:
    * y = 0
    * x = 0
* else:
    * y = np.random.randint(0, max_y + 1)
    * x = np.random.randint(0, max_x + 1)
Pick a valid top-left corner (y, x) so that [y:y+128, x:x+128] fits.

* 6.1 Crop HR thermal and RGB
    * hr_t_patch   = thr_n[y:y + self.hr_patch, x:x + self.hr_patch]   # (128,128)
    * hr_rgb_patch = rgb_n[:, y:y + self.hr_patch, x:x + self.hr_patch]# (3,128,128)
So: 
1. hr_t_patch: high-res target thermal patch.
2. hr_rgb_patch: aligned high-res optical patch.

* 6.2 Crop alogned LR thermal patch
    * ly, lx = y // self.upscale, x // self.upscale
    * lr_t_patch = lr_full[ly:ly + self.lr_patch, lx:lx + self.lr_patch]  # (64,64)
if upscale = 2
1. 1 LR picel = 2 HR pixel
2. So we divide y and x by 2
Now we have a fully aligned triplet
1. LR thermal = `lr_t_patch` (64x64)
2. HR RGB = `hr_rgb_patch` (3,128,128)
3. HR thermal = `hr_t_patch` (128x128)
This is exactly the relationship the model is supposed to learn 

#### `Safety shape checks`
* There's some shape cleaning.
* This just protects against weird border cases (like tiny off by one due to divide/round)

#### `Convert to tensors for the model`
* lr_t  = torch.from_numpy(lr_t_patch).unsqueeze(0).float()      # (1, LR, LR)
* hr_rgb = torch.from_numpy(hr_rgb_patch).float()                # (3, HR, HR)
* hr_t   = torch.from_numpy(hr_t_patch).unsqueeze(0).float()     # (1, HR, HR)
* return lr_t, hr_rgb, hr_t

Final shapes
* lr_t: (1, 64, 64) → low-res thermal input.
* hr_rgb: (3, 128, 128) → high-res optical guidance.
* hr_t: (1, 128, 128) → ground truth high-res thermal.
When we wrap this in a DataLoader with batch_size=BATCH_SIZE, PyTorch adds the batch dimension:
* lr_t: (B, 1, 64, 64)
* hr_rgb: (B, 3, 128, 128)
* hr_t: (B, 1, 128, 128)


In [12]:
class SSL4EOPatchDataset(Dataset):
    """
    For each scene:
      - Loads all_bands.tif
      - Uses B2/B3/B4 as HR optical guidance (3×channels)
      - Uses either B10 or B11 (chosen randomly) as HR thermal "truth"
      - Synthesizes LR thermal by downsampling with factor UPSCALE
      - Returns aligned patches: lr_thermal (1, LR, LR), hr_rgb (3, HR, HR), hr_thermal (1, HR, HR)
    The length is (#scenes * patches_per_scene) so each epoch samples a fixed number
    of random patches per scene, instead of exploding to millions of iterations.
    """
    def __init__(self, scene_files, hr_patch=HR_PATCH, upscale=UPSCALE,
                 patches_per_scene=4, mode="train"):
        super().__init__()
        self.scene_files = list(scene_files)
        self.hr_patch = hr_patch
        self.lr_patch = hr_patch // upscale
        self.upscale = upscale
        self.patches_per_scene = patches_per_scene
        self.mode = mode

    def __len__(self):
        return len(self.scene_files) * self.patches_per_scene

    def _read_bands(self, scene_path):
        # rasterio uses 1-based band indices
        with rasterio.open(scene_path) as src:
            # order = [B2,B3,B4,B10,B11]
            bands = src.read([
                BAND_IDX["B2"], BAND_IDX["B3"], BAND_IDX["B4"],
                BAND_IDX["B10"], BAND_IDX["B11"]
            ]).astype(np.float32)  # shape (5, H, W)
        return bands

    def __getitem__(self, idx):
        # map global index -> which scene
        scene_idx = idx // self.patches_per_scene
        scene_path = self.scene_files[scene_idx]

        bands = self._read_bands(scene_path)   # (5, H, W)
        rgb = bands[0:3, :, :]                 # B2,B3,B4
        t10 = bands[3, :, :]                   # B10
        t11 = bands[4, :, :]                   # B11

        # Pick one thermal band per patch (so model learns both across epochs)
        if random.random() < 0.5:
            thermal_hr = t10
        else:
            thermal_hr = t11

        # Normalize optical & thermal
        rgb_n = np.stack([norm_np(rgb[c]) for c in range(3)], axis=0)       # (3,H,W)
        thr_n = norm_np(thermal_hr)                                        # (H,W)

        H, W = thr_n.shape
        # if scene smaller than patch, reflect-pad
        if H < self.hr_patch or W < self.hr_patch:
            pad_y = max(0, self.hr_patch - H)
            pad_x = max(0, self.hr_patch - W)
            thr_n = np.pad(thr_n, ((0, pad_y), (0, pad_x)), mode='reflect')
            rgb_n = np.pad(rgb_n, ((0, 0), (0, pad_y), (0, pad_x)), mode='reflect')
            H, W = thr_n.shape

        # synthesize LR by downsampling HR thermal
        H_lr, W_lr = H // self.upscale, W // self.upscale
        lr_full = F.interpolate(
            torch.from_numpy(thr_n).unsqueeze(0).unsqueeze(0).float(),  # (1,1,H,W)
            size=(H_lr, W_lr),
            mode="bilinear",
            align_corners=False
        ).squeeze().numpy()  # (H_lr, W_lr)

        # random crop on HR, with aligned LR crop
        max_y = H - self.hr_patch
        max_x = W - self.hr_patch
        if max_y <= 0 or max_x <= 0:
            y = 0
            x = 0
        else:
            y = np.random.randint(0, max_y + 1)
            x = np.random.randint(0, max_x + 1)

        hr_t_patch   = thr_n[:, y:y + self.hr_patch,] if thr_n.ndim == 3 else thr_n[y:y + self.hr_patch,
                                                                                    x:x + self.hr_patch]
        hr_rgb_patch = rgb_n[:, y:y + self.hr_patch, x:x + self.hr_patch]

        ly, lx = y // self.upscale, x // self.upscale
        lr_t_patch = lr_full[ly:ly + self.lr_patch, lx:lx + self.lr_patch]

        # ensure correct shapes
        hr_t_patch = hr_t_patch if hr_t_patch.ndim == 2 else hr_t_patch[0]
        if hr_t_patch.shape != (self.hr_patch, self.hr_patch):
            # safety crop
            hr_t_patch = hr_t_patch[:self.hr_patch, :self.hr_patch]
            hr_rgb_patch = hr_rgb_patch[:, :self.hr_patch, :self.hr_patch]
        if lr_t_patch.shape != (self.lr_patch, self.lr_patch):
            lr_t_patch = lr_t_patch[:self.lr_patch, :self.lr_patch]

        # to tensors
        lr_t  = torch.from_numpy(lr_t_patch).unsqueeze(0).float()      # (1, LR, LR)
        hr_rgb = torch.from_numpy(hr_rgb_patch).float()                # (3, HR, HR)
        hr_t   = torch.from_numpy(hr_t_patch).unsqueeze(0).float()     # (1, HR, HR)

        return lr_t, hr_rgb, hr_t

In [13]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        y = self.avgpool(x)
        y = self.fc(y)
        return x * y

In [14]:
class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, max(8, in_channels//2), kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(max(8, in_channels//2), 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
    def forward(self, x):
        att = self.conv(x)
        return x * att

In [15]:
class RCAB(nn.Module):
    def __init__(self, channels, kernel_size=3, reduction=16):
        super().__init__()
        pad = kernel_size // 2
        self.body = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size, padding=pad),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size, padding=pad)
        )
        self.ca = ChannelAttention(channels, reduction=reduction)
        self.res_scale = 0.1
    def forward(self, x):
        res = self.body(x)
        res = self.ca(res)
        return x + res * self.res_scale

In [16]:
class ResidualGroup(nn.Module):
    def __init__(self, channels, n_rcab=4):
        super().__init__()
        layers = [RCAB(channels) for _ in range(n_rcab)]
        self.body = nn.Sequential(*layers)
    def forward(self, x):
        return self.body(x) + x

In [17]:
class LearnedUpsampler(nn.Module):
    def __init__(self, in_channels, out_channels, scale=UPSCALE):
        super().__init__()
        self.scale = scale
        self.proj = nn.Conv2d(in_channels, out_channels * (scale*scale), kernel_size=3, padding=1)
        self.post = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x, target_size=None):
        x = self.proj(x)
        x = F.pixel_shuffle(x, self.scale)
        x = self.post(x)
        if target_size is not None:
            x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False)
        return x

### DualEDSRPlus - The Model Architecture

Inputs:
* xT: Low resolution thermal map, shape (B,1,H_LR,W_LR)
* xO: high resoluton optical RGB image, shape (B,3,H_HR,W_HR) where H_HR = upscale x H_LR

Output: 
* out: High-resolution thermal map, shape (B,1,H_HR,W_HR)

Conceptually:
1. Extract features from thermal and optical seperately
2. Upsample thermal features to the optical features
3. Fuse Thermal + Optical features with channel and spatial attention
4. Refine fused featured and predict a high-res thermal map

It is called Dual EDSRPlus because:
* Dual: Represents the dual stream of optical and thermal
* EDSR: Heavy residual blocks like the EDSR super resolution architecture
* PLUS: Extra fusion and attention modules

#### 1.The constructor __init__
* `n_resgroups`: How many ResidualGroup blocks per stream
* `n_rcab`: how many RCAB blocks insideeach group
* `n_feats`: Feature channels in the hidden layer
* `upscale`: 2 for 2x SR, 4 for 4x etc
These hyperparameters control the deapth and width of the network

##### 1.1 Input Convolutions
* convT_in: First layer for thermal
    * Input: 1 Channel (Thermal)
    * Output: 64 channels (by default)
    * Kernel: 3x3, padding 1 -> preserves spatial size
* convO_int: first layer of optical
    * Input: 3 Channels (RGB)
    * Output: 64 Channels
    * Kernel: Same 3x3
Intuition: Convert raw images into a shared feature space of 64 channels , but with seperate weights for thermal and optical

##### 1.2 Deep feature extraction: Residual groups
* Thermal stream: t_groups
* Optical stream: o_groups
Each `ResidualGroup` contains multiple RCAB's and a group skip connection 
* RCAB = Conv -> ReLU -> Conv + ChannelAttention + Residual
* Several RCAB's in series = deeper representation
* Group-level skip = stabilizes training, preserves low level info
Effect:
* The thermal stream learns rich temperature - related patterns
* The optical stream learns edges,textrues,land covers boundaries etc


##### 1.3 Learnable upsamples for thermal
* Takes low resolution thermal features (B,64,H_LR,W_LR) and upsamples them to (B,64,H_HR,W_HR) using `LearnedUpsampler`
Intuition:
* Instead of just doing bilinear upsampling on the image, we upsample in feature space with learnable parameters,more powerful and sharp


##### 1.4 Fusion + attention modules
After upsampling thermal features to match optical resolution, we will:
1. Concatenate: `[thermal_up,optical_features]` Along with channels _> 64+64 = 128 channels
2. ConvFuse: 1x1 conv to mix and comparess 128 -> 64 channels
3. fuse_ca: ChannelAttention -> Decide which channels are important globally
4. fuse_sa: SpatialAttention -> Decide which pixels/locations are important

ChannelAttenion:
* Computes a channel descriptor via global avergae pooling
* Passes it through small MLP (conv 1x1) and sigmoid to get weights in [0,1]
* Mutiplies each channel by its importance

SpatialAttention:
* Convs over the feature map to produce a 1-channel importance mask [H,w] in [0,1]
* Multplies the whole feature map by this mask -> Focuses on hotspots,boundaries

##### 1.5 Refinement + output
* `refine`: two extra 3x3 convs with ReLU to clean up fused representation
* `convOut`: final 3x3 conv to map features (64 channels) -> 1 channel

So final output has shape (B,1,H_HR,W_HR)


In [18]:
class DualEDSRPlus(nn.Module):
    def __init__(self, n_resgroups=4, n_rcab=4, n_feats=64, upscale=UPSCALE):
        super().__init__()
        self.upscale = upscale
        self.n_feats = n_feats

        self.convT_in = nn.Conv2d(1, n_feats, 3, padding=1)
        self.convO_in = nn.Conv2d(3, n_feats, 3, padding=1)

        self.t_groups = nn.Sequential(*[ResidualGroup(n_feats, n_rcab) for _ in range(n_resgroups)])
        self.o_groups = nn.Sequential(*[ResidualGroup(n_feats, n_rcab) for _ in range(n_resgroups)])

        self.t_upsampler = LearnedUpsampler(n_feats, n_feats, scale=upscale)

        self.convFuse = nn.Conv2d(2 * n_feats, n_feats, kernel_size=1)
        self.fuse_ca  = ChannelAttention(n_feats)
        self.fuse_sa  = SpatialAttention(n_feats)

        self.refine = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.convOut = nn.Conv2d(n_feats, 1, kernel_size=3, padding=1)

        # Kaiming init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, xT, xO):
        fT = F.relu(self.convT_in(xT))
        fO = F.relu(self.convO_in(xO))

        fT = self.t_groups(fT)
        fO = self.o_groups(fO)

        fT_up_raw = self.t_upsampler(fT)
        target_hw = (fO.shape[2], fO.shape[3])
        fT_up = F.interpolate(fT_up_raw, size=target_hw, mode="bilinear", align_corners=False)

        f = torch.cat([fT_up, fO], dim=1)
        f = F.relu(self.convFuse(f))
        f = self.fuse_ca(f)
        f = self.fuse_sa(f)
        f = self.refine(f)
        out = self.convOut(f)
        return out


In [19]:
def partial_load_weights(model, ckpt_path, verbose=False):
    if not os.path.exists(ckpt_path):
        logger.warning(f"No checkpoint at {ckpt_path}")
        return 0
    src = torch.load(ckpt_path, map_location=DEVICE)
    if isinstance(src, dict) and "model_state" in src:
        src = src["model_state"]
    model_dict = model.state_dict()
    loaded = 0
    for k, v in src.items():
        if k in model_dict and model_dict[k].shape == v.shape:
            model_dict[k] = v
            loaded += 1
        elif verbose:
            logger.info(f"Skipping {k}; mismatch or missing.")
    model.load_state_dict(model_dict)
    logger.info(f"Partial-loaded {loaded} tensors from {ckpt_path}")
    return loaded

### Physics-aware loss - respecting temperature consistency

* In training:
    * pred_hr = model(lr_t,hr_rgb)
    * loss_fid = mse_loss(pred_hr, hr_t)
This is the standard super resolution loss: predicted HR vs groud truth HR

* Now the physics-aware part: 
    * pred_lr = F.interpolate(
    * pred_hr, size=(lr_t.shape[2], lr_t.shape[3]),
    * mode="area"  # average pooling style
    * )
    * loss_phys = mse_loss(pred_lr, lr_t)
    * loss = loss_fid + PHYS_LAMBDA * loss_phys

* What is happening?
    1. WE take the predicted HR thermal map, and downsample it back to LR using mode = 'area', which is like block averaging
    2. That gives an LR map representing average energy/temperature over coarse pixel
    3. We force it to be close to the original LR thermal input
* Why is this physics-aware?
    * Real sensors (e.g TIRS) measure average radiance/temperature over each coarse pixel footprint
    * If we "invent" crazy fine scae structure that changes the average, we would be violating physics
    * This loss says: ` Whatever details you invent at high resolution must still average back to the same coarse thermal value he satellite saw`

Our optical guidance can suggest edges and shapes, but the coarse thermal energy constraint keeps temperature in check


### Training,Validation, test flow

* Scene - level split (no leakage of tiles from same scene across sets)
* 70% train, 15% test, 15% validation

Each epoch: 
* For each train scene: 4 random patches
* For each val scene: 2 patches
* For each test scene: 2 patches
So epochs are manageable but still diverse

1. Training loop:
* Per epoch:
    * Iterate `train_loader`.
    * For each batch:
        1. forawrd -> pred_hr
        2. compute loss_fid and loss_phys.
        3. Back propogration, gradient clipping, optimizer step
2. Validation phase:
    * No gradients
    * For each val batch:
        1. forward -> `out`
        2. convert to Numpy, compute PSNR/SSIM/RMSE
    * Average metrics across validation patches
    * If PSNR improved -> save best model
3. Save last model every epoch


#### Training evaluation
After training: 
* Reload best model
* Run through test_loader exactly like val
* Log final PSNR/SSIM/RMSE -> this is what you can compare against literature

In [21]:
def train_ssl4eo(num_epochs=NUM_EPOCHS):
    # 1) Find / download scenes
    scenes = discover_ssl4eo_scenes()
    if len(scenes) == 0:
        scenes = maybe_download_ssl4eo()
    if len(scenes) == 0:
        logger.error("No SSL4EO scenes available; aborting.")
        return

    # 2) Train/val/test split
    random.seed(42)
    np.random.seed(42)
    scenes_shuffled = scenes.copy()
    random.shuffle(scenes_shuffled)

    n_total = len(scenes_shuffled)
    n_train = int(0.7 * n_total)
    n_val   = int(0.15 * n_total)
    n_test  = n_total - n_train - n_val

    train_scenes = scenes_shuffled[:n_train]
    val_scenes   = scenes_shuffled[n_train:n_train + n_val]
    test_scenes  = scenes_shuffled[n_train + n_val:]

    logger.info(
        f"Starting SSL4EO training with DualEDSRPlus + physics-aware loss.\n"
        f"Scene split -> Train: {len(train_scenes)}, Val: {len(val_scenes)}, Test: {len(test_scenes)}"
    )

    # 3) Datasets / loaders
    train_ds = SSL4EOPatchDataset(
        train_scenes, hr_patch=HR_PATCH, upscale=UPSCALE,
        patches_per_scene=PATCHES_PER_SCENE_TRAIN, mode="train"
    )
    val_ds = SSL4EOPatchDataset(
        val_scenes, hr_patch=HR_PATCH, upscale=UPSCALE,
        patches_per_scene=PATCHES_PER_SCENE_VAL, mode="val"
    )
    test_ds = SSL4EOPatchDataset(
        test_scenes, hr_patch=HR_PATCH, upscale=UPSCALE,
        patches_per_scene=PATCHES_PER_SCENE_TEST, mode="test"
    )

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=0, pin_memory=(DEVICE.type == "cuda")
    )
    val_loader = DataLoader(
        val_ds, batch_size=1, shuffle=False,
        num_workers=0, pin_memory=(DEVICE.type == "cuda")
    )
    test_loader = DataLoader(
        test_ds, batch_size=1, shuffle=False,
        num_workers=0, pin_memory=(DEVICE.type == "cuda")
    )

    # 4) Model, optimizer, loss
    model = DualEDSRPlus(n_resgroups=4, n_rcab=4, n_feats=64, upscale=UPSCALE).to(DEVICE)

    # optional warm-start from ECOSTRESS later:
    # partial_load_weights(model, ECO_BEST, verbose=False)

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    mse_loss  = nn.MSELoss()

    BEST_PATH = os.path.join(MODELS_DIR, "ssl4eo_best.pth")
    LAST_PATH = os.path.join(MODELS_DIR, "ssl4eo_last.pth")

    # --------- RESUME LOGIC ---------
    best_val_psnr = -1e9
    start_epoch = 1

    if os.path.exists(LAST_PATH):
        try:
            ckpt = torch.load(LAST_PATH, map_location=DEVICE)
            if isinstance(ckpt, dict) and "model_state" in ckpt:
                model.load_state_dict(ckpt["model_state"])
                if "optimizer_state" in ckpt:
                    optimizer.load_state_dict(ckpt["optimizer_state"])
                if "epoch" in ckpt:
                    start_epoch = ckpt["epoch"] + 1
                best_val_psnr = ckpt.get("best_val_psnr", -1e9)
                logger.info(
                    f"Resuming from checkpoint {LAST_PATH}: "
                    f"start_epoch={start_epoch}, best_val_psnr={best_val_psnr:.3f}"
                )
            else:
                logger.warning("Checkpoint format unexpected; starting from scratch.")
        except Exception as e:
            logger.warning(f"Failed to load {LAST_PATH} ({e}); starting from scratch.")
    else:
        logger.info("No LAST checkpoint found; starting from epoch 1.")

    logger.info(f"Training epochs: {start_epoch} -> {num_epochs}")
    # --------------------------------

    # 5) Training loop
    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        running = 0.0
        it = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} (train)")
        for lr_t, hr_rgb, hr_t in pbar:
            lr_t   = lr_t.to(DEVICE)      # (B,1,LR,LR)
            hr_rgb = hr_rgb.to(DEVICE)    # (B,3,HR,HR)
            hr_t   = hr_t.to(DEVICE)      # (B,1,HR,HR)

            optimizer.zero_grad()
            pred_hr = model(lr_t, hr_rgb)

            # --- Data fidelity loss (per-pixel MSE at HR) ---
            loss_fid = mse_loss(pred_hr, hr_t)

            # --- Physics-aware loss: coarse thermal consistency ---
            pred_lr = F.interpolate(
                pred_hr, size=(lr_t.shape[2], lr_t.shape[3]),
                mode="area"
            )
            loss_phys = mse_loss(pred_lr, lr_t)

            loss = loss_fid + PHYS_LAMBDA * loss_phys
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running += float(loss.item())
            it += 1
            pbar.set_postfix(loss=running / max(1, it))

        avg_train_loss = running / max(1, it)
        logger.info(f"Epoch {epoch} TRAIN loss={avg_train_loss:.6f}")

        # 6) Validation
        model.eval()
        ps_sum = ss_sum = rm_sum = 0.0
        cnt = 0
        with torch.no_grad():
            for lr_t, hr_rgb, hr_t in tqdm(val_loader, desc=f"Epoch {epoch} (val)"):
                lr_t   = lr_t.to(DEVICE)
                hr_rgb = hr_rgb.to(DEVICE)
                hr_t   = hr_t.to(DEVICE)

                out = model(lr_t, hr_rgb)
                pred = out.cpu().squeeze().numpy()
                tgt  = hr_t.cpu().squeeze().numpy()
                ps, ss, rm = compute_metrics(pred, tgt)
                ps_sum += ps; ss_sum += ss; rm_sum += rm; cnt += 1
        if cnt > 0:
            avg_ps = ps_sum / cnt
            avg_ss = ss_sum / cnt
            avg_rm = rm_sum / cnt
            logger.info(
                f"Epoch {epoch} VAL PSNR={avg_ps:.3f} dB, SSIM={avg_ss:.4f}, RMSE={avg_rm:.6f}"
            )
            if avg_ps > best_val_psnr:
                best_val_psnr = avg_ps
                torch.save(
                    {
                        "model_state": model.state_dict(),
                        "epoch": epoch,
                        "best_val_psnr": best_val_psnr,
                    },
                    BEST_PATH
                )
                logger.info(f"Saved BEST model -> {BEST_PATH} (PSNR={avg_ps:.3f})")

        # always save last (FULL checkpoint for resume)
        torch.save(
            {
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "epoch": epoch,
                "best_val_psnr": best_val_psnr,
            },
            LAST_PATH
        )
        logger.info(f"Saved LAST model -> {LAST_PATH} (epoch={epoch})")

    # 7) Final test evaluation with best model
    if os.path.exists(BEST_PATH):
        ckpt = torch.load(BEST_PATH, map_location=DEVICE)
        model.load_state_dict(ckpt["model_state"])
        logger.info(f"Loaded BEST model from {BEST_PATH} for TEST evaluation.")

    model.eval()
    ps_sum = ss_sum = rm_sum = 0.0
    cnt = 0
    with torch.no_grad():
        for lr_t, hr_rgb, hr_t in tqdm(test_loader, desc="TEST"):
            lr_t   = lr_t.to(DEVICE)
            hr_rgb = hr_rgb.to(DEVICE)
            hr_t   = hr_t.to(DEVICE)

            out = model(lr_t, hr_rgb)
            pred = out.cpu().squeeze().numpy()
            tgt  = hr_t.cpu().squeeze().numpy()
            ps, ss, rm = compute_metrics(pred, tgt)
            ps_sum += ps; ss_sum += ss; rm_sum += rm; cnt += 1
    if cnt > 0:
        avg_ps = ps_sum / cnt
        avg_ss = ss_sum / cnt
        avg_rm = rm_sum / cnt
        logger.info(
            f"TEST SUMMARY (SSL4EO, RGB-guided, physics-aware): "
            f"PSNR={avg_ps:.3f} dB, SSIM={avg_ss:.4f}, RMSE={avg_rm:.6f}"
        )
    return model


In [None]:
if __name__ == "__main__":
    train_ssl4eo()

2025-12-08 07:13:23,007 INFO:infranova_ssl4eo: Discovered 25000 SSL4EO scenes with all_bands.tif
2025-12-08 07:13:23,015 INFO:infranova_ssl4eo: Starting SSL4EO training with DualEDSRPlus + physics-aware loss.
Scene split -> Train: 17500, Val: 3750, Test: 3750
2025-12-08 07:13:26,882 INFO:infranova_ssl4eo: Resuming from checkpoint models\ssl4eo_last.pth: start_epoch=9, best_val_psnr=39.370
2025-12-08 07:13:26,883 INFO:infranova_ssl4eo: Training epochs: 9 -> 50
Epoch 9/50 (train): 100%|██████████| 17500/17500 [1:23:51<00:00,  3.48it/s, loss=0.000171]
2025-12-08 08:37:18,165 INFO:infranova_ssl4eo: Epoch 9 TRAIN loss=0.000171
Epoch 9 (val): 100%|██████████| 7500/7500 [05:27<00:00, 22.93it/s]
2025-12-08 08:42:45,201 INFO:infranova_ssl4eo: Epoch 9 VAL PSNR=39.584 dB, SSIM=0.9759, RMSE=0.011122
2025-12-08 08:42:45,260 INFO:infranova_ssl4eo: Saved BEST model -> models\ssl4eo_best.pth (PSNR=39.584)
2025-12-08 08:42:45,409 INFO:infranova_ssl4eo: Saved LAST model -> models\ssl4eo_last.pth (epoch=

KeyboardInterrupt: 