In [None]:
import hydromt
import xarray as xr
import numpy as np
from os.path import join
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import colors, patheffects
import cartopy.crs as ccrs
from string import ascii_lowercase as abcd

# local script skill.py
from skill import skill

In [None]:
# events = {'idai': '20190320', 'eloise':'20210125'}
events = {
    'idai': ['20190319', '20190320'], 
    # 'idai': ['20190313', '20190314', '20190319', '20190320'], 
    'eloise':['20210125','20210126']
}
tslice_max = {'idai': ('20190319', '20190320'), 'eloise':('20210125', '20210126')}

ddir = r'../../1_data/3_eo_rapid'
fdir = r'../../4_results/rebuttal'
mdir = r'../../3_models/SFINCS2'
mdir1 = r'../../3_models/CMF'
hmin=0.15

In [None]:
from hydromt_sfincs import SfincsModel

root = r'../../3_models/SFINCS2/00_base_100m'
mod0 = SfincsModel(root, mode='r')
flwdir =  hydromt.flw.flwdir_from_da(mod0.staticmaps['flwdir'])
xs, ys = flwdir.xy(flwdir.idxs_pit)

selection = ~np.logical_and(xs>=680000, ys>=7820000)
mod0.set_staticmaps(flwdir.basins(xy=(xs[selection], ys[selection])).astype('int16'), 'basins')
mod0.staticmaps['basins'].raster.set_nodata(0)
da_msk_bas = np.logical_and(mod0.staticmaps['basins'] == 0, mod0.staticmaps['dep']>0)
da_msk_bas = da_msk_bas.astype('int16')
da_msk_bas.raster.set_nodata(0)
gdf_bas = da_msk_bas.raster.vectorize()
# gdf_bas.plot()

In [None]:
# read permanent water mask data
msk, roots = {}, {}
for i, postfix in [(0, '_100m')]: #, (10, '_50m')]:
    root = join(mdir, f'{i:02d}_base{postfix}')
    print(root)
    # da_msk_eo = hydromt.open_mfraster(join(root, 'gis', f'dry_*.tif'), concat=True)['dry'].max('dim0')==1
    da_msk_riv = np.logical_and(hydromt.open_raster(join(root, 'gis', 'rivmsk.tif'))==1, ~da_msk_bas)
    da_msk_dep = hydromt.open_raster(join(root, 'gis', 'dep.tif'))<=0
    da_msk_wat = np.logical_or(da_msk_riv, da_msk_dep)
    # da_msk = np.logical_or(da_msk, da_msk_eo)
    da_msk = np.logical_or(da_msk_wat, da_msk_bas)
    # da_msk_wat.plot()
    msk[postfix] = da_msk
    roots[postfix] = root

da_msk.plot()

In [None]:
import pandas as pd
cmf_rm = {
    '01_powlaw_06min': '7. spatial res: 200%',
    '05_powlaw_bf0_03min': '6b. bifurcations: 0% (off)',
    '04_powlaw_bf5_03min': '6a. bifurcations: 50%',
    '08_ro120_03min': '5b. P & Q forcing: 120%',
    '08_ro80_03min': '5a. P & Q forcing: 80%',
    '09_h120_03min': '4b. H forcing: 120%',
    '09_h80_03min': '4a. H forcing: 80%',
    '07_powlaw_fldman15_03min': '3b. land manning: 150%',
    '06_powlaw_fldman05_03min': '3a. land manning: 50%',
    '02_powlaw_hp405_03min': '2b. river depth: 150%',
    '03_powlaw_hp135_03min': '2a. river depth: 50%',
    '01_powlaw_03min': '1. default',
}
sfx_rm = {
    '01_rivpowlaw_qp_x120': '5b. P & Q forcing: 120%',
    '01_rivpowlaw_qp_x80': '5a. P & Q forcing: 80%',
    '01_rivpowlaw_h_x120': '4b. H forcing: 120%',
    '01_rivpowlaw_h_x80': '4a. H forcing: 80%',
    '03_rivpowlaw_lnd0.15': '3b. land manning: 150%',
    '03_rivpowlaw_lnd0.05': '3a. land manning: 50%',
    '02_rivpowlaw_hc0.405': '2b. river depth: 150%',
    '02_rivpowlaw_hc0.135': '2a. river depth: 50%',
    '01_rivpowlaw': '1. default',
}

In [None]:
sfx_skill, cmf_skill = {}, {}
postfix = '_100m'

for event, dates in events.items():
    da_sfx = xr.open_dataarray(join(mdir, f'flddph_{event}{postfix}_v2.nc'))
    da_cmf = xr.open_dataarray(join(mdir1, f'flddph_{event}_v2.nc'))

    da_obs_lst = []
    for date in dates:
        # read observations
        # da_obs = hydromt.open_raster(join(roots[postfix], 'gis', f'flooding_{date}.tif'))
        da_obs = hydromt.open_raster(join(roots[postfix], 'gis', f'sofala_floodmap_{date}.tif'))
        da_obs_lst.append(da_obs)

        # validate SFINCS
        da_sim = da_sfx.sel(time=date).squeeze().raster.flipud()
        da_skill, da_cm = skill(da_sim, da_obs, msk[postfix], hmin=hmin)
        df_skill = da_skill.reset_coords(drop=True).to_dataframe()
        sfx_skill[date] = df_skill.loc[sfx_rm.keys()].rename(sfx_rm).drop(columns='E')

        # validate CMF
        da_sim = da_cmf.sel(time=date).squeeze()
        da_skill, da_cm = skill(da_sim, da_obs, msk[postfix], hmin=hmin)
        df_skill = da_skill.reset_coords(drop=True).to_dataframe()
        cmf_skill[date] = df_skill.loc[cmf_rm.keys()].rename(cmf_rm).drop(columns='E')
    
    # max exent
    da_obs = hydromt.open_raster(join(roots[postfix], 'gis', f'sofala_floodmap_{event}_max.tif'))
    # validate SFINCS
    da_sim = da_sfx.sel(time=slice(*tslice_max[event])).max('time').raster.flipud()
    da_skill, da_cm = skill(da_sim, da_obs, msk[postfix], hmin=hmin)
    df_skill = da_skill.reset_coords(drop=True).to_dataframe()
    sfx_skill[event] = df_skill.loc[sfx_rm.keys()].rename(sfx_rm).drop(columns='E')

    # validate CMF
    da_sim = da_cmf.sel(time=slice(*tslice_max[event])).max('time')
    da_skill, da_cm = skill(da_sim, da_obs, msk[postfix], hmin=hmin)
    df_skill = da_skill.reset_coords(drop=True).to_dataframe()
    cmf_skill[event] = df_skill.loc[cmf_rm.keys()].rename(cmf_rm).drop(columns='E')
   

dfs = []
for date in sfx_skill.keys():
    df1 = pd.concat([cmf_skill[date],sfx_skill[date]],axis=1,keys=['CMF', 'SFINCS']).swaplevel(0,1,axis=1).sort_index(axis=0).sort_index(axis=1)
    dfs.append(df1)
df1 = pd.concat(dfs, axis=1, keys=sfx_skill.keys())
# df1

In [None]:
df1.iloc[0,:].unstack(0).round(2).T.to_clipboard()
df1.iloc[0,:].unstack(0).round(2).T

In [None]:
df2 = df1.copy()
df2.iloc[1:,:] = (df1.iloc[1:,]-df1.iloc[0,])
df2[['idai', 'eloise']].round(2).fillna('N/A').to_clipboard()
df2[['idai', 'eloise']].round(2).fillna('N/A')

In [None]:
import geopandas as gpd
dep_mask = hydromt.open_raster(join(mdir, '01_rivpowlaw', 'gis', 'dep.tif'))==-9999
riv_mask = hydromt.open_raster(join(mdir, '01_rivpowlaw', 'gis', 'rivmsk.tif'))
riv_mask.raster.set_nodata(0)
gdf_riv = riv_mask.where(~dep_mask, 0).raster.vectorize()
gdf_towns = gpd.read_file(r'../../1_data/towns.geojson').set_index('index').to_crs(gdf_riv.crs)
gdf_points = gpd.read_file(join(mdir, '01_rivpowlaw', 'gis', 'obs.geojson'))
gdf_points.index = gdf_points.index+1

In [None]:
# hmin=0.25
scores = {'cmf': {}, 'sfx': {}}
res = {'cmf': {}, 'sfx': {}}
hmax_sim = {'cmf': {}, 'sfx': {}}
postfix = '_100m'
for event, dates in events.items():
    da_cmf = xr.open_dataarray(join(mdir1, f'flddph_{event}_v2.nc')).sel(run='01_powlaw_03min')
    da_sfx = xr.open_dataarray(join(mdir, f'flddph_{event}{postfix}_v2.nc')).sel(run='01_rivpowlaw').raster.flipud().load()
    dates = dates + [f'{event}_max']
    for date in dates:
        # da_obs = hydromt.open_raster(join(roots[postfix], 'gis', f'flooding_{date}.tif')).load()
        da_obs = hydromt.open_raster(join(roots[postfix], 'gis', f'sofala_floodmap_{date}.tif'))
        # CaMa-Flood
        if 'max' in date:
            da_cmf0 = da_cmf.sel(time=slice(*tslice_max[event])).max('time')
            da_sfx0 = da_sfx.sel(time=slice(*tslice_max[event])).max('time')
        else:
            da_cmf0 = da_cmf.sel(time=date)
            da_sfx0 = da_sfx.sel(time=date)
        scores['cmf'][date],  res['cmf'][date] = skill(da_cmf0, da_obs, msk[postfix], hmin=hmin)
        hmax_sim['cmf'][date] = da_cmf0.where(~msk[postfix])
        # SFINCS
        scores['sfx'][date],  res['sfx'][date] = skill(da_sfx0, da_obs, msk[postfix], hmin=hmin)
        hmax_sim['sfx'][date] = da_sfx0.where(~msk[postfix])

In [None]:
import matplotlib as mpl
mpl.rcParams['hatch.linewidth'] = 0.2

In [None]:

# read crs and utm zone > convert to cartopy
wkt = da_obs.raster.crs.to_wkt()
if "UTM zone " not in wkt:
    raise ValueError("Model CRS UTM zone not found.")
utm_zone = da_obs.raster.crs.to_wkt().split("UTM zone ")[1][:3]
utm = ccrs.UTM(int(utm_zone[:2]), "S" in utm_zone)
extent = np.array(da_obs.raster.box.buffer(100).total_bounds)[[0, 2, 1, 3]]
props = dict( facecolor='w', lw=0, alpha=0.8)


cm_dict = {
    1: ('false neg.', '#dd8452'),
    2: ('false pos.', '#c44e52'),
    3: ('true pos.', '#4c72b0'),
}
levels = [k for k,v in cm_dict.items()] + [4]
colors = [v[1] for k,v in cm_dict.items()]
ticklabs = [v[0] for k,v in cm_dict.items()]
cmap, norm = mpl.colors.from_levels_and_colors(levels, colors)
ticks = np.array(levels[:-1])+np.diff(levels)/2.


ann_kwargs = dict(
    xytext=(3, 3),
    textcoords="offset points",
    zorder=4,
    path_effects=[
        patheffects.Stroke(linewidth=5, foreground="w"),
        patheffects.Normal(),
    ],
)

fig, axs = plt.subplots(
    figsize=(9,12),
    nrows=2, ncols=2,
    subplot_kw={'projection': utm},
    sharex = True, sharey=True
)
axs = axs.flatten()
d = {
    'sfx': 'SFINCS',
    'cmf': 'CaMa-Flood',
}

for row, (event, _) in enumerate(events.items()):
    date = f'{event}_max'
    for col, mod in enumerate(['cmf', 'sfx']):
        i = int(row*2 + col)
        ax = axs[i]
        ax.set_extent(extent, crs=utm)

        da_cm = res[mod][date].load()
        da_skill = scores[mod][date].load()
        hr, csi, fr = np.round(da_skill['H'].item(),2), np.round(da_skill['C'].item(),2), np.round(da_skill['F'].item(),2)

        da_cm.raster.set_crs(gdf_riv.crs)
        da_msk_wat.raster.set_crs(gdf_riv.crs)

        gdf_bas.plot(ax=ax, facecolor='w', edgecolor='k', lw=0.2, zorder=3, hatch='xxx', )
        da_msk_wat.where(da_msk_wat).plot(ax=ax, cmap='gray', add_colorbar=False, alpha=0.5)
        cs = da_cm.where(da_cm>0).plot(ax=ax, cmap=cmap, norm=norm, add_colorbar=False)
        # context
        # gdf_riv.boundary.plot(ax=ax, ls='--', lw=0.5, color='k', alpha=0.5)
        # gdf_towns.plot(ax=ax, marker='.', markersize=20, color="k", label='towns / cities', zorder=4)
        for label, grow in gdf_towns.iterrows():
            x, y = grow.geometry.x, grow.geometry.y
            ax.plot(x,y, '.k', markersize=4)
            ax.annotate(f'{label}', xy=(x, y), **ann_kwargs)
        ax.text(0.03, 0.95, abcd[i].upper(), fontsize=14, fontweight='bold', transform=ax.transAxes, bbox=props)
        ax.text(0.8, 0.88, f'C: {csi:.2f}\nH: {hr:.2f}\nF: {fr:.2f}', transform=ax.transAxes, bbox=props)

        if col == 0:
            ax.yaxis.set_visible(True)
            ax.set_ylabel(f"y coordinate UTM zone {utm_zone} [m]")
        else:
            ax.set_ylabel('')
            ax.yaxis.set_visible(False)

        ax.set_title(f'{d[mod]} - {event.capitalize()}')
        if i >= len(axs)-2:
            ax.xaxis.set_visible(True)
            ax.set_xlabel(f"x coordinate UTM zone {utm_zone} [m]")   
        else:
            ax.set_xlabel('')
            ax.xaxis.set_visible(False)
            
ax.set_xticks(ax.get_xticks()[::2])
ax.set_extent(extent, crs=utm)
fig.subplots_adjust(wspace=0.04, hspace=0.06)

# # Add a colorbar axis at the bottom of the graph
cbar_ax = fig.add_axes([0.93, 0.33, 0.015, 0.3])

# # Draw the colorbar
cbar=fig.colorbar(cs, cax=cbar_ax, orientation='vertical', ticks=ticks)
cbar_ax.set_yticklabels(ticklabs, va='center', rotation=90)

plt.savefig(join(fdir, f'validation_hmin{int(hmin*100)}_100m_max.png'), dpi=500, bbox_inches="tight")

In [None]:
fig, axs = plt.subplots(
    figsize=(9,12),
    nrows=2, ncols=2,
    subplot_kw={'projection': utm},
    sharex = True, sharey=True
)
axs = axs.flatten()


event = 'idai'
for row, date in enumerate(events[event]):
    for col, mod in enumerate(['cmf', 'sfx']):
        i = int(row*2 + col)
        ax = axs[i]
        ax.set_extent(extent, crs=utm)

        da_cm = res[mod][date].load()
        da_skill = scores[mod][date].load()
        hr, csi, fr = round(da_skill['H'].item(),2), round(da_skill['C'].item(),2), round(da_skill['F'].item(),2)

        da_cm.raster.set_crs(gdf_riv.crs)
        da_msk_wat.raster.set_crs(gdf_riv.crs)

        gdf_bas.plot(ax=ax, facecolor='w', edgecolor='k', lw=0.2, zorder=3, hatch='xxx', )
        da_msk_wat.where(da_msk_wat).plot(ax=ax, cmap='gray', add_colorbar=False, alpha=0.5)
        cs = da_cm.where(da_cm>0).plot(ax=ax, cmap=cmap, norm=norm, add_colorbar=False)
        # context
        # gdf_riv.boundary.plot(ax=ax, ls='--', lw=0.5, color='k', alpha=0.5)
        # gdf_towns.plot(ax=ax, marker='.', markersize=20, color="k", label='towns / cities', zorder=4)
        for label, grow in gdf_towns.iterrows():
            x, y = grow.geometry.x, grow.geometry.y
            ax.plot(x,y, '.k', markersize=4)
            ax.annotate(f'{label}', xy=(x, y), **ann_kwargs)
        ax.text(0.03, 0.95, abcd[i].upper(), fontsize=14, fontweight='bold', transform=ax.transAxes, bbox=props)
        ax.text(0.8, 0.88, f'C: {csi:.2f}\nH: {hr:.2f}\nF: {fr:.2f}', transform=ax.transAxes, bbox=props)

        if col == 0:
            ax.yaxis.set_visible(True)
            ax.set_ylabel(f"y coordinate UTM zone {utm_zone} [m]")
        else:
            ax.set_ylabel('')
            ax.yaxis.set_visible(False)

        ax.set_title(f'{d[mod]} - {date}')
        if i >= len(axs)-2:
            ax.xaxis.set_visible(True)
            ax.set_xlabel(f"x coordinate UTM zone {utm_zone} [m]")   
        else:
            ax.set_xlabel('')
            ax.xaxis.set_visible(False)
            
ax.set_xticks(ax.get_xticks()[::2])
ax.set_extent(extent, crs=utm)
fig.subplots_adjust(wspace=0.04, hspace=0.06)

# # Add a colorbar axis at the bottom of the graph
cbar_ax = fig.add_axes([0.93, 0.33, 0.015, 0.3])

# # Draw the colorbar
cbar=fig.colorbar(cs, cax=cbar_ax, orientation='vertical', ticks=ticks)
cbar_ax.set_yticklabels(ticklabs, va='center', rotation=90)

plt.savefig(join(fdir, f'validation_hmin{int(hmin*100)}_100m_{event}.png'), dpi=500, bbox_inches="tight")

In [None]:
fig, axs = plt.subplots(
    figsize=(9,12),
    nrows=2, ncols=2,
    subplot_kw={'projection': utm},
    sharex = True, sharey=True
)
axs = axs.flatten()

cmap = plt.cm.viridis.copy()
cmap.set_under('white')

for row, (event, _) in enumerate(events.items()):
    date = f'{event}_max'
    for col, mod in enumerate(['cmf', 'sfx']):
        i = int(row*2 + col)
        ax = axs[i]
        ax.set_extent(extent, crs=utm)
        da_sim = hmax_sim[mod][date].load()
        da_skill = scores[mod][date].load()
        hr, csi, fr = da_skill['H'].item(), da_skill['C'].item(), da_skill['F'].item()

        da_sim.raster.set_crs(gdf_riv.crs)
        da_msk_wat.raster.set_crs(gdf_riv.crs)

        gdf_bas.plot(ax=ax, facecolor='w', edgecolor='k', lw=0.2, zorder=3, hatch='xxx', )
        da_msk_wat.where(da_msk_wat).plot(ax=ax, cmap='gray', add_colorbar=False, alpha=0.5)
        cs = da_sim.where(da_sim>hmin).plot(ax=ax, cmap=cmap, vmin=hmin, vmax=3.5, add_colorbar=False)
        # context
        # gdf_riv.boundary.plot(ax=ax, ls='--', lw=0.5, color='k', alpha=0.5)
        # gdf_towns.plot(ax=ax, marker='.', markersize=20, color="k", label='towns / cities', zorder=4)
        for label, grow in gdf_towns.iterrows():
            x, y = grow.geometry.x, grow.geometry.y
            ax.plot(x,y, '.k', markersize=4)
            ax.annotate(f'{label}', xy=(x, y), **ann_kwargs)
        for label, grow in gdf_points.iterrows():
            x, y = grow.geometry.x, grow.geometry.y
            ax.plot(x,y, marker='d', markersize=8, markerfacecolor='w', color='k', zorder=4)
            ax.annotate(f'{label}', xy=(x, y), **ann_kwargs)
        ax.text(0.03, 0.95, abcd[i].upper(), fontsize=14, fontweight='bold', transform=ax.transAxes, bbox=props)
            
        if col == 0:
            ax.yaxis.set_visible(True)
            ax.set_ylabel(f"y coordinate UTM zone {utm_zone} [m]")
        else:
            ax.set_ylabel('')
            ax.yaxis.set_visible(False)

        ax.set_title(f'{d[mod]} - {event.capitalize()}')
        if i >= len(axs)-2:
            ax.xaxis.set_visible(True)
            ax.set_xlabel(f"x coordinate UTM zone {utm_zone} [m]")   
        else:
            ax.set_xlabel('')
            ax.xaxis.set_visible(False)
            
ax.set_xticks(ax.get_xticks()[::2])
ax.set_extent(extent, crs=utm)
fig.subplots_adjust(wspace=0.04, hspace=0.06)

# # Add a colorbar axis at the bottom of the graph
cbar_ax = fig.add_axes([0.93, 0.33, 0.015, 0.3])

# # Draw the colorbar
ticks = [hmin] + np.arange(0.5, 4.1, 0.5).tolist()
cbar=fig.colorbar(cs, cax=cbar_ax, orientation='vertical', ticks=ticks, extend='both')
cbar_ax.set_ylabel('maximum flood depth [m]')

plt.savefig(join(fdir, f'flddph_hmin{int(hmin*100)}_100m_max.png'), dpi=500, bbox_inches="tight")