In [None]:
import hydromt
import xarray as xr
import numpy as np
from os.path import join
import geopandas as gpd
import pandas as pd

In [None]:
events = {'Idai': '20190320', 'Eloise':'20210125'}
mdir = r'../../3_models/SFINCS'

base ='01_rivpowlaw'
dep_mask = hydromt.open_raster(join(mdir, base, 'gis', 'dep.tif'))==-9999
riv_mask = hydromt.open_raster(join(mdir, base, 'gis', 'rivmsk.tif'))
riv_mask.raster.set_nodata(0)
gdf_riv = riv_mask.where(~dep_mask, 0).raster.vectorize()
gdf_towns = gpd.read_file(r'../../1_data/towns.geojson').set_index('index')


In [None]:
hmin=0.15
dh = 0.05
base ='02_rivpowlaw_hc0.405'

da_sim, da_diff, da_cmpnd = {}, {}, {}
for event in events:
    da = xr.open_dataarray(join(mdir, f'flddph_{event.lower()}_100m_max.nc'))
    assert base in da.run.values
    # select and rename runs to scenarios
    runs = {r:r.replace(f'{base}_', '').replace(base, 'qhp') for r in da.run.values if (r.startswith(base) and 'glofas' not in r)}
    da = da.sel(run=list(runs.keys()))
    da = da.assign_coords(run = xr.IndexVariable(dims='run', data=np.array(list(runs.values())))).rename({'run':'scen'})
    da = da.where(da.sel(scen='qhp')>hmin)
    da_sim[event] = da

    # waterlevel diff
    da_single_max = da.sel(scen=['q', 'h', 'p']).max('scen')  #'base'
    da1 = (da.sel(scen='qhp') - da_single_max).compute()
    da1.name = 'diff. in waterlevel\ncompound - max. single driver'
    da1.attrs.update(unit='m')
    da_diff[event] = da1

    # find compound drivers
    compound_mask = da1 > dh
    surge_mask = da.sel(scen='h') > da.sel(scen=['p', 'q']).max('scen')
    discharge_mask = da.sel(scen='q') > da.sel(scen=['h', 'p']).max('scen')
    precip_mask = da.sel(scen='p') > da.sel(scen=['h', 'q']).max('scen') 
    precip_mask = np.logical_and(precip_mask, da1>=0)
    assert ~np.logical_and(precip_mask, surge_mask).any() and ~np.logical_and(discharge_mask, surge_mask).any() and ~np.logical_and(discharge_mask, precip_mask).any()
    da2 = (
        xr.where(surge_mask, compound_mask+1, 0)
        + xr.where(discharge_mask, compound_mask + 3, 0)
        + xr.where(precip_mask, compound_mask + 5, 0)
    ).compute()
    da2.name = None
    da_cmpnd[event] = da2

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from matplotlib import colors, patheffects
from string import ascii_lowercase as abcd

# read crs and utm zone > convert to cartopy
wkt = da.raster.crs.to_wkt()
utm_zone = da.raster.crs.to_wkt().split("UTM zone ")[1][:3]
utm = ccrs.UTM(int(utm_zone[:2]), "S" in utm_zone)
extent = np.array(da.raster.box.buffer(100).total_bounds)[[0, 2, 1, 3]]

cmap = mpl.cm.get_cmap('tab20c')
levels = np.arange(1,8)
colors = np.array(cmap.colors)[[2,0,14,12,10,8]]
cmap, norm = mpl.colors.from_levels_and_colors(levels, colors)

ann_kwargs = dict(
    xytext=(4, 0),
    textcoords="offset points",
    zorder=4,
    path_effects=[
        patheffects.Stroke(linewidth=3, foreground="w"),
        patheffects.Normal(),
    ],
)

n = len(events.keys())
fig, axs = plt.subplots(
#     nrows=2, figsize=(10,15),
    nrows=2, figsize=(4.5*n, 12),
    ncols=n,
    subplot_kw={'projection': utm},
    sharex = True, sharey=True, 
)
axs = axs.flatten()

for row, (event, time) in enumerate(events.items()):

    i = row
    j = row + n
    
    # dh
    da_t= da_diff[event]        
    dep_mask.where(dep_mask).plot(ax=axs[i], cmap='gray', add_colorbar=False, alpha=0.5)
    ckwargs = dict(cmap='seismic', vmin=-0.20, vmax=0.20)
    cs = da_t.plot(ax=axs[i], add_colorbar=False, **ckwargs)
    # context
    gdf_riv.boundary.plot(ax=axs[i], ls='--', lw=0.5, color='k', alpha=0.5)
    gdf_towns.plot(ax=axs[i], marker='.', markersize=20, color="k", label='towns / cities', zorder=4)
    for label, grow in gdf_towns.iterrows():
        x, y = grow.geometry.x, grow.geometry.y
        axs[i].annotate(f'{label}', xy=(x, y), **ann_kwargs)
    axs[i].text(0.01, 0.95, abcd[i].upper(), fontsize=14, fontweight='bold', transform=axs[i].transAxes)
    
    # driver
    da_c = da_cmpnd[event]
    dep_mask.where(dep_mask).plot(ax=axs[j], cmap='gray', add_colorbar=False, alpha=0.5)
    p = da_c.where(da_c>0).plot(ax=axs[j], cmap=cmap, norm=norm, add_colorbar=False)
    # context
    gdf_riv.boundary.plot(ax=axs[j], ls='--', lw=0.5, color='k', alpha=0.5)
    gdf_towns.plot(ax=axs[j], marker='.', markersize=20, color="k", label='towns / cities', zorder=4)
    for label, grow in gdf_towns.iterrows():
        x, y = grow.geometry.x, grow.geometry.y
        axs[j].annotate(f'{label}', xy=(x, y), **ann_kwargs)
    axs[j].text(0.01, 0.95, abcd[j].upper(), fontsize=14, fontweight='bold', transform=axs[j].transAxes)

    axs[i].set_title(event)
    axs[j].set_title('')
    axs[i].set_extent(extent, crs=utm)
    axs[j].set_extent(extent, crs=utm)

    if row == 0:
        axs[i].yaxis.set_visible(True)
        axs[i].set_ylabel(f"y coordinate UTM zone {utm_zone} [m]")
        axs[j].yaxis.set_visible(True)
        axs[j].set_ylabel(f"y coordinate UTM zone {utm_zone} [m]")
    axs[j].xaxis.set_visible(True)
    axs[j].set_xlabel(f"x coordinate UTM zone {utm_zone} [m]")

axs[j].set_xticks(axs[j].get_xticks()[1::2])
fig.subplots_adjust(wspace=0.03, hspace=0.05)

# add colormap dh
pos0 = axs[i].get_position() # get the original position 
cax = fig.add_axes([pos0.x1 + 0.01, pos0.y0 + pos0.height*0.15, 0.015, pos0.height*0.7])
label = 'diff. waterlevel [m]\ncompound - max. single driver'
cbar=fig.colorbar(cs, cax=cax, orientation='vertical', label=label, extend='both')

# Add a colorbar drivers
pos1 = axs[j].get_position() # get the original position 
cbar_ax = fig.add_axes([pos1.x1 + 0.01, pos1.y0 + pos0.height*0.25,  0.05, pos0.height*0.5] )
cm = np.arange(1,7).reshape((3,2))
cbar_ax.imshow(cm, cmap=cmap, norm=norm, aspect='auto')
cbar_ax.yaxis.tick_right()
cbar_ax.set_yticks([0,1,2])
cbar_ax.set_yticklabels(['coastal', 'fluvial', 'pluvial'], va='center', rotation=90)
cbar_ax.set_xticks([0,1])
cbar_ax.set_xticklabels(['single', 'compound'], ha='center', rotation=60)

plt.savefig(join('../../4_results', f'{base}_compound_analysis_sfincs.png'), dpi=225, bbox_inches="tight")