In [None]:
import hydromt
import xarray as xr
import numpy as np
from os.path import join, basename

# local script skill.py
from skill import skill

In [None]:
events = {'idai': '20190320', 'eloise':'20210125'}
ddir = r'../../1_data/3_eo_rapid'
mdir = r'../../3_models/SFINCS'
mdir1 = r'../../3_models/CMF'
hmin=0.25


In [None]:
# read permanent water mask data
msk, roots = {}, {}
for i, postfix in [(0, '_100m'), (10, '_50m')]:
    root = join(mdir, f'{i:02d}_base{postfix}')
    print(root)
    da_msk_eo = hydromt.open_mfraster(join(root, 'gis', f'dry_*.tif'), concat=True)['dry'].max('dim0')==1
    da_msk_riv = hydromt.open_raster(join(root, 'gis', 'rivmsk.tif'))==1
    da_msk_dep = hydromt.open_raster(join(root, 'gis', 'dep.tif'))<=0
    da_msk = np.logical_or(da_msk_riv, da_msk_dep)
    da_msk = np.logical_or(da_msk, da_msk_eo)
    # da_msk.plot()
    msk[postfix] = da_msk
    roots[postfix] = root

In [None]:
postfix = '_100m'
for event, date in events.items():
    da_obs = hydromt.open_raster(join(roots[postfix], 'gis', f'flooding_{date}.tif'))
    da_sim = xr.open_dataarray(join(mdir, f'flddph_{event}{postfix}.nc')).raster.flipud()
    
    da_skill, da_cm = skill(da_sim, da_obs, msk[postfix], hmin=hmin)
    df_skill = da_skill.reset_coords(drop=True).to_dataframe()
    df_skill.to_csv(join(mdir, f'flddph_{event}{postfix}_skill_hmin{hmin}.csv'))
    print(df_skill)
    # break

In [None]:
# validate CMF
postfix = '_100m'
for event, date in events.items():
    da_obs = hydromt.open_raster(join(roots[postfix], 'gis', f'flooding_{date}.tif'))
    da_sim = xr.open_dataarray(join(mdir1, f'flddph_{event}.nc'))
    da_skill, da_cm = skill(da_sim, da_obs, msk[postfix], hmin=hmin)
    df_skill = da_skill.reset_coords(drop=True).to_dataframe()
    df_skill.to_csv(join(mdir1, f'flddph_{event}_skill_hmin{hmin}.csv'))
    print(df_skill)
    # break

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import colors, patheffects
import cartopy.crs as ccrs
from string import ascii_lowercase

cm_dict = {
    1: ('false neg.', '#dd8452'),
    2: ('false pos.', '#c44e52'),
    3: ('true pos.', '#4c72b0'),
}
levels = [k for k,v in cm_dict.items()] + [4]
colors = [v[1] for k,v in cm_dict.items()]
ticklabs = [v[0] for k,v in cm_dict.items()]
cmap, norm = mpl.colors.from_levels_and_colors(levels, colors)
ticks = np.array(levels[:-1])+np.diff(levels)/2.

In [None]:
hmin=0.25
scores = {'cmf': {}, 'sfx': {}}
res = {'cmf': {}, 'sfx': {}}
for event, date in events.items():
    postfix = '_100m'
    da_obs = hydromt.open_raster(join(roots[postfix], 'gis', f'flooding_{date}.tif')).load()
    da_cmf = xr.open_dataarray(join(mdir1, f'flddph_{event}.nc')).sel(run='00_default_03min')
    scores['cmf'][date],  res['cmf'][date] = skill(da_cmf, da_obs, msk[postfix], hmin=hmin)
    # postfix, run = '_50m', '11_rivpowlaw'
    postfix, run = '_100m', '01_rivpowlaw'
    da_obs = hydromt.open_raster(join(roots[postfix], 'gis', f'flooding_{date}.tif')).load()
    da_sfx = xr.open_dataarray(join(mdir, f'flddph_{event}{postfix}.nc')).sel(run=run).raster.flipud().load()
    scores['sfx'][date],  res['sfx'][date] = skill(da_sfx, da_obs, msk[postfix], hmin=hmin)

In [None]:

# read crs and utm zone > convert to cartopy
wkt = da_obs.raster.crs.to_wkt()
if "UTM zone " not in wkt:
    raise ValueError("Model CRS UTM zone not found.")
utm_zone = da_obs.raster.crs.to_wkt().split("UTM zone ")[1][:3]
utm = ccrs.UTM(int(utm_zone[:2]), "S" in utm_zone)
extent = np.array(da_obs.raster.box.buffer(100).total_bounds)[[0, 2, 1, 3]]
props = dict( facecolor='w', alpha=0.8)

ann_kwargs = dict(
    xytext=(3, 3),
    textcoords="offset points",
    zorder=4,
    path_effects=[
        patheffects.Stroke(linewidth=3, foreground="w"),
        patheffects.Normal(),
    ],
)

fig, axs = plt.subplots(
    nrows=2, figsize=(10,15),
#     nrows=1, figsize=(10,8),
    ncols=2,
    subplot_kw={'projection': utm},
    
    sharex = True, sharey=True
)
axs = axs.flatten()
d = {
    'sfx': 'SFINCS',
    'cmf': 'CaMa-Flood',
}

for row, (event, date) in enumerate(events.items()):
    for col, mod in enumerate(['cmf', 'sfx']):
        i = int(row*2 + col)
        ax = axs[i]
        da_cm = res[mod][date].load()
        da_skill = scores[mod][date].load()
        hr, csi, fr = da_skill['H'].item(), da_skill['C'].item(), da_skill['F'].item()

        da_msk.where(da_msk).plot(ax=ax, cmap='gray', add_colorbar=False, alpha=0.5)
        cs = da_cm.where(da_cm>0).plot(ax=ax, cmap=cmap, norm=norm, add_colorbar=False)
# TODO add POI
#         if row == 0:
#             gdf_pnts.to_crs(da_cm.raster.crs).plot(ax=ax)
#         for label, gdf_row in gdf_pnts.to_crs(da_cm.raster.crs).iterrows():
#             x, y = gdf_row.geometry.x, gdf_row.geometry.y
# #             ax.plot(x, y, color='k', marker='o')
#             ax.annotate(label, xy=(x, y), **ann_kwargs)
            
        ax.yaxis.set_visible(True)
        ax.xaxis.set_visible(True)
        ax.text(0.82, 0.88, f'C: {csi:.2f}\nH: {hr:.2f}\nF: {fr:.2f}', transform=ax.transAxes, bbox=props)

        if col == 0:
            ax.set_ylabel(f"y coordinate UTM zone {utm_zone} [m]")
        else:
            ax.set_ylabel('')
        
        l = ascii_lowercase[i]
        ax.set_title(f'{l}) {d[mod]} - {event} ({date})')
        if i >= len(axs)-2:
            ax.set_xlabel(f"x coordinate UTM zone {utm_zone} [m]")   
        else:
            ax.set_xlabel('')
            
ax.set_xticks(ax.get_xticks()[::2])
ax.set_extent(extent, crs=utm)
fig.subplots_adjust(wspace=0.05, hspace=0.06)

# # Add a colorbar axis at the bottom of the graph
cbar_ax = fig.add_axes([0.93, 0.33, 0.015, 0.3])

# # Draw the colorbar
cbar=fig.colorbar(cs, cax=cbar_ax, orientation='vertical', ticks=ticks)
cbar_ax.set_yticklabels(ticklabs)

plt.savefig(join(f'obs_vs_sim_hmin25_100m.png'), dpi=500, bbox_inches="tight")