In [None]:
import os
from typing import Callable

import rioxarray
import geopandas as gpd
import shapely.geometry
import rasterio
import numpy as np
import rasterio.windows
import rasterio.transform
import torch
from rasterio.enums import Resampling
from tqdm import tqdm
from pyproj import CRS
from pystac_client import Client
import torchvision.transforms.v2 as T
from odc.stac import configure_rio, stac_load
from torchgeo.models import resnet18, ResNet18_Weights

os.environ["GDAL_HTTP_TCP_KEEPALIVE"] = "YES"
os.environ["AWS_S3_ENDPOINT"] = "eodata.dataspace.copernicus.eu"
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get("AWS_ACCESS_KEY_ID")
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get("AWS_SECRET_ACCESS_KEY")
os.environ["AWS_HTTPS"] = "YES"
os.environ["AWS_VIRTUAL_HOSTING"] = "FALSE"
os.environ["GDAL_HTTP_UNSAFESSL"] = "YES"

### Test download patch

In [None]:
cfg = configure_rio(cloud_defaults=True)

path = "geometries/bexar.geojson"
gdf = gpd.read_file(path).to_crs("EPSG:4326")
geometry = shapely.geometry.box(*gdf.total_bounds)

url = "https://stac.dataspace.copernicus.eu/v1"
client = Client.open(url)

# Search the collection
items = client.search(
    collections=["sentinel-2-global-mosaics"],
    intersects=geometry,
    datetime="2024-06-01/2025-06-02",
    max_items=1,
).get_all_items()

ds = (
    stac_load(
        list(items),
        bands=["B04", "B03", "B02"],
        intersects=geometry,
        crs="EPSG:3857",
        resolution=10,
        chunks={"x": 256, "y": 256},
        stac_cfg=cfg,
    )
    .isel(time=0)
    .compute()
)
output_path = os.path.splitext(os.path.basename(path))[0] + ".tif"
ds.to_array("band").rio.to_raster(
    output_path,
    driver="COG",
    transform=ds.rio.transform(),
    dtype="uint16",
    compression="LZW",
)

### Test download UTM tiles

In [None]:
cfg = configure_rio(cloud_defaults=True)
url = "https://stac.dataspace.copernicus.eu/v1"
client = Client.open(url)


def utm_zone_bounds(zone_number: int, north: bool = True):
    """Returns lon/lat bounds for a UTM zone as a shapely box (EPSG:4326)"""
    # UTM zones are 6° wide
    lon_min = -180 + (zone_number - 1) * 6
    lon_max = lon_min + 6
    lat_min = 0 if north else -80
    lat_max = 84 if north else 0
    return shapely.geometry.box(lon_min, lat_min, lon_max, lat_max)


for zone in tqdm(range(10, 19)):
    utm_crs = CRS.from_dict(
        {"proj": "utm", "zone": zone, "datum": "WGS84", "south": False}
    )
    geom = utm_zone_bounds(zone)

    # STAC search and load restricted to zone
    items = client.search(
        collections=["sentinel-2-global-mosaics"],
        intersects=geom,
        datetime="2024-06-01/2025-06-02",
        max_items=1,
    ).get_all_items()

    if not items:
        print(f"Skipping UTM {zone} — no data found.")
        continue

    ds_zone = stac_load(
        list(items),
        bands=["B04", "B03", "B02"],
        intersects=geom,
        crs=utm_crs,
        resolution=10,
        chunks={"x": 512, "y": 512},
        stac_cfg=cfg,
    ).isel(time=0)

    # Clip and reproject per UTM zone
    ds_clipped = ds_zone.rio.clip([geom], crs=ds_zone.rio.crs, drop=True)

    if ds_clipped.rio.width <= 1 or ds_clipped.rio.height <= 1:
        print(f"Skipping UTM Zone {zone} – no data after clip")
        continue

    ds_reproj = ds_clipped.rio.reproject(
        dst_crs=utm_crs,
        resolution=10,
        resampling=rioxarray.rio.enums.Resampling.bilinear,
    )

    out_path = f"utm_tiles/conus_utm_{zone}.tif"
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    ds_reproj.to_array("band").rio.to_raster(
        out_path,
        driver="COG",
        dtype="uint16",
        compress="LZW",
        overview_resampling=rioxarray.rio.enums.Resampling.nearest,
        tiled=True,
    )

    print(f"✅ Wrote UTM Zone {zone} to {out_path}")

### Embedding

In [1]:
Box = tuple[int, int, int, int]
Boxes = list[Box]


class TiledGeotiffDataset(torch.utils.data.Dataset):
    """Torch Dataset class that performs tiling with overlap on geotiffs."""

    def __init__(
        self,
        path: str,  # can also be a s3 url
        tile_size: int,
        tile_overlap: float = 0.0,
        transforms: Callable | None = None,
    ):
        """Initialize the dataset.

        Args:
            path: The path to the geotiff file.
            tile_size: The size of the tile.
            tile_overlap: The overlap between the tiles.
            transforms: The transforms to apply to the image.
        """
        self.path = path
        self.tile_size = tile_size
        self.tile_overlap = tile_overlap
        self.transforms = transforms

        with rasterio.open(path) as src:
            self.width, self.height = src.width, src.height
            self.crs = src.crs
            self.transform = src.transform

        self.tiles = self.get_tiles(
            img_size=(self.width, self.height),
            tile_size=tile_size,
            tile_overlap=tile_overlap,
        )
        assert len(self) > 0, "No tiles were able to be generated"

    @staticmethod
    def get_tiles(
        img_size: tuple[int, int], tile_size: int, tile_overlap: float
    ) -> Boxes:
        tiles = []
        img_w, img_h = img_size
        tile_w, tile_h = tile_size, tile_size
        stride_w = int((1 - tile_overlap) * tile_w)
        stride_h = int((1 - tile_overlap) * tile_h)
        for y in range(0, img_h - 1, stride_h):
            for x in range(0, img_w - 1, stride_w):
                x2 = x + tile_w
                y2 = y + tile_h
                tiles.append((x, y, x2, y2))
        return tiles

    def box_to_window(self, box: Box) -> rasterio.windows.Window:
        """Convert the bbox of the patch bounds to a rasterio window."""
        xmin, ymin, xmax, ymax = box
        xmax = min(xmax, self.width)
        ymax = min(ymax, self.height)
        window = rasterio.windows.Window(
            col_off=xmin, row_off=ymin, width=xmax - xmin, height=ymax - ymin
        )
        return window

    def bbox_to_coords(self, box: Box):
        xmin, ymin, xmax, ymax = box
        xmax = min(xmax, self.width)
        ymax = min(ymax, self.height)
        lon_min, lat_min = rasterio.transform.xy(
            self.transform, ymin, xmin, offset="center"
        )
        lon_max, lat_max = rasterio.transform.xy(
            self.transform, ymax, xmax, offset="center"
        )
        return (lon_min, lat_min, lon_max, lat_max)

    def crop_image(
        self,
        box: Box,
        image_size: tuple[int, int] | None = None,
        interpolation: Resampling = Resampling.bilinear,
    ) -> np.ndarray:
        xmin, ymin, xmax, ymax = box
        window = rasterio.windows.Window(
            col_off=xmin, row_off=ymin, width=xmax - xmin, height=ymax - ymin
        )
        with rasterio.open(self.path) as src:
            x = src.read(
                boundless=False,
                window=window,
                fill_value=0,
                out_shape=image_size,
                resampling=interpolation,
            )
        return x

    def __len__(self):
        return len(self.tiles)

    def __getitem__(self, index) -> dict[str, torch.Tensor]:
        box = self.tiles[index]
        x = self.crop_image(box)
        original_size = (x.shape[1], x.shape[2])
        x = np.clip(x, 0, None)
        x = torch.from_numpy(x.copy()).to(torch.float)

        if self.transforms is not None:
            x = self.transforms(x)

        box = torch.tensor(box)
        original_size = torch.tensor(original_size)
        return dict(image=x, box=box, original_size=original_size)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
model = resnet18(weights=weights)
model.eval()
transforms = T.Compose(
    [T.Normalize(mean=[0.0], std=[10000.0]), T.CenterCrop((224, 224))]
)

path = "bexar.tif"
dataset = TiledGeotiffDataset(
    path, tile_size=224, tile_overlap=0.0, transforms=transforms
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    shuffle=False,
    num_workers=16,
    pin_memory=False,
    persistent_workers=True,
)
len(dataset), len(dataloader)

In [None]:
@torch.inference_mode()
@torch.amp.autocast(device_type="cuda", dtype=torch.float16)
def embed(model, dataloader, device):
    model = model.to(device)
    embeddings, geometries = [], []
    for batch in tqdm(dataloader, total=len(dataloader)):
        image = batch["image"].to(device)
        emb = model(image).cpu().numpy()
        embeddings.append(emb)
        for box in batch["box"].numpy().tolist():
            coords = dataset.bbox_to_coords(box)
            geom = shapely.geometry.box(*coords).centroid
            geometries.append(geom)

    embeddings = np.concatenate(embeddings, axis=0)
    return embeddings, geometries


embeddings, geometries = embed(model, dataloader, device)

In [None]:
gdf = gpd.GeoDataFrame(
    data={"embedding": [emb.tolist() for emb in embeddings]},
    geometry=geometries,
    crs=dataset.crs,
)
gdf.to_crs(epsg=4326, inplace=True)
gdf.to_parquet("aoi_embeddings/bexar-county/test.parquet")