In [None]:
import os
import sys
import math
import logging
from glob import glob
from tqdm import tqdm
import earthaccess
import rasterio
from rasterio.enums import Resampling
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
OUT_DIR = "data_processed"
RAW_DIR = "data_raw"
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(RAW_DIR, exist_ok=True)

# Region / date selection
BBOX = (78.3, 17.2, 78.7, 17.6)
DATE_RANGE = ("2023-01-01", "2023-01-31")
MAX_GRANULES = 4

# Patch & training params
UPSCALE = 3
HR_PATCH = 128
LR_PATCH = HR_PATCH // UPSCALE
BATCH_SIZE = 4
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("mvp_sr")

In [3]:
def earthdata_download_hls(bbox, date_range, max_granules=4, out_dir=RAW_DIR):
    logger.info("Logging into Earthdata. You will be prompted for credentials if needed.")
    auth = earthaccess.login()  # interactive

    logger.info("Searching HLSL30 (Landsat-based) -- includes thermal bands (B10)...")
    results_L30 = earthaccess.search_data(
        short_name="HLSL30", temporal=date_range, bounding_box=bbox
    )
    logger.info(f"Found {len(results_L30)} L30 granules. Downloading up to {max_granules}...")
    l30_downloads = earthaccess.download(results_L30[:max_granules], out_dir)

    logger.info("Searching HLSS30 (Sentinel-2-based) -- optical B02/B03/B04...")
    results_S30 = earthaccess.search_data(
        short_name="HLSS30", temporal=date_range, bounding_box=bbox
    )
    logger.info(f"Found {len(results_S30)} S30 granules. Downloading up to {max_granules}...")
    s30_downloads = earthaccess.download(results_S30[:max_granules], out_dir)

    l30_files = list(l30_downloads) if l30_downloads is not None else []
    s30_files = list(s30_downloads) if s30_downloads is not None else []
    return l30_files, s30_files

In [4]:
def extract_and_save_bands(l30_files, s30_files, out_dir=OUT_DIR):
    saved_pairs = []
    n = min(len(l30_files), len(s30_files))
    logger.info(f"Extracting bands for {n} matched pairs.")
    for i in range(n):
        l30 = l30_files[i]; s30 = s30_files[i]
        try:
            with rasterio.open(l30) as src:
                thermal = src.read(1)
                prof = src.profile
        except Exception as e:
            logger.error(f"Unable to read L30 file {l30}: {e}")
            continue

        s30_dir = os.path.dirname(s30)
        def find_band_file(dirpath, band_codes):
            for code in band_codes:
                candidates = glob(os.path.join(dirpath, f"*{code}*.tif"))
                if candidates: return candidates[0]
            return None

        b02 = find_band_file(s30_dir, ["B02", "b02", "B2"])
        b03 = find_band_file(s30_dir, ["B03", "b03", "B3"])
        b04 = find_band_file(s30_dir, ["B04", "b04", "B4"])
        optical_profile = None

        if not (b02 and b03 and b04):
            try:
                with rasterio.open(s30) as ssrc:
                    if ssrc.count >= 3:
                        b02_arr = ssrc.read(1)
                        b03_arr = ssrc.read(2)
                        b04_arr = ssrc.read(3)
                        optical_profile = ssrc.profile
                    else:
                        raise RuntimeError("S30 granule doesn't have RGB bands.")
            except Exception as e:
                logger.error(f"Unable to extract RGB bands for S30 granule {s30}: {e}")
                continue
        else:
            with rasterio.open(b02) as sb: b02_arr = sb.read(1); optical_profile = sb.profile
            with rasterio.open(b03) as sb: b03_arr = sb.read(1)
            with rasterio.open(b04) as sb: b04_arr = sb.read(1)

        base = f"sample_{i}"
        t_path = os.path.join(out_dir, f"{base}_thermal_30m.tif")
        b02_path = os.path.join(out_dir, f"{base}_B02_10m.tif")
        b03_path = os.path.join(out_dir, f"{base}_B03_10m.tif")
        b04_path = os.path.join(out_dir, f"{base}_B04_10m.tif")

        t_prof = prof.copy()
        t_prof.update(driver="GTiff", count=1, dtype=thermal.dtype)
        with rasterio.open(t_path, "w", **t_prof) as dst: dst.write(thermal, 1)

        opt_prof = optical_profile.copy()
        opt_prof.update(driver="GTiff", count=1, dtype=b02_arr.dtype)
        with rasterio.open(b02_path, "w", **opt_prof) as dst: dst.write(b02_arr, 1)
        with rasterio.open(b03_path, "w", **opt_prof) as dst: dst.write(b03_arr, 1)
        with rasterio.open(b04_path, "w", **opt_prof) as dst: dst.write(b04_arr, 1)

        saved_pairs.append({
            "thermal30": t_path,
            "b02": b02_path,
            "b03": b03_path,
            "b04": b04_path
        })
        logger.info(f"Saved sample {i}: {t_path}, {b02_path}, {b03_path}, {b04_path}")

    return saved_pairs

In [5]:
def resample_thermal_to_optical(thermal_path, optical_reference_path, out_path):
    with rasterio.open(optical_reference_path) as ref:
        ref_profile = ref.profile
        ref_transform = ref.transform
        ref_crs = ref.crs
        ref_width = ref.width
        ref_height = ref.height

    with rasterio.open(thermal_path) as src:
        data = src.read(1)
        target_profile = src.profile.copy()
        target_profile.update({
            "crs": ref_crs,
            "transform": ref_transform,
            "width": ref_width,
            "height": ref_height,
            "driver": "GTiff",
            "count": 1
        })
        from rasterio.warp import reproject, Resampling
        dest = np.empty((ref_height, ref_width), dtype=np.float32)
        reproject(
            source=data,
            destination=dest,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=ref_transform,
            dst_crs=ref_crs,
            resampling=Resampling.bilinear
        )
        with rasterio.open(out_path, "w", **target_profile) as dst:
            dst.write(dest.astype(np.float32), 1)
    logger.info(f"Resampled thermal saved to {out_path}")
    return out_path


In [6]:
class ThermalOpticalPatchDataset(Dataset):
    def __init__(self, samples, hr_patch=HR_PATCH, upscale=UPSCALE, transform=None):
        super().__init__()
        self.samples = samples
        self.hr_patch = hr_patch
        self.lr_patch = hr_patch // upscale
        self.upscale = upscale
        self.transform = transform

    def __len__(self):
        return max(1000, len(self.samples) * 200)

    def _read_arrays_for_index(self, sample_idx):
        sample = self.samples[sample_idx % len(self.samples)]
        thermal30_path = sample["thermal30"]
        b02, b03, b04 = sample["b02"], sample["b03"], sample["b04"]

        with rasterio.open(b02) as sb2: r00 = sb2.read(1).astype(np.float32); opt_profile = sb2.profile
        with rasterio.open(b03) as sb3: r01 = sb3.read(1).astype(np.float32)
        with rasterio.open(b04) as sb4: r02 = sb4.read(1).astype(np.float32)
        rgb = np.stack([r02, r01, r00], axis=-1)

        thermal10_path_guess = thermal30_path.replace("_thermal_30m.tif", "_thermal_10m.tif")
        if os.path.exists(thermal10_path_guess):
            with rasterio.open(thermal10_path_guess) as t10f:
                thermal10 = t10f.read(1).astype(np.float32)
        else:
            with rasterio.open(thermal30_path) as ts:
                from rasterio.warp import reproject, Resampling
                dest = np.empty((opt_profile["height"], opt_profile["width"]), dtype=np.float32)
                reproject(
                    source=ts.read(1),
                    destination=dest,
                    src_transform=ts.transform,
                    src_crs=ts.crs,
                    dst_transform=opt_profile["transform"],
                    dst_crs=opt_profile["crs"],
                    resampling=Resampling.bilinear
                )
            thermal10 = dest

        def norm(a):
            mn = float(a.min()); mx = float(a.max())
            if mx - mn < 1e-6: return np.zeros_like(a, dtype=np.float32)
            return ((a - mn) / (mx - mn)).astype(np.float32)

        rgb_n = norm(rgb)
        thermal10_n = norm(thermal10)

        H_hr, W_hr = thermal10_n.shape
        H_lr, W_lr = H_hr // self.upscale, W_hr // self.upscale
        lr_thermal = F.interpolate(torch.from_numpy(thermal10_n).unsqueeze(0).unsqueeze(0).float(),
                                   size=(H_lr, W_lr), mode='bilinear', align_corners=False).squeeze().numpy()

        return lr_thermal.astype(np.float32), rgb_n.astype(np.float32), thermal10_n.astype(np.float32)

    def __getitem__(self, idx):
        sample_idx = np.random.randint(0, len(self.samples))
        lr_thermal, rgb_hr, hr_thermal = self._read_arrays_for_index(sample_idx)
        H_hr, W_hr, _ = rgb_hr.shape
        max_y = H_hr - self.hr_patch
        max_x = W_hr - self.hr_patch
        if max_y <= 0 or max_x <= 0:
            raise ValueError("HR image smaller than hr_patch")
        y = np.random.randint(0, max_y + 1)
        x = np.random.randint(0, max_x + 1)

        hr_rgb_patch = rgb_hr[y:y+self.hr_patch, x:x+self.hr_patch, :]
        hr_t_patch = hr_thermal[y:y+self.hr_patch, x:x+self.hr_patch]
        ly, lx = y // self.upscale, x // self.upscale
        lr_t_patch = lr_thermal[ly:ly + self.lr_patch, lx:lx + self.lr_patch]

        lr_t = torch.from_numpy(lr_t_patch).unsqueeze(0)
        hr_rgb = torch.from_numpy(np.transpose(hr_rgb_patch, (2,0,1)))
        hr_t = torch.from_numpy(hr_t_patch).unsqueeze(0)

        return lr_t.float(), hr_rgb.float(), hr_t.float()

In [7]:
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
    def forward(self, x):
        res = self.conv2(self.relu(self.conv1(x)))
        return x + res * 0.1

In [8]:
class DualEDSR(nn.Module):
    def __init__(self, n_resblocks=8, n_feats=64, upscale=3):
        super().__init__()
        self.upscale = upscale
        self.convT = nn.Conv2d(1, n_feats, 3, padding=1)
        self.convO = nn.Conv2d(3, n_feats, 3, padding=1)
        self.resBlocksT = nn.Sequential(*[ResBlock(n_feats) for _ in range(n_resblocks)])
        self.resBlocksO = nn.Sequential(*[ResBlock(n_feats) for _ in range(n_resblocks)])
        self.convFuse = nn.Conv2d(2 * n_feats, n_feats, 1)
        self.refine = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(n_feats, n_feats, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.convOut = nn.Conv2d(n_feats, 1, 3, padding=1)

    def forward(self, xT, xO):
        fT = F.relu(self.convT(xT))
        fO = F.relu(self.convO(xO))
        fT = self.resBlocksT(fT)
        fO = self.resBlocksO(fO)
        fT_up = F.interpolate(fT, size=(fO.shape[2], fO.shape[3]),
                              mode='bilinear', align_corners=False)
        f = torch.cat([fT_up, fO], dim=1)
        f = F.relu(self.convFuse(f))
        f = self.refine(f)
        out = self.convOut(f)
        return out


In [9]:
def compute_metrics(pred, target):
    pred = np.clip(pred, 0.0, 1.0)
    target = np.clip(target, 0.0, 1.0)
    mse = float(np.mean((pred - target) ** 2))
    psnr_val = 10 * math.log10(1.0 / mse) if mse > 1e-12 else 100.0
    rmse_val = math.sqrt(mse)
    try: ssim_val = ssim(target, pred, data_range=1.0)
    except: ssim_val = 0.0
    return psnr_val, ssim_val, rmse_val


In [10]:
def train_and_evaluate(samples, out_dir=OUT_DIR):
    n = len(samples)
    if n == 0:
        logger.error("No samples provided.")
        return

    # 🔹 Split dataset
    if n >= 3:
        train_samples = samples[:int(0.7*n)]
        val_samples = samples[int(0.7*n):int(0.9*n)]
        test_samples = samples[int(0.9*n):]
    else:
        train_samples = samples[:1]
        val_samples = samples[1:2] if n>1 else []
        test_samples = samples[2:3] if n>2 else []

    logger.info(f"Train={len(train_samples)}, Val={len(val_samples)}, Test={len(test_samples)}")

    # 🔹 Datasets
    train_ds = ThermalOpticalPatchDataset(train_samples)
    val_ds = ThermalOpticalPatchDataset(val_samples) if len(val_samples)>0 else None
    test_ds = ThermalOpticalPatchDataset(test_samples) if len(test_samples)>0 else None

    # 🔹 Dataloaders
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=0, pin_memory=(DEVICE.type=="cuda"))
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False) if val_ds else None
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False) if test_ds else None

    # 🔹 Model, optimizer, loss
    model = DualEDSR(n_resblocks=8, n_feats=64, upscale=UPSCALE).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.MSELoss()

    # 🔹 Checkpoint setup
    best_val = -1e9
    best_path = os.path.join(out_dir, "best_model.pth")
    ckpt_path = os.path.join(out_dir, "checkpoint.pth")
    loss_history = []
    start_epoch = 1

    # 🔹 Resume if checkpoint exists
    if os.path.exists(ckpt_path):
        logger.info(f"Resuming training from checkpoint: {ckpt_path}")
        checkpoint = torch.load(ckpt_path, map_location=DEVICE)
        model.load_state_dict(checkpoint["model_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        best_val = checkpoint["best_val"]
        start_epoch = checkpoint["epoch"] + 1
        loss_history = checkpoint.get("loss_history", [])
        logger.info(f"Resumed from epoch {checkpoint['epoch']} with best_val={best_val:.3f}")

    # 🔹 Training loop
    for epoch in range(start_epoch, NUM_EPOCHS+1):
        model.train()
        running_loss = 0.0
        for lr_t, hr_rgb, hr_t in tqdm(train_loader, desc=f"Epoch {epoch} training"):
            lr_t, hr_rgb, hr_t = lr_t.to(DEVICE), hr_rgb.to(DEVICE), hr_t.to(DEVICE)
            optimizer.zero_grad()
            pred = model(lr_t, hr_rgb)
            loss = criterion(pred, hr_t)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / max(1, len(train_loader))
        loss_history.append(avg_loss)
        logger.info(f"Epoch {epoch} Train Loss: {avg_loss:.6f}")

        # 🔹 Validation
        if val_loader:
            model.eval()
            ps_sum = ss_sum = rm_sum = 0.0; cnt=0
            with torch.no_grad():
                for lr_t, hr_rgb, hr_t in val_loader:
                    lr_t, hr_rgb, hr_t = lr_t.to(DEVICE), hr_rgb.to(DEVICE), hr_t.to(DEVICE)
                    out = model(lr_t, hr_rgb)
                    ps, ss, rm = compute_metrics(out.cpu().squeeze().numpy(), hr_t.cpu().squeeze().numpy())
                    ps_sum += ps; ss_sum += ss; rm_sum += rm; cnt+=1
            if cnt>0:
                avg_ps, avg_ss, avg_rm = ps_sum/cnt, ss_sum/cnt, rm_sum/cnt
                logger.info(f"Epoch {epoch} VAL PSNR={avg_ps:.3f}, SSIM={avg_ss:.4f}, RMSE={avg_rm:.6f}")
                if avg_ps > best_val:
                    best_val = avg_ps
                    torch.save(model.state_dict(), best_path)
                    logger.info("Saved best model.")

        # 🔹 Save checkpoint after every epoch
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_val": best_val,
            "loss_history": loss_history
        }, ckpt_path)
        logger.info(f"Checkpoint saved at epoch {epoch}")


In [11]:
if __name__ == "__main__":
    logger.info("STEP 1: Download HLS granules")
    l30_files, s30_files = earthdata_download_hls(BBOX, DATE_RANGE, MAX_GRANULES, RAW_DIR)

    logger.info("STEP 2: Extract bands")
    samples = extract_and_save_bands(l30_files, s30_files, OUT_DIR)
    if len(samples)==0:
        logger.error("No samples extracted"); sys.exit(1)

    logger.info("STEP 3: Resample thermal 30m -> 10m")
    for s in samples:
        t30 = s["thermal30"]
        ref = s["b02"]
        t10_out = t30.replace("_thermal_30m.tif","_thermal_10m.tif")
        resample_thermal_to_optical(t30, ref, t10_out)
        s["thermal10"] = t10_out

    logger.info("STEP 4: Train & evaluate model")
    train_and_evaluate(samples, OUT_DIR)
    logger.info("Done. Check output folder for model and test images.")

INFO:mvp_sr:STEP 1: Download HLS granules
INFO:mvp_sr:Logging into Earthdata. You will be prompted for credentials if needed.
INFO:earthaccess.auth:You're now authenticated with NASA Earthdata Login
INFO:mvp_sr:Searching HLSL30 (Landsat-based) -- includes thermal bands (B10)...
INFO:earthaccess.api:Granules found: 8
INFO:mvp_sr:Found 8 L30 granules. Downloading up to 4...
INFO:earthaccess.store: Getting 4 granules, approx download size: 0.69 GB
QUEUEING TASKS | :   0%|          | 0/60 [00:00<?, ?it/s]INFO:earthaccess.store:File HLS.L30.T43QHV.2023003T050941.v2.0.B04.tif already downloaded
INFO:earthaccess.store:File HLS.L30.T43QHV.2023003T050941.v2.0.VZA.tif already downloaded
INFO:earthaccess.store:File HLS.L30.T43QHV.2023003T050941.v2.0.SAA.tif already downloaded
INFO:earthaccess.store:File HLS.L30.T43QHV.2023003T050941.v2.0.B10.tif already downloaded
QUEUEING TASKS | : 100%|██████████| 60/60 [00:00<00:00, 6147.00it/s]
INFO:earthaccess.store:File HLS.L30.T43QHV.2023003T050941.v2.0.VA

KeyboardInterrupt: 