In [None]:
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# from geofeather.pygeos import to_geofeather, from_geofeather

import pandas as pd
import geopandas as gpd

#import pygeos
from rasterstats import zonal_stats
from scipy.stats import spearmanr
import shapely

from torch.utils.data import Dataset
import torch
import numpy as np
import rasterio
from rasterio.windows import Window

from shapely.geometry import mapping, shape
from shapely import wkb
from shapely.wkb import loads as from_wkb

import rasterio

from pathlib import Path

from rasterio.warp import reproject, Resampling
from rasterio.windows import Window
import torch.nn as nn

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch


from rasterio.windows import from_bounds
from rasterio.transform import from_origin

import fiona
from pathlib import Path
import numpy as np
import rasterio
from rasterio.warp import reproject
from rasterio.enums import Resampling as ResamplingEnums
from rasterio.features import rasterize

# Tiling

In [2]:
def tiling(input_path, tile_size, overlap, pad_value):
    """
    Generator that yields raster tiles in-memory, with padding.

    Args:
        input_path (str): Path to input GeoTIFF.
        tile_size (tuple): (width, height) of each tile in pixels.
        overlap (tuple): (x_overlap, y_overlap) in pixels.
        pad_value (int or float): Value used to pad edge tiles.

    Yields:
        dict: {
            "data": np.ndarray (bands, height, width),
            "transform": Affine transform of the tile,
            "indices": (top, left) pixel coordinates in source
        }
    """
    tile_width, tile_height = tile_size
    overlap_x, overlap_y = overlap

    with rasterio.open(input_path) as src:
        width, height = src.width, src.height
        num_bands = src.count
        step_x = tile_width - overlap_x
        step_y = tile_height - overlap_y

        for top in range(0, height, step_y):
            for left in range(0, width, step_x):
                win_width = min(tile_width, width - left)
                win_height = min(tile_height, height - top)
                window = Window(left, top, win_width, win_height)
                transform = src.window_transform(window)

                data = src.read(window=window)

                if win_width < tile_width or win_height < tile_height:
                    padded = np.full((num_bands, tile_height, tile_width), pad_value, dtype=data.dtype)
                    padded[:, :win_height, :win_width] = data
                    data = padded

                yield {
                    "data": data,
                    "transform": transform,
                    "indices": (top, left)
                }


In [6]:
# GDP tiles
gdp_dates_2030 = ("GDP2030_025_ssp1_clipped", "GDP2030_025_ssp2_clipped", "GDP2030_025_ssp3_clipped", "GDP2030_025_ssp4_clipped", "GDP2030_025_ssp5_clipped")
gdp_dates_2050 = ("GDP2050_025_ssp1_clipped", "GDP2050_025_ssp2_clipped", "GDP2050_025_ssp3_clipped", "GDP2050_025_ssp4_clipped", "GDP2050_025_ssp5_clipped")
gdp_dates_2100 = ("GDP2100_025_ssp1_clipped", "GDP2100_025_ssp2_clipped", "GDP2100_025_ssp3_clipped", "GDP2100_025_ssp4_clipped", "GDP2100_025_ssp5_clipped")


tiling(
    input_path="GDP_clipped_files_025d\GDP2030__025_ssp1_clipped.tif",
    tile_size=(64, 64),
    overlap=(0,0), 
    pad_value=0
) 

<generator object tiling at 0x000002D801E5A5A0>

# Model

### Dataset

In [None]:


class InfrastructureDataset(Dataset):
    def __init__(self, input_raster_path, label_raster_path,
                 tile_size=(64, 64), overlap=(0, 0), pad_value=0, transform=None):
        self.input_raster_path = input_raster_path
        self.label_raster_path = label_raster_path
        self.tile_size = tile_size
        self.overlap = overlap
        self.pad_value = pad_value
        self.transform = transform

        # Precompute tile indices
        self.tile_indices = []
        with rasterio.open(self.input_raster_path) as src:
            self.width, self.height = src.width, src.height
            self.num_bands = src.count
            step_x = tile_size[0] - overlap[0]
            step_y = tile_size[1] - overlap[1]

            for top in range(0, self.height, step_y):
                for left in range(0, self.width, step_x):
                    self.tile_indices.append((top, left))

    def __len__(self):
        return len(self.tile_indices)

    def __getitem__(self, idx):
        top, left = self.tile_indices[idx]
        tile_width, tile_height = self.tile_size

        # Read input tile
        with rasterio.open(self.input_raster_path) as src:
            window = Window(left, top, tile_width, tile_height)
            input_tile = src.read(window=window)

            if input_tile.shape[1] < tile_height or input_tile.shape[2] < tile_width:
                padded = np.full((self.num_bands, tile_height, tile_width), self.pad_value, dtype=input_tile.dtype)
                padded[:, :input_tile.shape[1], :input_tile.shape[2]] = input_tile
                input_tile = padded

        # Read label tile (single-band)
        with rasterio.open(self.label_raster_path) as lbl_src:
            label_tile = lbl_src.read(1, window=window)

            if label_tile.shape[0] < tile_height or label_tile.shape[1] < tile_width:
                padded_label = np.full((tile_height, tile_width), self.pad_value, dtype=label_tile.dtype)
                padded_label[:label_tile.shape[0], :label_tile.shape[1]] = label_tile
                label_tile = padded_label

        if self.transform:
            input_tile, label_tile = self.transform(input_tile, label_tile)

        return torch.tensor(input_tile, dtype=torch.float32), torch.tensor(label_tile, dtype=torch.long)


In [10]:
dataset = InfrastructureDataset(
    input_raster_path="cisi_index_pop_all_years.tif",       # input: predictor
    label_raster_path="CISI_label_file_025.tif",          # label: infrastructure presence in 2020
    tile_size=(64, 64),
    overlap=(0, 0),
    pad_value=0
)

loader = DataLoader(dataset, batch_size=16, shuffle=True)


In [None]:
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimpleCNN, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)  # downsample
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, 1)  # output layer
        )

    def forward(self, x):
        x = self.encoder(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)  # upsample
        x = self.decoder(x)
        return x  # Output: [B, out_channels, H, W]


In [None]:
# Training setup
# select model
model = SimpleCNN(in_channels=1, out_channels=1)  # Adjust in_channels if you have more bands

# move model to device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# loss function (regression)
criterion = nn.MSELoss()

# define optimizer
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=1e-3)




In [16]:
# Train the model
def train(model, loader, criterion, optimizer, num_epochs=10, device="cpu"):
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0

        for inputs, targets in loader:
            inputs = inputs.to(device)
            targets = targets.to(device).unsqueeze(1).float()  # [B, 1, H, W]

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss / len(loader):.4f}")


In [17]:
# Run the training model
train(model, loader, criterion, optimizer, num_epochs=20, device=device)


Epoch 1/20 - Loss: inf
Epoch 2/20 - Loss: inf
Epoch 3/20 - Loss: inf
Epoch 4/20 - Loss: inf
Epoch 5/20 - Loss: inf
Epoch 6/20 - Loss: inf
Epoch 7/20 - Loss: inf
Epoch 8/20 - Loss: inf
Epoch 9/20 - Loss: inf
Epoch 10/20 - Loss: inf
Epoch 11/20 - Loss: inf
Epoch 12/20 - Loss: inf
Epoch 13/20 - Loss: inf
Epoch 14/20 - Loss: inf
Epoch 15/20 - Loss: inf
Epoch 16/20 - Loss: inf
Epoch 17/20 - Loss: inf
Epoch 18/20 - Loss: inf
Epoch 19/20 - Loss: inf
Epoch 20/20 - Loss: inf


In [18]:
for x, y in loader:
    print("Input min/max:", x.min().item(), x.max().item())
    print("Label min/max:", y.min().item(), y.max().item())
    break


Input min/max: 0.0 0.7352721095085144
Label min/max: -9223372036854775808 0
