In [None]:
# installing
!pip install torch torchvision torchgeo rasterio segmentation-models-pytorch albumentations


Collecting albumentations
  Downloading albumentations-2.0.8-py3-none-any.whl.metadata (43 kB)
Collecting albucore==0.0.24 (from albumentations)
  Downloading albucore-0.0.24-py3-none-any.whl.metadata (5.3 kB)
Collecting opencv-python-headless>=4.9.0.80 (from albumentations)
  Downloading opencv_python_headless-4.12.0.88-cp37-abi3-win_amd64.whl.metadata (20 kB)
Collecting stringzilla>=3.10.4 (from albucore==0.0.24->albumentations)
  Downloading stringzilla-4.2.1-cp311-cp311-win_amd64.whl.metadata (112 kB)
Collecting simsimd>=5.9.2 (from albucore==0.0.24->albumentations)
  Downloading simsimd-6.5.3-cp311-cp311-win_amd64.whl.metadata (71 kB)
Collecting numpy (from torchvision)
  Downloading numpy-2.2.6-cp311-cp311-win_amd64.whl.metadata (60 kB)
INFO: pip is looking at multiple versions of pandas to determine which version is compatible with other requirements. This could take a while.
Collecting pandas>=1.5 (from torchgeo)
  Downloading pandas-2.3.3-cp311-cp311-win_amd64.whl.metadata (19

  You can safely remove it manually.
  You can safely remove it manually.
  You can safely remove it manually.
  You can safely remove it manually.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.19.0 requires numpy<2.2.0,>=1.26.0, but you have numpy 2.2.6 which is incompatible.

[notice] A new release of pip is available: 24.3.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
import math
import warnings
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# TorchGeo imports
from torchgeo.datasets import RasterDataset
try:
    from torchgeo.datasets import stack_samples  # newer
except Exception:
    from torchgeo.datasets.utils import stack_samples  # older

try:
    from torchgeo.datasets.geo import IntersectionDataset
except Exception:
    from torchgeo.datasets import IntersectionDataset

from torchgeo.samplers import GridGeoSampler, RandomGeoSampler

# U-Net backbone (Segmentation Models PyTorch)
import segmentation_models_pytorch as smp

warnings.filterwarnings("ignore", category=UserWarning)


In [4]:
class InputsDataset(RasterDataset):
    """Multi-band inputs (e.g., population, GDP, land-use) as aligned GeoTIFFs."""
    filename_glob = "*.tif"
    is_image = True  # tensors under key "image"


class LabelsDataset(RasterDataset):
    """
    Single-band continuous target (e.g., CISI scaled to [0,1]).
    Use float32 for continuous targets.
    """
    filename_glob = "*.tif"
    is_image = False           # tensors under key "mask"
    dtype = torch.float32


In [5]:
class ChannelWiseNormalize(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        m = torch.as_tensor(mean).view(1, -1, 1, 1)
        s = torch.as_tensor(std).view(1, -1, 1, 1)
        self.register_buffer("mean", m.float())
        self.register_buffer("std", s.float())

    def forward(self, x):
        return (x - self.mean) / (self.std + 1e-6)


class MaskedHuber(nn.Module):
    """Huber regression loss that ignores NaN or masked values."""
    def __init__(self, delta=0.5):
        super().__init__()
        self.delta = delta

    def forward(self, pred, target, mask=None):
        if mask is None:
            mask = torch.isfinite(target)
        mask = mask.float()

        diff = pred - target
        abs_diff = diff.abs()
        delta = torch.tensor(self.delta, device=pred.device)
        quadratic = torch.minimum(abs_diff, delta)
        loss = 0.5 * quadratic**2 + delta * (abs_diff - quadratic) - 0.5 * (delta**2)
        loss = loss * mask
        denom = mask.sum().clamp_min(1.0)
        return loss.sum() / denom


In [6]:
def quick_channel_stats(train_loader, in_ch, norm_batches):
    """Rough channel-wise mean/std over a few batches."""
    running_mean = torch.zeros(in_ch, dtype=torch.float64)
    running_var = torch.zeros(in_ch, dtype=torch.float64)
    n = 0
    with torch.no_grad():
        for i, batch in enumerate(train_loader):
            xb = batch[0]["image"].float()  # (B,C,H,W)
            B, C, H, W = xb.shape
            xb = xb.view(B, C, -1)
            running_mean += xb.mean(dim=(0, 2)).double()
            running_var += xb.var(dim=(0, 2), unbiased=False).double()
            n += 1
            if i + 1 >= norm_batches:
                break
    mean = (running_mean / max(n, 1)).float().tolist()
    std = (running_var / max(n, 1)).sqrt().float().tolist()
    return mean, std


In [7]:
def run(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True

    # ---- Build datasets
    inputs_ds = InputsDataset(args.in_dir)
    labels_ds = LabelsDataset(args.lab_dir)
    dataset = IntersectionDataset(inputs_ds, labels_ds)  # pair by spatial overlap

    # ---- Samplers
    train_sampler = RandomGeoSampler(dataset, size=args.patch, length=args.train_windows)
    val_sampler = GridGeoSampler(dataset, size=args.patch, stride=int(args.patch * args.val_stride_frac))

    # ---- DataLoaders
    train_loader = DataLoader(dataset, batch_size=args.batch, sampler=train_sampler,
                              num_workers=args.workers, collate_fn=stack_samples)
    val_loader = DataLoader(dataset, batch_size=args.batch, sampler=val_sampler,
                            num_workers=args.workers, collate_fn=stack_samples)

    # ---- Peek for shape info
    sample = next(iter(train_loader))
    x0 = sample[0]["image"]
    y0 = sample[1]["mask"]
    in_ch = x0.shape[1]
    print(f"[Info] Inferred input channels: {in_ch}")
    print(f"[Info] Label shape: {tuple(y0.shape)}")

    # ---- Normalization
    if args.use_quick_norm:
        mean, std = quick_channel_stats(train_loader, in_ch, args.norm_batches)
    else:
        mean = [0.0] * in_ch
        std = [1.0] * in_ch
    normalize = ChannelWiseNormalize(mean, std).to(device)
    print(f"[Info] mean: {mean}\n[Info] std: {std}")

    # ---- Model
    model = smp.Unet(
        encoder_name=args.backbone,
        encoder_weights=None,
        in_channels=in_ch,
        classes=1,
        activation=None,
    ).to(device)

    # ---- Optimizer, loss, AMP
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
    criterion = MaskedHuber(delta=args.huber_delta)
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    # ---- Training loop
    best_val = math.inf
    for epoch in range(1, args.epochs + 1):
        model.train()
        tr_loss = 0
        for batch in train_loader:
            x = batch[0]["image"].float().to(device)
            y = batch[1]["mask"].float().to(device)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=args.amp):
                x = normalize(x)
                yhat = torch.sigmoid(model(x))
                loss = criterion(yhat, y)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            tr_loss += loss.item()

        # ---- Validation
        model.eval()
        va_loss = 0
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=args.amp):
            for batch in val_loader:
                x = batch[0]["image"].float().to(device)
                y = batch[1]["mask"].float().to(device)
                x = normalize(x)
                yhat = torch.sigmoid(model(x))
                loss = criterion(yhat, y)
                va_loss += loss.item()

        tr_loss /= max(1, len(train_loader))
        va_loss /= max(1, len(val_loader))
        print(f"Epoch {epoch:03d} | train: {tr_loss:.4f} | val: {va_loss:.4f}")

        if va_loss < best_val:
            best_val = va_loss
            ckpt = {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "norm_mean": mean,
                "norm_std": std,
                "in_channels": in_ch,
                "backbone": args.backbone,
            }
            Path(args.out_dir).mkdir(parents=True, exist_ok=True)
            torch.save(ckpt, Path(args.out_dir) / "best_unet_regression.pt")
            print(f"  -> Saved best checkpoint (val={best_val:.4f})")

    print("Done.")


In [8]:
class Args: pass
args = Args()

# Paths (edit to your data locations)
args.in_dir = "data/inputs"
args.lab_dir = "data/labels"
args.out_dir = "checkpoints"

# Training config
args.patch = 256
args.train_windows = 500
args.val_stride_frac = 1.0
args.batch = 4
args.workers = 2

# Model / optimizer / training
args.epochs = 5
args.lr = 3e-4
args.wd = 1e-2
args.huber_delta = 0.5
args.backbone = "resnet34"
args.amp = True

# Normalization
args.use_quick_norm = True
args.norm_batches = 5


In [9]:
run(args)


DatasetNotFoundError: Dataset not found in `paths='data/inputs'` and cannot be automatically downloaded, either specify a different `paths` or manually download the dataset.