In [1]:
# installing
!pip install torch torchvision torchgeo rasterio segmentation-models-pytorch 





[notice] A new release of pip is available: 24.3.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import math
import warnings
from pathlib import Path

import rasterio
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torch
import torch.nn.functional as F

# TorchGeo imports
from torchgeo.datasets import RasterDataset
try:
    from torchgeo.datasets import stack_samples  # newer
except Exception:
    from torchgeo.datasets.utils import stack_samples  # older

try:
    from torchgeo.datasets.geo import IntersectionDataset
except Exception:
    from torchgeo.datasets import IntersectionDataset

from torchgeo.samplers import GridGeoSampler, RandomGeoSampler

# U-Net backbone (Segmentation Models PyTorch)
import segmentation_models_pytorch as smp

warnings.filterwarnings("ignore", category=UserWarning)


In [3]:
# # stack three single-band rasters into one 3-band GeoTIFF
# in_dir  = Path("READY_data/inputs")
# out_mb  = in_dir / "historic_3band_025deg.tif"

# in_lc   = in_dir / "2018_landcover_aligned_025deg.tif"
# in_gdp  = in_dir / "2019_gdp_aligned_025deg.tif"
# in_pop  = in_dir / "2020_pop_aligned_025deg.tif"

# with rasterio.open(in_lc) as src_lc, \
#      rasterio.open(in_gdp) as src_gdp, \
#      rasterio.open(in_pop) as src_pop:

#     # sanity checks
#     assert src_lc.crs == src_gdp.crs == src_pop.crs, "CRS mismatch"
#     assert src_lc.transform == src_gdp.transform == src_pop.transform, "transform mismatch"
#     assert (src_lc.width, src_lc.height) == (src_gdp.width, src_gdp.height) == (src_pop.width, src_pop.height), "size mismatch"

#     lc  = src_lc.read(1)
#     gdp = src_gdp.read(1)
#     pop = src_pop.read(1)

#     arr = np.stack([lc, gdp, pop], axis=0).astype(np.float32)  # (3, H, W)

#     profile = src_lc.profile
#     profile.update(count=3, dtype="float32", nodata=np.nan)

#     with rasterio.open(out_mb, "w", **profile) as dst:
#         dst.write(arr)
#         dst.set_band_description(1, "landcover")
#         dst.set_band_description(2, "gdp")
#         dst.set_band_description(3, "pop")

# print("Wrote:", out_mb)


In [4]:
# quick inspection
# import rasterio
# from pathlib import Path

# path = Path("READY_data/inputs/historic_3band_025deg.tif")
# with rasterio.open(path) as src:
#     print("Bands:", src.count)            # expect 3
#     print("Size :", src.width, "x", src.height)
#     print("CRS  :", src.crs)
#     print("Desc :", src.descriptions)     # ('landcover','gdp','pop')

In [23]:
# Args
class Args:
    pass

args = Args()
args.in_dir = "READY_data/inputs"
args.lab_dir = "READY_data/labels"
args.out_dir = "checkpoints"

args.patch = 64
args.train_windows = 500
args.val_stride_frac = 1.0
args.batch = 4
args.workers = 0

args.epochs = 5
args.lr = 3e-4
args.wd = 1e-2
args.huber_delta = 0.5
args.backbone = "resnet34"
args.amp = True

# 👇 add this line to fix the NaN normalization issue
args.use_quick_norm = True # was False; now compute train-only mean/std
args.norm_batches = 5


In [6]:
# Datasets
class InputsDataset(RasterDataset):
    """Multi-band inputs (e.g., population, GDP, land-use) as aligned GeoTIFFs."""
    filename_glob = "historic_3band_025deg.tif"  # <— only the multiband file
    is_image = True

class LabelsDataset(RasterDataset):
    """
    Single-band continuous target (e.g., CISI scaled to [0,1]).
    Use float32 for continuous targets.
    """
    filename_glob = "2024_CISI_025deg.tif"   # your label stays as before (single band)
    is_image = False
    dtype = torch.float32


In [7]:
# Normalization & Loss
class ChannelWiseNormalize(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        m = torch.as_tensor(mean).view(1, -1, 1, 1)
        s = torch.as_tensor(std).view(1, -1, 1, 1)
        self.register_buffer("mean", m.float())
        self.register_buffer("std", s.float())

    def forward(self, x):
        return (x - self.mean) / (self.std + 1e-6)


class MaskedHuber(nn.Module):
    """Huber regression loss that ignores NaN or masked values."""
    def __init__(self, delta=0.5):
        super().__init__()
        self.delta = delta

    def forward(self, pred, target, mask=None):
        if mask is None:
            mask = torch.isfinite(target)
        mask = mask.float()

        diff = pred - target
        abs_diff = diff.abs()
        delta = torch.tensor(self.delta, device=pred.device)
        quadratic = torch.minimum(abs_diff, delta)
        loss = 0.5 * quadratic**2 + delta * (abs_diff - quadratic) - 0.5 * (delta**2)
        loss = loss * mask
        denom = mask.sum().clamp_min(1.0)
        return loss.sum() / denom


In [22]:
# Mean/std stats block
def quick_channel_stats(train_loader, in_ch, norm_batches):
    """Rough channel-wise mean/std over a few TRAIN batches (NaN-safe)."""
    running_mean = torch.zeros(in_ch, dtype=torch.float64)
    running_m2   = torch.zeros(in_ch, dtype=torch.float64)  # for variance via Welford
    count        = torch.zeros(in_ch, dtype=torch.float64)

    with torch.no_grad():
        for i, batch in enumerate(train_loader):
            xb = batch["image"].float()  # (B,C,H,W)
            xb = torch.nan_to_num(xb, nan=0.0, posinf=0.0, neginf=0.0)
            B, C, H, W = xb.shape
            xb = xb.view(B, C, -1)

            # per-channel stats over all pixels in the batch
            batch_sum   = xb.sum(dim=(0, 2)).double()
            batch_count = torch.full((C,), B*H*W, dtype=torch.float64)
            batch_mean  = batch_sum / batch_count

            # update running mean and M2 (per channel)
            delta = batch_mean - running_mean / torch.clamp(count, min=1.0)
            running_mean += batch_sum
            count        += batch_count
            # recompute mean after merging
            new_mean = running_mean / torch.clamp(count, min=1.0)

            # approximate M2 by accumulating squared diffs within batch (cheap)
            diffs = (xb.double() - batch_mean.view(1, C, 1))**2
            running_m2 += diffs.sum(dim=(0, 2))

            # move on
            if i + 1 >= norm_batches:
                break

    mean = (running_mean / torch.clamp(count, min=1.0)).float().tolist()
    var  = (running_m2   / torch.clamp(count, min=1.0)).float()
    std  = torch.sqrt(torch.clamp(var, min=1e-12)).tolist()
    return mean, std


In [9]:
# Cell 6: Build the paired dataset (inputs & labels) and window count
inputs_ds = InputsDataset(args.in_dir)
labels_ds = LabelsDataset(args.lab_dir)
dataset   = inputs_ds & labels_ds

# Show basic info
print("Inputs res:", inputs_ds.res, "Labels res:", labels_ds.res)
print("Inputs CRS:", inputs_ds.crs, "Labels CRS:", labels_ds.crs)
print("Bounds (intersection):", dataset.bounds)

args.patch = 64  # was 256; must be <= the image min dimension (~195)

# Count candidate windows using a non-overlapping grid
from torchgeo.samplers import GridGeoSampler
probe_stride = args.patch
probe_grid   = GridGeoSampler(dataset, size=args.patch, stride=probe_stride)
num_windows  = sum(1 for _ in probe_grid)
print(f"[Info] Candidate windows (patch={args.patch}, stride={probe_stride}): {num_windows}")

if num_windows == 0:
    raise RuntimeError("No candidate windows at this patch size. Try args.patch = 128 and re-run this cell.")


Inputs res: (0.25, 0.25) Labels res: (0.25, 0.25)
Inputs CRS: EPSG:4326 Labels CRS: EPSG:4326
Bounds (intersection): BoundingBox(minx=-25.0, maxx=32.0, miny=32.25, maxy=81.0, mint=0.0, maxt=9.223372036854776e+18)
[Info] Candidate windows (patch=64, stride=64): 16


In [19]:
# Define bounds for train/test/val splits
from torchgeo.datasets import BoundingBox
import math

minx, maxx = dataset.bounds[0], dataset.bounds[1]
miny, maxy = dataset.bounds[2], dataset.bounds[3]
mint, maxt = float("-inf"), float("inf")

# Split longitudinally into 60% / 20% / 20% (train / val / test)
W = maxx - minx
x1 = minx + 0.60 * W
x2 = minx + 0.80 * W

train_roi = BoundingBox(minx=minx, maxx=x1, miny=miny, maxy=maxy, mint=mint, maxt=maxt)
val_roi   = BoundingBox(minx=x1,  maxx=x2, miny=miny, maxy=maxy, mint=mint, maxt=maxt)
test_roi  = BoundingBox(minx=x2,  maxx=maxx,miny=miny, maxy=maxy, mint=mint, maxt=maxt)

print("ROIs set:",
      f"\n  train x∈[{minx:.3f},{x1:.3f}]",
      f"\n  val   x∈[{x1:.3f},{x2:.3f}]",
      f"\n  test  x∈[{x2:.3f},{maxx:.3f}]")

ROIs set: 
  train x∈[-25.000,9.200] 
  val   x∈[9.200,20.600] 
  test  x∈[20.600,32.000]


In [20]:
from torchgeo.samplers import GridGeoSampler

# Train: 50% overlap; Val/Test: no overlap
train_sampler = GridGeoSampler(dataset, size=args.patch, stride=int(args.patch * 0.5), roi=train_roi)
val_sampler   = GridGeoSampler(dataset, size=args.patch, stride=args.patch,                 roi=val_roi)
test_sampler  = GridGeoSampler(dataset, size=args.patch, stride=args.patch,                 roi=test_roi)

print("Samplers: Grid(train|val|test) with disjoint ROIs")


Samplers: Grid(train|val|test) with disjoint ROIs


In [21]:
train_loader = DataLoader(dataset, batch_size=args.batch, sampler=train_sampler,
                          num_workers=args.workers, collate_fn=stack_samples,
                          pin_memory=False, persistent_workers=False)
val_loader   = DataLoader(dataset, batch_size=args.batch, sampler=val_sampler,
                          num_workers=args.workers, collate_fn=stack_samples,
                          pin_memory=False, persistent_workers=False)
test_loader  = DataLoader(dataset, batch_size=args.batch, sampler=test_sampler,
                          num_workers=args.workers, collate_fn=stack_samples,
                          pin_memory=False, persistent_workers=False)

print("DataLoaders ready (train/val/test).")


DataLoaders ready (train/val/test).


In [12]:
# Cell 9 — Smoke test a single batch
sample = next(iter(train_loader))
print("Keys:", sample.keys())
x0 = sample["image"]              # (B, C, H, W)
y0 = sample["mask"]               # (B, H, W)
print("X shape:", x0.shape, "Y shape:", y0.shape)


Keys: dict_keys(['crs', 'bounds', 'image', 'mask'])
X shape: torch.Size([4, 3, 64, 64]) Y shape: torch.Size([4, 64, 64])


In [None]:
# Cell 10 — Training function that accepts loaders (does NOT rebuild samplers/dataset)

def run_with_loaders(args, train_loader, val_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True

    # Peek for channels
    sample = next(iter(train_loader))
    x0 = sample["image"]                    # (B,C,H,W)
    y0 = sample["mask"]                    # (B,H,W)
    print("Sample keys:", sample.keys())
    print("X shape:", x0.shape, "Y shape:", y0.shape)
    in_ch = x0.shape[1]
    print(f"[Info] Inferred input channels: {in_ch}")

    # Normalization
    if args.use_quick_norm:
        mean, std = quick_channel_stats(train_loader, in_ch, args.norm_batches)
    else:
        mean = [0.0] * in_ch
        std  = [1.0] * in_ch
    normalize = ChannelWiseNormalize(mean, std).to(device)
    print(f"[Info] mean: {mean}\n[Info] std : {std}")

    # Model
    model = smp.Unet(
        encoder_name=args.backbone,
        encoder_weights=None,
        in_channels=in_ch,
        classes=1,
        activation=None,
    ).to(device)

    # Optimizer, loss, AMP
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                            optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    criterion = MaskedHuber(delta=args.huber_delta)
    scaler = torch.amp.GradScaler('cuda', enabled=args.amp)  # new API

    best_val = float("inf")
    bad_epochs = 0
    for epoch in range(1, args.epochs + 1):
        # ---- Train
        model.train()
        tr_loss = 0.0
        tr_n = 0
        for batch in train_loader:
            x = batch["image"].float().to(device)
            y = batch["mask"].float().unsqueeze(1).to(device)  # (B,1,H,W)

            # sanitize inputs; skip if no valid labels
            x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).clamp(-1e6, 1e6)
            if not torch.isfinite(y).any():
                continue

            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast('cuda', enabled=args.amp):
                x = normalize(x)
                yhat = torch.sigmoid(model(x))
                loss = criterion(yhat, y)  # MaskedHuber ignores non-finite y
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)  # <-- unscale first
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()


            tr_loss += loss.item()
            tr_n += 1

        # ---- Validate
        model.eval()
        va_loss = 0.0
        va_n = 0
        with torch.no_grad(), torch.amp.autocast('cuda', enabled=args.amp):
            for batch in val_loader:
                x = batch["image"].float().to(device)
                y = batch["mask"].float().unsqueeze(1).to(device)
                x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).clamp(-1e6, 1e6)
                if not torch.isfinite(y).any():
                    continue

                x = normalize(x)
                yhat = torch.sigmoid(model(x))
                loss = criterion(yhat, y)
                va_loss += loss.item()
                va_n += 1

        tr_loss = tr_loss / max(tr_n, 1)
        va_loss = va_loss / max(va_n, 1)
        print(f"Epoch {epoch:03d} | train: {tr_loss:.4f} | val: {va_loss:.4f}")

        # Step scheduler on the actual validation loss
        scheduler.step(va_loss if va_n > 0 else tr_loss)

        if va_n > 0 and va_loss < best_val:
            best_val = va_loss
            ckpt = {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "norm_mean": getattr(normalize, "mean", None).flatten().tolist(),
                "norm_std": getattr(normalize, "std", None).flatten().tolist(),
                "in_channels": model.encoder.conv1.in_channels if hasattr(model.encoder, "conv1") else None,
                "backbone": args.backbone,
            }
            Path(args.out_dir).mkdir(parents=True, exist_ok=True)
            torch.save(ckpt, Path(args.out_dir) / "best_unet_regression.pt")
            print(f"  -> Saved best checkpoint (val={best_val:.4f})")
        else:
            bad_epochs += 1
            if bad_epochs >= 6:  # simple patience; align with scheduler patience if you want
                print("Early stopping.")
                break


    print("Done.")
    return model, normalize


In [14]:
# Cell 11 — Start training
model, normalize = run_with_loaders(args, train_loader, val_loader)


Sample keys: dict_keys(['crs', 'bounds', 'image', 'mask'])
X shape: torch.Size([4, 3, 64, 64]) Y shape: torch.Size([4, 64, 64])
[Info] Inferred input channels: 3
[Info] mean: [0.0, 0.0, 0.0]
[Info] std : [1.0, 1.0, 1.0]


  scaler = torch.cuda.amp.GradScaler(enabled=args.amp)


Epoch 001 | train: 0.0121 | val: 0.0402
  -> Saved best checkpoint (val=0.0402)
Epoch 002 | train: -0.0150 | val: -0.0245
  -> Saved best checkpoint (val=-0.0245)
Epoch 003 | train: -0.0370 | val: -0.0343
  -> Saved best checkpoint (val=-0.0343)
Epoch 004 | train: -0.0559 | val: -0.0605
  -> Saved best checkpoint (val=-0.0605)
Epoch 005 | train: -0.0709 | val: -0.0774
  -> Saved best checkpoint (val=-0.0774)
Done.


In [24]:
# Cell 12 — Eval helper (MAE, RMSE, R^2)
def evaluate(loader, model, normalize, device):
    model.eval()
    mae_sum = rmse_sum = r2_sum = n = 0
    with torch.no_grad(), torch.amp.autocast('cuda', enabled=(device=='cuda')):
        for batch in loader:
            x = batch["image"].float().to(device)
            y = batch["mask"].float().unsqueeze(1).to(device)

            # match training sanitization
            x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).clamp(-1e6, 1e6)

            m = torch.isfinite(y)
            if not m.any():
                continue

            x = normalize(x)
            p = torch.sigmoid(model(x))

            yv, pv = y[m], p[m]
            mae  = torch.mean(torch.abs(pv - yv)).item()
            rmse = torch.sqrt(torch.mean((pv - yv)**2)).item()
            ymean = torch.mean(yv)
            denom = torch.sum((yv - ymean)**2).clamp_min(1e-12)
            r2 = (1 - torch.sum((pv - yv)**2) / denom).item()

            mae_sum += mae; rmse_sum += rmse; r2_sum += r2; n += 1
    return mae_sum/max(n,1), rmse_sum/max(n,1), r2_sum/max(n,1)


In [None]:
# Cell 13 — Evaluate
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(Path(args.out_dir) / "best_unet_regression.pt", map_location=device)

eval_model = smp.Unet(
    encoder_name=args.backbone,
    encoder_weights=None,
    in_channels=ckpt["in_channels"],    # multiband input
    classes=1,
    activation=None
).to(device)
eval_model.load_state_dict(ckpt["model_state"])
eval_norm = ChannelWiseNormalize(ckpt["norm_mean"], ckpt["norm_std"]).to(device)

mae, rmse, r2 = evaluate(val_loader, eval_model, eval_norm, device)
print(f"Validation → MAE {mae:.4f} | RMSE {rmse:.4f} | R² {r2:.3f}")


  with torch.no_grad(), torch.cuda.amp.autocast():


Validation → MAE nan | RMSE nan | R² nan


In [25]:
# Cell 14 — Final TEST evaluation
device = "cuda" if torch.cuda.is_available() else "cpu"

# Reuse the checkpoint and normalization already loaded in Cell 13
test_mae, test_rmse, test_r2 = evaluate(test_loader, eval_model, eval_norm, device)
print(f"TEST     → MAE {test_mae:.4f} | RMSE {test_rmse:.4f} | R² {test_r2:.3f}")


TEST     → MAE 0.0000 | RMSE 0.0000 | R² 0.000


In [26]:
# Cell 15 — Save a few QA figures from val & test
import numpy as np, matplotlib.pyplot as plt
from pathlib import Path

qa_dir = Path("qa_figs"); qa_dir.mkdir(parents=True, exist_ok=True)

def save_triptychs(loader, tag, limit_batches=2):
    eval_model.eval()
    done = 0
    with torch.no_grad(), torch.amp.autocast('cuda', enabled=(device=='cuda')):
        for bidx, batch in enumerate(loader):
            x = batch["image"].float().to(device)
            y = batch["mask"].float().unsqueeze(1).to(device)
            x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).clamp(-1e6, 1e6)

            x = eval_norm(x)
            p = torch.sigmoid(eval_model(x))

            B = x.size(0); nshow = min(B, 4)
            for i in range(nshow):
                xi = x[i].detach().cpu()
                yi = y[i,0].detach().cpu()
                pi = p[i,0].detach().cpu()

                # visualize first 3 channels (or repeat if 1)
                rgb = xi.repeat(3,1,1) if xi.size(0)==1 else xi[:3]
                rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-6)

                fig, axs = plt.subplots(1, 3, figsize=(12, 4))
                axs[0].imshow(np.moveaxis(rgb.numpy(), 0, -1)); axs[0].set_title("Input");   axs[0].axis("off")
                im1 = axs[1].imshow(yi.numpy(), vmin=0, vmax=1); axs[1].set_title("Target"); axs[1].axis("off")
                im2 = axs[2].imshow(pi.numpy(), vmin=0, vmax=1); axs[2].set_title("Pred");   axs[2].axis("off")
                fig.colorbar(im2, ax=axs, fraction=0.025, pad=0.02)
                fig.tight_layout()
                fig.savefig(qa_dir / f"{tag}_b{bidx}_i{i}.png", dpi=160)
                plt.close(fig)

            # prediction histogram for the batch
            plt.figure(figsize=(5,4))
            plt.hist(p.detach().cpu().numpy().ravel(), bins=50)
            plt.title(f"Pred histogram — {tag} batch {bidx}")
            plt.tight_layout()
            plt.savefig(qa_dir / f"{tag}_b{bidx}_hist.png", dpi=160)
            plt.close()

            done += 1
            if done >= limit_batches:
                break

save_triptychs(val_loader,  tag="val",  limit_batches=2)
save_triptychs(test_loader, tag="test", limit_batches=2)
print(f"Saved QA figures to: {qa_dir.resolve()}")


Saved QA figures to: C:\Users\Gebruiker\Desktop\Thesis\FutureCISI-main\empy_repo_for_thesis\qa_figs


In [17]:
# Check out checkpoint
import torch
checkpoint = torch.load("checkpoints/best_unet_regression.pt", map_location="cpu")
print(checkpoint.keys())


dict_keys(['epoch', 'model_state', 'optimizer_state', 'norm_mean', 'norm_std', 'in_channels', 'backbone'])


In [18]:
# TEST