In [None]:
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# from geofeather.pygeos import to_geofeather, from_geofeather

import pandas as pd
import geopandas as gpd

#import pygeos
from rasterstats import zonal_stats
from scipy.stats import spearmanr
import shapely

from torch.utils.data import Dataset
import torch
import numpy as np
import rasterio
from rasterio.windows import Window

from shapely.geometry import mapping, shape
from shapely import wkb
from shapely.wkb import loads as from_wkb

import rasterio

from pathlib import Path

from rasterio.warp import reproject, Resampling
from rasterio.windows import Window
import torch.nn as nn

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch


from rasterio.windows import from_bounds
from rasterio.transform import from_origin

import fiona
from pathlib import Path
import numpy as np
import rasterio
from rasterio.warp import reproject
from rasterio.enums import Resampling as ResamplingEnums
from rasterio.features import rasterize

# Tiling

In [2]:
def tiling(input_path, tile_size, overlap, pad_value):
    """
    Generator that yields raster tiles in-memory, with padding.

    Args:
        input_path (str): Path to input GeoTIFF.
        tile_size (tuple): (width, height) of each tile in pixels.
        overlap (tuple): (x_overlap, y_overlap) in pixels.
        pad_value (int or float): Value used to pad edge tiles.

    Yields:
        dict: {
            "data": np.ndarray (bands, height, width),
            "transform": Affine transform of the tile,
            "indices": (top, left) pixel coordinates in source
        }
    """
    tile_width, tile_height = tile_size
    overlap_x, overlap_y = overlap

    with rasterio.open(input_path) as src:
        width, height = src.width, src.height
        num_bands = src.count
        step_x = tile_width - overlap_x
        step_y = tile_height - overlap_y

        for top in range(0, height, step_y):
            for left in range(0, width, step_x):
                win_width = min(tile_width, width - left)
                win_height = min(tile_height, height - top)
                window = Window(left, top, win_width, win_height)
                transform = src.window_transform(window)

                data = src.read(window=window)

                if win_width < tile_width or win_height < tile_height:
                    padded = np.full((num_bands, tile_height, tile_width), pad_value, dtype=data.dtype)
                    padded[:, :win_height, :win_width] = data
                    data = padded

                yield {
                    "data": data,
                    "transform": transform,
                    "indices": (top, left)
                }


In [6]:
# GDP tiles
gdp_dates_2030 = ("GDP2030_025_ssp1_clipped", "GDP2030_025_ssp2_clipped", "GDP2030_025_ssp3_clipped", "GDP2030_025_ssp4_clipped", "GDP2030_025_ssp5_clipped")
gdp_dates_2050 = ("GDP2050_025_ssp1_clipped", "GDP2050_025_ssp2_clipped", "GDP2050_025_ssp3_clipped", "GDP2050_025_ssp4_clipped", "GDP2050_025_ssp5_clipped")
gdp_dates_2100 = ("GDP2100_025_ssp1_clipped", "GDP2100_025_ssp2_clipped", "GDP2100_025_ssp3_clipped", "GDP2100_025_ssp4_clipped", "GDP2100_025_ssp5_clipped")


tiling(
    input_path="GDP_clipped_files_025d\GDP2030__025_ssp1_clipped.tif",
    tile_size=(64, 64),
    overlap=(0,0), 
    pad_value=0
) 

<generator object tiling at 0x000002D801E5A5A0>

# Model

### Dataset

In [None]:


class InfrastructureDataset(Dataset):
    def __init__(self, input_raster_path, label_raster_path,
                 tile_size=(64, 64), overlap=(0, 0), pad_value=0, transform=None):
        self.input_raster_path = input_raster_path
        self.label_raster_path = label_raster_path
        self.tile_size = tile_size
        self.overlap = overlap
        self.pad_value = pad_value
        self.transform = transform

        # Precompute tile indices
        self.tile_indices = []
        with rasterio.open(self.input_raster_path) as src:
            self.width, self.height = src.width, src.height
            self.num_bands = src.count
            step_x = tile_size[0] - overlap[0]
            step_y = tile_size[1] - overlap[1]

            for top in range(0, self.height, step_y):
                for left in range(0, self.width, step_x):
                    self.tile_indices.append((top, left))

    def __len__(self):
        return len(self.tile_indices)

    def __getitem__(self, idx):
        top, left = self.tile_indices[idx]
        tile_width, tile_height = self.tile_size

        # Read input tile
        with rasterio.open(self.input_raster_path) as src:
            window = Window(left, top, tile_width, tile_height)
            input_tile = src.read(window=window)

            if input_tile.shape[1] < tile_height or input_tile.shape[2] < tile_width:
                padded = np.full((self.num_bands, tile_height, tile_width), self.pad_value, dtype=input_tile.dtype)
                padded[:, :input_tile.shape[1], :input_tile.shape[2]] = input_tile
                input_tile = padded

        # Read label tile (single-band)
        with rasterio.open(self.label_raster_path) as lbl_src:
            label_tile = lbl_src.read(1, window=window)

            if label_tile.shape[0] < tile_height or label_tile.shape[1] < tile_width:
                padded_label = np.full((tile_height, tile_width), self.pad_value, dtype=label_tile.dtype)
                padded_label[:label_tile.shape[0], :label_tile.shape[1]] = label_tile
                label_tile = padded_label

        if self.transform:
            input_tile, label_tile = self.transform(input_tile, label_tile)

        return torch.tensor(input_tile, dtype=torch.float32), torch.tensor(label_tile, dtype=torch.long)


In [10]:
dataset = InfrastructureDataset(
    input_raster_path="cisi_index_pop_all_years.tif",       # input: predictor
    label_raster_path="CISI_label_file_025.tif",          # label: infrastructure presence in 2020
    tile_size=(64, 64),
    overlap=(0, 0),
    pad_value=0
)

loader = DataLoader(dataset, batch_size=16, shuffle=True)


In [None]:
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimpleCNN, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)  # downsample
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, 1)  # output layer
        )

    def forward(self, x):
        x = self.encoder(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)  # upsample
        x = self.decoder(x)
        return x  # Output: [B, out_channels, H, W]


In [None]:
# Training setup
# select model
model = SimpleCNN(in_channels=1, out_channels=1)  # Adjust in_channels if you have more bands

# move model to device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# loss function (regression)
criterion = nn.MSELoss()

# define optimizer
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=1e-3)




In [16]:
# Train the model
def train(model, loader, criterion, optimizer, num_epochs=10, device="cpu"):
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0

        for inputs, targets in loader:
            inputs = inputs.to(device)
            targets = targets.to(device).unsqueeze(1).float()  # [B, 1, H, W]

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss / len(loader):.4f}")


In [17]:
# Run the training model
train(model, loader, criterion, optimizer, num_epochs=20, device=device)


Epoch 1/20 - Loss: inf
Epoch 2/20 - Loss: inf
Epoch 3/20 - Loss: inf
Epoch 4/20 - Loss: inf
Epoch 5/20 - Loss: inf
Epoch 6/20 - Loss: inf
Epoch 7/20 - Loss: inf
Epoch 8/20 - Loss: inf
Epoch 9/20 - Loss: inf
Epoch 10/20 - Loss: inf
Epoch 11/20 - Loss: inf
Epoch 12/20 - Loss: inf
Epoch 13/20 - Loss: inf
Epoch 14/20 - Loss: inf
Epoch 15/20 - Loss: inf
Epoch 16/20 - Loss: inf
Epoch 17/20 - Loss: inf
Epoch 18/20 - Loss: inf
Epoch 19/20 - Loss: inf
Epoch 20/20 - Loss: inf


In [18]:
for x, y in loader:
    print("Input min/max:", x.min().item(), x.max().item())
    print("Label min/max:", y.min().item(), y.max().item())
    break


Input min/max: 0.0 0.7352721095085144
Label min/max: -9223372036854775808 0


In [None]:
"""
TorchGeo + U-Net (SMP) for continuous per-pixel regression (pixelwise regression)
- Inputs: multi-band georasters (e.g., pop, GDP, land-use)
- Labels: single-band continuous target in [0,1] (e.g., CISI), float32
- Pairing: IntersectionDataset (by spatial overlap)
- Loss: masked Huber (handles NoData via NaN)
- Mixed precision + grad clipping

Install:
    pip install torch torchvision torchgeo rasterio segmentation-models-pytorch albumentations
"""

import argparse
from pathlib import Path
import math
import warnings

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# -----------------------------
# TorchGeo imports (version-robust)
# -----------------------------
from torchgeo.datasets import RasterDataset

# stack_samples import moved across versions; try new then old
try:
    from torchgeo.datasets import stack_samples  # newer
except Exception:
    try:
        from torchgeo.datasets.utils import stack_samples  # older
    except Exception as e:
        raise ImportError("Could not import stack_samples from TorchGeo.") from e

# IntersectionDataset location varies slightly across versions
try:
    from torchgeo.datasets.geo import IntersectionDataset
except Exception:
    try:
        from torchgeo.datasets import IntersectionDataset
    except Exception as e:
        raise ImportError("Could not import IntersectionDataset from TorchGeo.") from e

from torchgeo.samplers import GridGeoSampler, RandomGeoSampler

# -----------------------------
# Model (U-Net from SMP)
# -----------------------------
import segmentation_models_pytorch as smp

warnings.filterwarnings("ignore", category=UserWarning)


# -----------------------------
# Dataset definitions (TorchGeo)
# -----------------------------
class InputsDataset(RasterDataset):
    """
    Multi-band inputs (e.g., population, GDP, land-use) as aligned GeoTIFFs.
    """
    filename_glob = "*.tif"
    is_image = True  # tensors will be under key "image"


class LabelsDataset(RasterDataset):
    """
    Single-band continuous target (e.g., CISI scaled to [0,1]).
    IMPORTANT:
      - use float32 for continuous targets
      - set is_image=False so TorchGeo stores this under key "mask"
      - store NoData as NaN in your GeoTIFFs when possible
    """
    filename_glob = "*.tif"
    is_image = False          # tensors will be under key "mask"
    dtype = torch.float32


# -----------------------------
# Normalization module
# -----------------------------
class ChannelWiseNormalize(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        m = torch.as_tensor(mean).view(1, -1, 1, 1)
        s = torch.as_tensor(std).view(1, -1, 1, 1)
        self.register_buffer("mean", m.float())
        self.register_buffer("std", s.float())

    def forward(self, x):
        return (x - self.mean) / (self.std + 1e-6)


# -----------------------------
# Masked Huber regression loss
# -----------------------------
class MaskedHuber(nn.Module):
    def __init__(self, delta=0.5):
        super().__init__()
        self.delta = delta

    def forward(self, pred, target, mask=None):
        """
        pred, target: (B,1,H,W)
        mask: (B,1,H,W) boolean/float in {0,1}. If None, uses isfinite(target).
        """
        if mask is None:
            mask = torch.isfinite(target)
        mask = mask.float()

        diff = pred - target
        abs_diff = diff.abs()
        delta = torch.tensor(self.delta, device=pred.device)
        quadratic = torch.minimum(abs_diff, delta)
        loss = 0.5 * quadratic**2 + delta * (abs_diff - quadratic) - 0.5 * (delta**2)
        loss = loss * mask

        denom = mask.sum().clamp_min(1.0)
        return loss.sum() / denom


# -----------------------------
# Training / validation driver
# -----------------------------
def run(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True

    # ---- Build base datasets (positional root arg)
    inputs_ds = InputsDataset(args.in_dir)
    labels_ds = LabelsDataset(args.lab_dir)

    # ---- Combine by spatial overlap (modern replacement for ZipDataset)
    dataset = IntersectionDataset(inputs_ds, labels_ds)   # equivalently: inputs_ds & labels_ds

    # ---- Samplers
    train_sampler = RandomGeoSampler(dataset, size=args.patch, length=args.train_windows)
    val_sampler = GridGeoSampler(dataset, size=args.patch, stride=int(args.patch * args.val_stride_frac))

    # ---- DataLoaders (use TorchGeo collate)
    train_loader = DataLoader(
        dataset, batch_size=args.batch, sampler=train_sampler, num_workers=args.workers,
        collate_fn=stack_samples, pin_memory=True, persistent_workers=args.workers > 0
    )
    val_loader = DataLoader(
        dataset, batch_size=args.batch, sampler=val_sampler, num_workers=args.workers,
        collate_fn=stack_samples, pin_memory=True, persistent_workers=args.workers > 0
    )

    # ---- Peek for channel count and sanity
    sample = next(iter(train_loader))
    x0 = sample[0]["image"]      # (B,C,H,W)
    y0 = sample[1]["mask"]       # (B,1,H,W)
    in_ch = x0.shape[1]
    print(f"[Info] Inferred input channels: {in_ch}")
    print(f"[Info] Label shape sample: {tuple(y0.shape)} (expect B,1,H,W)")

    # ---- Normalization
    if args.use_quick_norm:
        mean, std = quick_channel_stats(train_loader, in_ch, args.norm_batches)
    else:
        mean = [0.0] * in_ch
        std = [1.0] * in_ch
    print(f"[Info] Normalization mean: {mean}")
    print(f"[Info] Normalization std : {std}")

    normalize = ChannelWiseNormalize(mean, std).to(device)

    # ---- Model: SMP U-Net
    model = smp.Unet(
        encoder_name=args.backbone,     # e.g., "resnet34"
        encoder_weights=None,           # use "imagenet" only if inputs are truly RGB-like
        in_channels=in_ch,
        classes=1,                      # single continuous target channel
        activation=None,                # handle activation explicitly
    ).to(device)

    # ---- Optimizer, loss, AMP
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
    criterion = MaskedHuber(delta=args.huber_delta)
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    # ---- Training loop
    best_val = math.inf
    for epoch in range(1, args.epochs + 1):
        model.train()
        tr_loss, tr_n = 0.0, 0
        for batch in train_loader:
            x = batch[0]["image"].float().to(device)  # (B,C,H,W)
            y = batch[1]["mask"].float().to(device)   # (B,1,H,W) in [0,1]; NaN = NoData preferred

            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=args.amp):
                x = normalize(x)
                yhat = model(x)
                yhat = torch.sigmoid(yhat)  # constrain to [0,1]
                loss = criterion(yhat, y)   # masked inside (NaNs ignored)

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            tr_loss += loss.item()
            tr_n += 1

        # ---- Validation
        model.eval()
        va_loss, va_n = 0.0, 0
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=args.amp):
            for batch in val_loader:
                x = batch[0]["image"].float().to(device)
                y = batch[1]["mask"].float().to(device)
                x = normalize(x)
                yhat = torch.sigmoid(model(x))
                loss = criterion(yhat, y)
                va_loss += loss.item()
                va_n += 1

        tr_loss /= max(tr_n, 1)
        va_loss /= max(va_n, 1)
        print(f"Epoch {epoch:03d} | train: {tr_loss:.4f} | val: {va_loss:.4f}")

        if va_loss < best_val:
            best_val = va_loss
            ckpt = {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "norm_mean": mean,
                "norm_std": std,
                "in_channels": in_ch,
                "backbone": args.backbone,
            }
            Path(args.out_dir).mkdir(parents=True, exist_ok=True)
            torch.save(ckpt, Path(args.out_dir) / "best_unet_regression.pt")
            print(f"  -> saved best checkpoint (val={best_val:.4f})")

    print("Done.")


def quick_channel_stats(train_loader, in_ch, norm_batches):
    """Rough channel-wise mean/std over a few batches."""
    running_mean = torch.zeros(in_ch, dtype=torch.float64)
    running_var = torch.zeros(in_ch, dtype=torch.float64)
    n = 0
    with torch.no_grad():
        for i, batch in enumerate(train_loader):
            xb = batch[0]["image"].float()  # (B,C,H,W)
            B, C, H, W = xb.shape
            xb = xb.view(B, C, -1)
            running_mean += xb.mean(dim=(0, 2)).double()
            running_var += xb.var(dim=(0, 2), unbiased=False).double()
            n += 1
            if i + 1 >= norm_batches:
                break
    mean = (running_mean / max(n, 1)).float().tolist()
    std = (running_var / max(n, 1)).sqrt().float().tolist()
    return mean, std


# -----------------------------
# CLI
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--in_dir", type=str, required=True, help="Path to inputs directory (multi-band GeoTIFFs)")
    p.add_argument("--lab_dir", type=str, required=True, help="Path to labels directory (single-band GeoTIFFs)")
    p.add_argument("--out_dir", type=str, default="checkpoints", help="Where to save checkpoints")

    p.add_argument("--patch", type=int, default=256, help="Patch size in pixels (H=W)")
    p.add_argument("--train_windows", type=int, default=2000, help="#random windows per epoch for training")
    p.add_argument("--val_stride_frac", type=float, default=1.0, help="Val stride as fraction of patch size (>=1)")
    p.add_argument("--batch", type=int, default=4)
    p.add_argument("--workers", type=int, default=4)

    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--wd", type=float, default=1e-2)
    p.add_argument("--huber_delta", type=float, default=0.5)
    p.add_argument("--backbone", type=str, default="resnet34")  # any SMP encoder
    p.add_argument("--amp", action="store_true", help="Enable mixed precision")

    p.add_argument("--use_quick_norm", action="store_true", help="Estimate mean/std over a few batches")
    p.add_argument("--norm_batches", type=int, default=10, help="#batches for quick norm estimation")
    return p.parse_args()


if __name__ == "__main__":
    class Args: pass
    args = Args()
    args.in_dir = "data/inputs"
    args.lab_dir = "data/labels"
    args.out_dir = "checkpoints"

    args.patch = 256
    args.train_windows = 500
    args.val_stride_frac = 1.0
    args.batch = 4
    args.workers = 2

    args.epochs = 5
    args.lr = 3e-4
    args.wd = 1e-2
    args.huber_delta = 0.5
    args.backbone = "resnet34"
    args.amp = True

    args.use_quick_norm = True
    args.norm_batches = 5

    run(args)
