# Create MSS coastal DEM

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import dask.array as da
import geopandas as gpd
import numpy as np
import pandas as pd
import pyinterp.backends.xarray as pbx
import regionmask
import rioxarray
import xarray as xr
from dask_gateway import GatewayCluster
from shapely.geometry import box
from sliiders import settings as sset
from sliiders import spatial
from sliiders.dask import upload_sliiders
from sliiders.io import open_dataarray, open_zarr, save

  from distributed.utils import LoopRunner, format_bytes


In [3]:
N_WORKERS = 500
N_PIXELS_PER_TILE = 3601

AUTHOR = "Ian Bolliger"
CONTACT = "ian.bolliger@blackrock.com"
DESCRIPTION = (
    "Blended DEM for coastal regions, relative to a Mean Sea Level DEM (1993-2012)"
)
METHOD = "CoastalDEM2.1 is the primary data source. Areas marked as missing or as water in CoastalDEM are infilled with SRTM15+ v2.4. AVISO MDT+ is used to convert from the orthometric datum to a MSL datum. Note: CoastalDEM uses 1/3601 degrees rather than 1/3600 (1 arc-second)."
HISTORY = """version 2.1: Associated with CoastalDEM 2.1 and SRTM15+ v2.4
version 2.1.1: Updated pixel source data to define SRTM15+ ocean vs. inland pixels. Dropped datum transformation from EGM96 to XGM2019_e b/c unsure of whether there is some smoothing done to SRTM specific to EGM that, upon transformation, would result in some spatial noise. Updated list of tiles to drop tiles that will wind up having elevations too high to matter for coastal regions and to include potentially-inland CoastalDEM tiles b/c we now will filter these later based on hydraulic connectivity. Added int_res field to denote pixels that had integer resolution (i.e. from SRTM15+) such that later on we can smooth the exposure in these pixels over a 1m distribution."""

OUT_ATTRS = {
    "z": {"units": "m", "long_name": "elevation relative to MSL 1993-2012"},
    "source": {
        "long_name": "data source",
        "description": """0: CoastalDEM
1: CoastalDEM water pixels replaced with SRTM15+
2: CoastalDEM missing or out-of-range pixels infilled with SRTM15+
3: Non-CoastalDEM tile (SRTM15+) ocean pixel
4: Non-CoastalDEM tile (SRTM15+) inland pixel""",
    },
    "int_res": {
        "long_name": "Integer resolution flag",
        "description": "True for pixels that have integer elevation resolution.",
    },
}

In [4]:
cluster = GatewayCluster(profile="micro", idle_timeout=1800)
client = cluster.get_client()
upload_sliiders(client)
cluster.adapt(minimum=7, maximum=N_WORKERS)
cluster

VBox(children=(HTML(value='<h2>GatewayCluster</h2>'), HBox(children=(HTML(value='\n<div>\n<style scoped>\n    …

Define elevation-processing functions

In [5]:
def get_grid_at_tile(da, grid_in):
    """
    Get interpolated datum tile in the same shape as `da` using `pbx.Grid2D`
    """
    buffer = 0.2

    grid = grid_in.copy()

    # Ensure tiles along the 180 meridian have coordinates defined contiguously
    if da.x[-1].item() > 179:
        new_lons = grid.lon.values
        new_lons[new_lons < -179] = new_lons[new_lons < -179] + 360
        grid = grid.assign_coords({"lon": new_lons})
    elif da.x[0].item() < -179:
        new_lons = grid.lon.values
        new_lons[new_lons > 179] = new_lons[new_lons > 179] - 360
        grid = grid.assign_coords({"lon": new_lons})

    grid = grid.isel(
        lon=(grid.lon >= da.x.min().item() - buffer)
        & (grid.lon <= da.x.max().item() + buffer),
        lat=(grid.lat >= da.y.min().item() - buffer)
        & (grid.lat <= da.y.max().item() + buffer),
    ).load()

    grid = grid.sortby("lon")

    grid.lon.attrs["units"] = "degrees_east"
    grid.lat.attrs["units"] = "degrees_north"

    interpolator = pbx.Grid2D(grid, geodetic=True)

    mx, my = np.meshgrid(da.x.values, da.y.values, indexing="ij")

    out = interpolator.bicubic(dict(lon=mx.flatten(), lat=my.flatten()))

    out = out.reshape(mx.shape)
    out = xr.DataArray(out).rename({"dim_0": "x", "dim_1": "y"})
    out["x"] = da.x.values
    out["y"] = da.y.values

    return out


def get_tile_path(tile, fuse=False):
    """Get raw CoastalDEM tile path"""
    out = sset.DIR_COASTALDEM / f"{tile}.tif"
    if fuse:
        out = str(out).replace("gs:/", "/gcs")
    return out


def get_lonlat_range(lon_min, lat_min, lon_max, lat_max):
    return np.arange(
        lon_min + 0.5 / N_PIXELS_PER_TILE, lon_max, 1 / N_PIXELS_PER_TILE
    ), np.arange(lat_min + 0.5 / N_PIXELS_PER_TILE, lat_max, 1 / N_PIXELS_PER_TILE)


def get_elev_tile(
    tile_name,
    llon,
    llat,
    ulon,
    ulat,
    use_coastal_dem=True,
    egm96_xgm2019e=None,
    ocean_geom=None,
    mdt=None,
    cap=None,
):
    """
    Get 1-arcsec elevation tile relative to MSS.
    Use CoastalDEM where available, replacing null areas with SRTM15+.
    Get elevations relative to MSS datum using EGM96 -> XGM2019e and Mean Dynamic
        Topography (MDT) datum transformations.
    """
    if use_coastal_dem:
        # load tile
        tile_path = get_tile_path(tile_name, fuse=True)
        elev_tile = (
            rioxarray.open_rasterio(
                tile_path,
                mask_and_scale=True,
            )
            .squeeze(drop=True)
            .load()
        )

        # handle tiles with inaccurately bottom-left .1-degree metadata
        # (this was an issue with v1.1 for some tiles, I don't think it is for v2.1)
        if elev_tile["y"].values.max() - elev_tile["y"].values.min() < 0.9:
            elev_tile["y"] = (
                elev_tile["y"].values.min()
                + (elev_tile["y"].values - elev_tile["y"].values.min()) * 10
            )
            elev_tile["x"] = (
                elev_tile["x"].values.min()
                + (elev_tile["x"].values - elev_tile["x"].values.min()) * 10
            )

    # open our "main DEM" (to fill in missing pixels in CoastalDEM)
    with open_dataarray(sset.PATH_SRTM15_PLUS) as srtm:

        srtm_buffer = 0.01

        # Ensure tiles along the 180 meridian have coordinates defined contiguously
        if llon == 179:
            new_lons = srtm.lon.values
            new_lons[new_lons < -179] = new_lons[new_lons < -179] + 360
            srtm = srtm.assign_coords({"lon": new_lons})
        elif ulon == -179:
            new_lons = srtm.lon.values
            new_lons[new_lons > 179] = new_lons[new_lons > 179] - 360
            srtm = srtm.assign_coords({"lon": new_lons})

        # fill NaNs with SRTM
        this_srtm = (
            srtm.isel(
                lon=(srtm.lon >= llon - srtm_buffer) & (srtm.lon <= ulon + srtm_buffer),
                lat=(srtm.lat >= llat - srtm_buffer) & (srtm.lat <= ulat + srtm_buffer),
            )
            .sortby("lon")
            .load()
        )

        if use_coastal_dem:
            srtm_interp = this_srtm.rename({"lon": "x", "lat": "y"}).interp_like(
                elev_tile, method="linear", assume_sorted=True
            )
            # -32767 means SRTM input to coastalDEM was missing (we have previously filled this in
            # our master DEM)
            # -9999 means outside of a particular spatial domain for coastalDEM
            pixel_src = xr.zeros_like(elev_tile, dtype="uint8")

            mask = elev_tile.notnull() & ~elev_tile.isin([-32767, -9999])

            pixel_src = pixel_src.where(mask, 2)
            elev_tile = elev_tile.where(mask, srtm_interp)

            # 0 is where coastalDEM is "underwater". Also fill these with SRTM15+
            pixel_src = pixel_src.where(elev_tile != 0, 1)
            elev_tile = elev_tile.where(elev_tile != 0, np.minimum(srtm_interp, 0))
        else:
            grid_width = N_PIXELS_PER_TILE
            size = 1 / grid_width

            lons_small = np.arange(llon + (size / 2), ulon, size)
            lats_small = np.arange(llat + (size / 2), ulat, size)

            srtm_interp = this_srtm.rename({"lon": "x", "lat": "y"}).interp(
                {"x": lons_small, "y": lats_small},
                method="linear",
                assume_sorted=True,
            )
            elev_tile = srtm_interp
            ocean_pixels = regionmask.mask_geopandas(
                ocean_geom, srtm_interp, lon_name="x", lat_name="y"
            ).notnull()
            pixel_src = (xr.ones_like(elev_tile, dtype="uint8") * 3).where(
                ocean_pixels & (elev_tile <= 0), 4
            )

    # figure out pixels that had integer resolution. Will be used to smooth out elev
    # distribution of exposure later on. Needs to be done before converting to MSS
    int_res = (pixel_src != 0) & (np.isclose(elev_tile, elev_tile.astype("int")))

    # Datum transformations
    if (egm96_xgm2019e is None) or (mdt is None):
        with open_zarr(sset.PATH_GEOG_DATUMS_GRID, chunks=None) as datum_grid:
            mdt = datum_grid.mdt
            egm96_xgm2019e = datum_grid.egm96_xgm2019e

    mdt_interp = get_grid_at_tile(elev_tile, mdt)
    # only adjust CoastalDEM and land-pixels of SRTM15+. Ocean pixels of SRTM15+ are
    # nominally in a MSL datum
    elev_tile -= (mdt_interp).where(pixel_src.isin([0, 4]), 0)

    # Bundle higher-than-coastal elevation values into one to simplify later data processing
    if cap is not None:
        elev_tile = xr.where(elev_tile > cap, cap, elev_tile)

    assert elev_tile.notnull().all()

    return xr.Dataset({"z": elev_tile, "source": pixel_src, "int_res": int_res}).rename(
        x="lon", y="lat"
    )


def process_tile(
    tile_name,
    mdt=None,
    ocean_geom=None,
    check=True,
):
    lon_min, lat_min = spatial.get_ll(tile_name)
    lons, lats = get_lonlat_range(lon_min, lat_min, lon_min + 1, lat_min + 1)
    lat_slice = slice(
        (90 + lat_min) * N_PIXELS_PER_TILE, (91 + lat_min) * N_PIXELS_PER_TILE
    )
    lon_slice = slice(
        (180 + lon_min) * N_PIXELS_PER_TILE, (181 + lon_min) * N_PIXELS_PER_TILE
    )

    if (
        check
        and open_zarr(sset.PATH_ELEV_MSS, chunks=None)
        .z.isel(lon=lon_slice, lat=lat_slice)
        .notnull()
        .any()
    ):
        return None

    # get coastalDEM tile, filled with SRTM, relative to MSS
    tile_path = get_tile_path(tile_name)
    elev_tile = get_elev_tile(
        tile_name,
        lon_min,
        lat_min,
        lon_min + 1,
        lat_min + 1,
        use_coastal_dem=tile_path.is_file(),
        mdt=mdt,
        ocean_geom=ocean_geom,
    )

    # this can happen in v1.1. I don't think it can with v2.1
    assert (elev_tile.lat.size == len(lats)) and (elev_tile.lon.size == len(lons))
    # commented out section for v1.1 below
    # if not ((elev_tile.lat.size == len(lats)) and (elev_tile.lon.size == len(lons))):
    #     elevs = elev_tile.z.interp(lat=lats, lon=lons, method="linear")
    #     sources = elev_tile.source.interp(lat=lats, lon=lons, method="nearest").astype(
    #         "uint8"
    #     )
    #     elev_tile = xr.Dataset({"z": elevs, "source": sources})

    elev_tile["z"] = elev_tile.z.astype("float32")

    # ensure lats are increasing
    elev_tile = elev_tile.sortby("lat")

    # ignore matching lats and lons when saving in case there is a floating point diff
    save(
        elev_tile.drop_vars(elev_tile.coords),
        sset.PATH_ELEV_MSS,
        region={
            "lat": lat_slice,
            "lon": lon_slice,
        },
    )

## Create template zarr

In [6]:
lon, lat = get_lonlat_range(-180, -90, 180, 90)
out_arr = da.empty(
    (N_PIXELS_PER_TILE * 180, N_PIXELS_PER_TILE * 360),
    chunks=(N_PIXELS_PER_TILE, N_PIXELS_PER_TILE),
    dtype="float32",
)
src_arr = out_arr.astype("uint8")
z = xr.DataArray(
    out_arr, dims=["lat", "lon"], coords={"lon": lon, "lat": lat}, attrs=OUT_ATTRS["z"]
)
src = xr.DataArray(
    out_arr.astype("uint8"),
    dims=["lat", "lon"],
    coords={"lon": lon, "lat": lat},
    attrs=OUT_ATTRS["source"],
)
int_res = xr.DataArray(
    out_arr.astype(bool),
    dims=["lat", "lon"],
    coords={"lon": lon, "lat": lat},
    attrs=OUT_ATTRS["int_res"],
)

In [7]:
save(
    xr.Dataset(
        {"z": z, "source": src, "int_res": int_res},
        attrs={
            "author": AUTHOR,
            "contact": CONTACT,
            "description": DESCRIPTION,
            "method": METHOD,
            "history": HISTORY,
        },
    ),
    sset.PATH_ELEV_MSS,
    compute=False,
)

Get list of tiles to process

In [8]:
coastal_tiles = pd.read_parquet(
    sset.PATH_EXPOSURE_TILE_LIST,
    filters=[("PROCESSING_SET", "in", ["CIAM", "WITHELEV"])],
).index.values
len(coastal_tiles)

8926

Load datum grids onto cluster

In [9]:
def load_mdt():
    return open_zarr(
        sset.PATH_GEOG_DATUMS_GRID,
        chunks=None,
    ).mdt.load()


def load_ocean_geom():
    return (
        gpd.read_parquet(sset.PATH_NATEARTH_OCEAN_NOCASPIAN)
        .buffer(1)
        .clip(box(-180, -90, 180, 90))
    )

In [10]:
mdt_fut = client.submit(load_mdt)
ocean_fut = client.submit(load_ocean_geom)

Run on workers

In [11]:
fut = client.map(
    process_tile,
    coastal_tiles,
    mdt=mdt_fut,
    ocean_geom=ocean_fut,
    batch_size=1000,
)

Close cluster

In [14]:
finished = False
client.gather(fut)
finished = True

In [15]:
cluster.close()
client.close()