# Create MSS coastal DEM

In [None]:
import subprocess

import dask.distributed as dd
import numpy as np
import pandas as pd
import pyinterp.backends.xarray as pbx
import rhg_compute_tools.gcs as rhgcs
import rhg_compute_tools.kubernetes as rhgk
import rhg_compute_tools.utils as rhgu
import xarray as xr
from shapely.geometry import box

from sliiders import settings as sset

Define elevation-processing functions

In [None]:
@rhgu.block_globals
def get_grid_at_tile(da, grid):
    """
    Get interpolated datum tile in the same shape as `da` using `pbx.Grid2D`
    """
    buffer = 0.2

    # Ensure tiles along the 180 meridian have coordinates defined contiguously
    if da.x[-1].item() > 179:
        new_lons = grid.lon.values
        new_lons[new_lons < -179] = new_lons[new_lons < -179] + 360
        grid = grid.assign_coords({"lon": new_lons})
    elif da.x[0].item() < -179:
        new_lons = grid.lon.values
        new_lons[new_lons > 179] = new_lons[new_lons > 179] - 360
        grid = grid.assign_coords({"lon": new_lons})

    grid = grid.isel(
        lon=(grid.lon >= da.x[0] - buffer) & (grid.lon <= da.x[-1] + buffer),
        lat=(grid.lat >= da.y[-1] - buffer) & (grid.lat <= da.y[0] + buffer),
    ).load()

    grid = grid.sortby("lon")

    grid.lon.attrs["units"] = "degrees_east"
    grid.lat.attrs["units"] = "degrees_north"

    interpolator = pbx.Grid2D(grid, geodetic=True)

    mx, my = np.meshgrid(da.x.values, da.y.values, indexing="ij")

    out = interpolator.bicubic(dict(lon=mx.flatten(), lat=my.flatten()))

    out = out.reshape(mx.shape)
    out = xr.DataArray(out).rename({"dim_0": "x", "dim_1": "y"})
    out["x"] = da.x.values
    out["y"] = da.y.values

    return out


@rhgu.block_globals
def get_bbox(tile_name):
    """
    Return bounding box from tile name in the string format "VXXHYYY" representing the southwestern corner of a 1-degree tile,
    where "V" is "N" (north) or "S" (south), "H" is "E" (east) or "W" (west), "XX" is a two-digit zero-padded number indicating
    the number of degrees north or south from 0,0, and "YYY" is a three-digit zero-padded number indicating the number of degrees
    east or west from 0,0.
    """
    lat_term, lon_term = tile_name[:3], tile_name[3:]

    lat_direction, lat_value = lat_term[0], int(lat_term[1:])
    lon_direction, lon_value = lon_term[0], int(lon_term[1:])

    lat_sign = 1 if lat_direction == "N" else -1
    lon_sign = 1 if lon_direction == "E" else -1

    llat = lat_sign * lat_value
    llon = lon_sign * lon_value

    ulat = llat + 1
    ulon = llon + 1

    return box(llon, llat, ulon, ulat)


@rhgu.block_globals
def get_tile_path(tile):
    """Get raw CoastalDEM tile path"""
    return sset.DIR_COASTALDEM / f"{tile}.tif"


@rhgu.block_globals
def get_elev_tile(
    tile_name, bbox, use_coastal_dem=True, egm96_xgm2019e=None, mdt=None, cap=None
):
    """
    Get 1-arcsec elevation tile relative to MSS.
    Use CoastalDEM where available, replacing null areas with SRTM15+.
    Get elevations relative to MSS datum using EGM96 -> XGM2019e and Mean Dynamic Topography (MDT) datum transformations
    (both provided by Aviso).
    """
    llon, llat, ulon, ulat = bbox.bounds
    if use_coastal_dem:
        # load tile
        tile_path = get_tile_path(tile_name)
        elev_tile = xr.open_rasterio(tile_path).sel(band=1).drop("band")
        elev_tile.load()

        # handle tiles with inaccurately bottom-left .1-degree metadata
        if elev_tile["y"].values.max() - elev_tile["y"].values.min() < 0.9:
            elev_tile["y"] = (
                elev_tile["y"].values.min()
                + (elev_tile["y"].values - elev_tile["y"].values.min()) * 10
            )
            elev_tile["x"] = (
                elev_tile["x"].values.min()
                + (elev_tile["x"].values - elev_tile["x"].values.min()) * 10
            )

    # open our "main DEM" (to fill in missing pixels in CoastalDEM)
    with xr.open_dataarray(sset.PATH_SRTM15_PLUS) as srtm:

        srtm_buffer = 0.01

        # Ensure tiles along the 180 meridian have coordinates defined contiguously
        if llon == 179:
            new_lons = srtm.lon.values
            new_lons[new_lons < -179] = new_lons[new_lons < -179] + 360
            srtm = srtm.assign_coords({"lon": new_lons})
        elif ulon == -179:
            new_lons = srtm.lon.values
            new_lons[new_lons > 179] = new_lons[new_lons > 179] - 360
            srtm = srtm.assign_coords({"lon": new_lons})

        # fill NaNs with SRTM
        this_srtm = srtm.isel(
            lon=(srtm.lon >= llon - srtm_buffer) & (srtm.lon <= ulon + srtm_buffer),
            lat=(srtm.lat >= llat - srtm_buffer) & (srtm.lat <= ulat + srtm_buffer),
        )

        this_srtm = this_srtm.sortby("lon")
        this_srtm.load()

        if use_coastal_dem:
            srtm_interp = this_srtm.rename({"lon": "x", "lat": "y"}).interp_like(
                elev_tile, method="linear", assume_sorted=True
            )
            # -32767 means SRTM input to coastalDEM was missing (we have previously filled this in
            # our master DEM)
            # -9999 means outside of a particular spatial domain for coastalDEM
            elev_tile = elev_tile.where(~elev_tile.isin([-32767, -9999])).fillna(
                srtm_interp
            )
            # 0 is where coastalDEM is "underwater"
            elev_tile = elev_tile.where(elev_tile != 0, np.nan)
        else:
            grid_width = 3600
            size = 1 / grid_width

            lons_small = np.arange(llon + (size / 2), ulon, size)
            lats_small = np.flip(np.arange(llat + (size / 2), ulat, size))

            srtm_interp = this_srtm.rename({"lon": "x", "lat": "y"}).interp(
                {"x": lons_small, "y": lats_small}, method="linear", assume_sorted=True
            )
            elev_tile = srtm_interp

    # Datum transformations
    if (egm96_xgm2019e is None) or (mdt is None):
        with xr.open_zarr(sset.PATH_GEOG_DATUMS_GRID, consolidated=True) as datum_grid:
            mdt = datum_grid.mdt
            egm96_xgm2019e = datum_grid.egm96_xgm2019e

    egm96_xgm2019e_interp = get_grid_at_tile(elev_tile, egm96_xgm2019e)
    mdt_interp = get_grid_at_tile(elev_tile, mdt)
    elev_tile = elev_tile + egm96_xgm2019e_interp
    elev_tile = elev_tile - mdt_interp

    # Bundle higher-than-coastal elevation values into one to simplify later data processing
    if cap is not None:
        elev_tile = xr.where(elev_tile > cap, cap, elev_tile)

    return elev_tile


@rhgu.block_globals
def process_tile(
    tile_name,
    egm96_xgm2019e=None,
    mdt=None,
):
    bbox = get_bbox(tile_name)

    # get coastalDEM tile, filled with SRTM, relative to MSS
    tile_path = get_tile_path(tile_name)
    elev_tile = get_elev_tile(
        tile_name,
        bbox,
        use_coastal_dem=tile_path.exists(),
        egm96_xgm2019e=egm96_xgm2019e,
        mdt=mdt,
    )

    elev_tile = elev_tile.astype(np.float32)

    path_out_tmp = sset.DIR_MSS / f"{tile_name}_tmp.tif"
    path_out = sset.DIR_MSS / f"{tile_name}.tif"
    elev_tile.rio.to_raster(path_out_tmp)

    cmd_cp = f"gdal_translate {str(path_out_tmp)} {str(path_out)} -co COMPRESS=DEFLATE -co PREDICTOR=3"
    cmd_rm = f"rm {str(path_out_tmp)}"

    subprocess.run(cmd_cp.split(" "), capture_output=True)
    subprocess.run(cmd_rm.split(" "), capture_output=True)

Get list of tiles to process

In [None]:
tile_meta = pd.read_parquet(sset.PATH_EXPOSURE_TILE_LIST)
coastal_tiles = tile_meta.loc[
    tile_meta["PROCESSING_SET"].isin(["CIAM", "WITHELEV"]), "tile_name"
].to_numpy()

sset.DIR_MSS.mkdir(exist_ok=True)
finished_tiles = [t[:-4] for t in rhgcs.ls(sset.DIR_MSS)]

coastal_tiles = [t for t in coastal_tiles if t not in finished_tiles]

len(coastal_tiles)

Start up cluster

In [None]:
client, cluster = rhgk.get_micro_cluster()
cluster.scale(24)

cluster

In [None]:
import zipfile
from sliiders import __file__
from pathlib import Path
import os

sliiders_dir = Path(__file__).parent
zipf = zipfile.ZipFile("sliiders.zip", "w", zipfile.ZIP_DEFLATED)
for root, dirs, files in os.walk(sliiders_dir):
    for file in files:
        zipf.write(
            os.path.join(root, file),
            os.path.relpath(os.path.join(root, file), os.path.join(sliiders_dir, "..")),
        )
zipf.close()
client.upload_file("sliiders.zip")

Load datum grids onto workers

In [None]:
with xr.open_zarr(sset.PATH_GEOG_DATUMS_GRID, consolidated=True) as datum_grid:
    egm96_xgm2019e = datum_grid.egm96_xgm2019e.load()
    mdt = datum_grid.mdt.load()

egm96_xgm2019e_fut = client.scatter(egm96_xgm2019e, broadcast=True)
mdt_fut = client.scatter(mdt, broadcast=True)

Run tiles on workers

In [None]:
fut = client.map(
    process_tile, coastal_tiles, egm96_xgm2019e=egm96_xgm2019e_fut, mdt=mdt_fut
)
dd.progress(fut)

Close cluster

In [None]:
cluster.close()
client.close()