In [1]:
# pip install osmnx geopandas rasterio shapely

In [None]:
import os
import json
import time
import numpy as np
import tifffile
import warnings
from pathlib import Path
from datetime import timezone, timedelta, datetime
from shapely.geometry import box, mapping
from rasterio.features import rasterize
import rasterio.transform
from pyproj import Transformer
from ohsome import OhsomeClient
from shapely.geometry import mapping


# Suppress geopandas projection warnings
warnings.filterwarnings("ignore", message=".*GeoDataFrame.*")

# Initialize Ohsome client
client = OhsomeClient()

# Paths
metadata_dir = Path("../data/processed/sentinel2")
mask_base_dir = Path("../data/masks")
road_dir = mask_base_dir / "road"
building_dir = mask_base_dir / "building"
osm_meta_dir = mask_base_dir / "osm_metadata"
road_dir.mkdir(parents=True, exist_ok=True)
building_dir.mkdir(parents=True, exist_ok=True)
osm_meta_dir.mkdir(parents=True, exist_ok=True)

image_shape = (1024, 1024)
transformer = Transformer.from_crs("epsg:3857", "epsg:4326", always_xy=True)

def save_osm_metadata(patch_id, timestamp, base_name, status, roads_len, buildings_len):
    meta_record = {
        "patch_id": patch_id,
        "timestamp": timestamp,
        "osm_query_date": datetime.now(timezone.utc).isoformat(),
        "status": status,
        "features": {
            "roads": int(roads_len),
            "buildings": int(buildings_len)
        },
        "osm_source": "ohsome",
        "query_type": "elementsFullGeometry"
    }
    meta_path = osm_meta_dir / f"{base_name}_osm.json"
    with open(meta_path, "w") as f:
        json.dump(meta_record, f, indent=2)

def rasterize_features(gdf, bbox_polygon, shape):
    if gdf.empty:
        return np.zeros(shape, dtype=np.uint8)
    return rasterize(
        ((geom, 1) for geom in gdf.geometry),
        out_shape=shape,
        transform=rasterio.transform.from_bounds(*bbox_polygon.bounds, shape[1], shape[0]),
        fill=0,
        dtype=np.uint8
    )

def get_error_message(e):
    return str(e) if not e.args else str(e.args[0])

from requests.exceptions import RequestException
import random

MAX_RETRIES = 5

def query_ohsome_safe(client, bpolys, date_str, feature_filter):
    for attempt in range(1, MAX_RETRIES + 1):
        try:
            return client.elements.geometry.post(
                bpolys=bpolys,
                time=date_str,
                filter=feature_filter,
                properties="tags"
            )
        except Exception as e:
            err_msg = get_error_message(e)
            wait = min(30, 2 ** attempt + random.uniform(0, 3))
            print(f"⚠️ Retry {attempt} after error: {err_msg} | waiting {wait:.1f}s")
            time.sleep(wait)
    raise RuntimeError(f"❌ Failed to get {feature_filter} after {MAX_RETRIES} attempts.")


# Process each patch
json_files = sorted(metadata_dir.glob("*.json"))
N = len(json_files)
start_time = time.time()

for idx, json_path in enumerate(json_files, start=1):
    with open(json_path) as f:
        meta = json.load(f)

    patch_id = meta["patch_id"]
    timestamp = meta["timestamp"]
    timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
    timestamp_str = timestamp.replace(":", "").replace("-", "").replace("T", "_").replace("+", "_").replace(".", "_")
    base_name = f"{patch_id}_{timestamp_str}"

    ulx, uly = meta["bbox"]["upper_left"]
    lrx, lry = meta["bbox"]["bottom_right"]
    ul_lon, ul_lat = transformer.transform(ulx, uly)
    lr_lon, lr_lat = transformer.transform(lrx, lry)
    bbox_polygon = box(ul_lon, lr_lat, lr_lon, ul_lat)
    bpolys = bbox_polygon  # ✅ Direct Shapely geometry


    out_road = road_dir / f"{base_name}_road_mask.tiff"
    out_building = building_dir / f"{base_name}_building_mask.tiff"

    if out_road.exists() and out_building.exists():
        print(f"[{idx}/{N}] Skipping (already exists): {base_name}")
        continue

    try:
        # Timestamp for Ohsome query
        date_str = timestamp_dt.strftime("%Y-%m-%d")

        # Throttle requests: prevent 503s
        time.sleep(3 + 0.1 * idx)

        # Perform Ohsome query
        roads = query_ohsome_safe(client, bbox_polygon, date_str, "highway=*")
        buildings = query_ohsome_safe(client, bbox_polygon, date_str, "building=*")
        # Handle responses
        if roads is None or buildings is None:
            raise ValueError("Ohsome returned None")

        roads_gdf = roads.as_geodataframe()
        buildings_gdf = buildings.as_geodataframe()

        if roads_gdf.empty and buildings_gdf.empty:
            print(f"[{idx}/{N}] No OSM features found for {base_name} at {date_str}")
            save_osm_metadata(patch_id, timestamp, base_name, "no_features_found", 0, 0)
            continue

    except Exception as e:
        err_msg = get_error_message(e)
        print(f"[{idx}/{N}] ❌ Ohsome query failed: {base_name} | {err_msg}")
        save_osm_metadata(patch_id, timestamp, base_name, f"query_failed: {err_msg}", 0, 0)
        time.sleep(10)
        continue

    # Rasterize and save
    road_mask = rasterize_features(roads_gdf, bbox_polygon, image_shape)
    building_mask = rasterize_features(buildings_gdf, bbox_polygon, image_shape)
    tifffile.imwrite(out_road, road_mask.astype(np.uint8))
    tifffile.imwrite(out_building, building_mask.astype(np.uint8))

    save_osm_metadata(patch_id, timestamp, base_name, "success", len(roads_gdf), len(buildings_gdf))

    # Progress
    elapsed = time.time() - start_time
    avg_time = elapsed / idx
    eta = avg_time * (N - idx)
    print(f"[{idx}/{N}] {base_name} saved | {1/avg_time:.2f} img/sec | ETA: {str(timedelta(seconds=int(eta)))}")


⚠️ Retry 1 after error: message | waiting 4.9s
⚠️ Retry 2 after error: Keyboard Interrupt: Query was interrupted by the user. | waiting 5.1s
⚠️ Retry 3 after error: message | waiting 9.8s
