In [None]:
import os
import json
import time
import requests
import rasterio
import zipfile
from pathlib import Path
from collections import defaultdict
from datetime import datetime, timezone, timedelta
from shapely.geometry import box, mapping
from rasterio.mask import mask
from pyproj import Transformer
from dotenv import load_dotenv

# Load token
load_dotenv()
token = os.getenv("EARTHDATA_BEARER")
if not token:
    raise RuntimeError("EARTHDATA_BEARER not set in .env file.")

# Paths
input_dir = Path("../data/processed/sentinel2")
output_dir = Path("../data/masks/dem_srtmgl1")
output_dir.mkdir(parents=True, exist_ok=True)

# SRTM source
SRTM_BASE_URL = "https://e4ftl01.cr.usgs.gov/MEASURES/SRTMGL1.003/2000.02.11/"
transformer = Transformer.from_crs("EPSG:3857", "EPSG:4326", always_xy=True)

# Group files by patch_id
patch_groups = defaultdict(list)
for json_path in input_dir.glob("*.json"):
    name_parts = json_path.stem.split("_")
    patch_id = "_".join(name_parts[:2])
    patch_groups[patch_id].append(json_path)

N = len(patch_groups)
start_time = time.time()

def download_srtm_tile(lat, lon, output_dir):
    tile_lat = f"{'N' if lat >= 0 else 'S'}{abs(int(lat)):02d}"
    tile_lon = f"{'E' if lon >= 0 else 'W'}{abs(int(lon)):03d}"
    tile_name = f"{tile_lat}{tile_lon}"
    zip_filename = f"{tile_name}.SRTMGL1.hgt.zip"
    tif_filename = f"{tile_name}.hgt"

    zip_url = f"{SRTM_BASE_URL}{zip_filename}"
    local_zip = output_dir / zip_filename
    local_tif = output_dir / tif_filename

    if local_tif.exists():
        return local_tif

    print(f"⬇️ Downloading {zip_filename}...")
    headers = {"Authorization": f"Bearer {token}"}
    r = requests.get(zip_url, stream=True, headers=headers)
    if r.status_code != 200:
        raise Exception(f"Download failed: {zip_url} (status {r.status_code})")

    with open(local_zip, "wb") as f:
        for chunk in r.iter_content(chunk_size=8192):
            f.write(chunk)

    with zipfile.ZipFile(local_zip, "r") as z:
        z.extractall(output_dir)
    os.remove(local_zip)
    return local_tif

def save_metadata(patch_id, out_name, bounds):
    meta = {
        "patch_id": patch_id,
        "status": "success",
        "elevation_source": "SRTMGL1 (NASA Earthdata)",
        "elevation_resolution_m": 30,
        "bbox_wgs84": {
            "north": bounds[3], "south": bounds[2],
            "east": bounds[1], "west": bounds[0]
        },
        "query_date": datetime.now(timezone.utc).isoformat()
    }
    with open(output_dir / f"{out_name}.json", "w") as f:
        json.dump(meta, f, indent=2)

# Process each patch group
for idx, (patch_id, files) in enumerate(patch_groups.items(), 1):
    print(f"\n[{idx}/{N}] Processing patch group: {patch_id}")
    first_json = files[0]
    with open(first_json) as f:
        meta = json.load(f)

    # Get bounds
    uly, ulx = meta["bbox"]["upper_left"]
    lry, lrx = meta["bbox"]["bottom_right"]
    west, north = transformer.transform(ulx, uly)
    east, south = transformer.transform(lrx, lry)
    if west > east: west, east = east, west
    if south > north: south, north = north, south
    bounds = (west, east, south, north)

    try:
        # Download SRTM tile
        lat_center = (south + north) / 2
        lon_center = (west + east) / 2
        tile_path = download_srtm_tile(lat_center, lon_center, output_dir)

        # Mask the DEM
        with rasterio.open(tile_path) as src:
            patch_geom = box(*bounds)
            out_image, out_transform = mask(src, [mapping(patch_geom)], crop=True)
            profile = src.profile.copy()
            profile.update({
                "driver": "GTiff",
                "dtype": "uint16",
                "count": 1,
                "height": out_image.shape[1],
                "width": out_image.shape[2],
                "transform": out_transform
            })

            # Write one file for each timestamped variant
            for json_path in files:
                ts_part = json_path.stem.replace(patch_id + "_", "")
                out_name = f"dem_{patch_id}_{ts_part}"
                out_tif = output_dir / f"{out_name}.tiff"
                with rasterio.open(out_tif, "w", **profile) as dst:
                    dst.write(out_image.astype("uint16"))
                save_metadata(patch_id, out_name, bounds)
                print(f"   🗺️ Saved: {out_tif.name}")

    except Exception as e:
        print(f"❌ Error for patch {patch_id}: {e}")

    # Logging ETA
    elapsed = time.time() - start_time
    avg_time = elapsed / idx
    eta = avg_time * (N - idx)
    print(f"   ⏱ {1/avg_time:.2f} patches/sec | ETA: {str(timedelta(seconds=int(eta)))}")
    time.sleep(1)


In [None]:
import matplotlib.pyplot as plt
import rasterio
from pathlib import Path

# Change this to any tile you downloaded
tile_path = Path("../data/masks/dem_srtmgl1/N08E115.hgt")  # Example for Bali

with rasterio.open(tile_path) as src:
    dem = src.read(1)
    plt.figure(figsize=(10, 8))
    plt.imshow(dem, cmap="terrain")
    plt.colorbar(label="Elevation (m)")
    plt.title("Raw SRTMGL1 Tile: N08E115")
    plt.xlabel("Pixels")
    plt.ylabel("Pixels")
    plt.show()
