In [None]:
import xarray as xr
from os.path import join, basename
from datetime import timedelta, datetime
import pandas as pd
import numpy as np
import glob

In [None]:
root = r'/scratch/compound_hotspots/'
ddir = join(root, 'data', '3-model_output')

In [None]:
# t0, t1 = datetime(1980,1,2), datetime(2014,12,30)   # NOTE: cama date at 00hrs of 'next day' and missing GTSM values at 31-12-2014
# scen_rm = {
#     'cmpnd':   'surge', 
#     'runoff':  'seas',
#     'tide':    'tide',
#     'msl':     'msl'
#     # 'surge':   'surge',
# }

In [None]:
from dask.diagnostics import ProgressBar
chunks = {'time':3196, 'lat':30, 'lon':30}
# regions = {
#     'A': (-90.0, 41.0, -50.0, 65.0), # 40 x 24
#     'B': (-12.0, 47.0, 18.0, 65.0), # 30 x 18
#     'C': (70.0, -10.0, 130.0, 26.0),  # 60 x 36
# }
# name = 'B'
# xmin, ymin, xmax, ymax = regions[name]

for scen in ['cmpnd', 'runoff']:
    fns = glob.glob(join(ddir, f'anu_mswep_{scen}_v362_1980-2014', 'sfcelv*.nc'))
    fn_out = join(ddir, f'anu_mswep_{scen}_v362_1980-2014.zarr')
    ds_sel = xr.open_mfdataset(fns, combine='by_coords', chunks=chunks).persist() #.sel(lon=slice(xmin, xmax), lat=slice(ymax, ymin))
    with ProgressBar():
        ds_sel.chunk(chunks).to_zarr(fn_out)


In [None]:
%matplotlib inline
map_dir = r'/home/dirk/models/cama-flood_bmi_v3.6.2_nc/map/global_15min'
rm = {'x':'lon', 'y':'lat'}
elevtn = xr.open_rasterio(join(map_dir, 'elevtn.tif')).drop('band').squeeze().rename(rm)
lecz = xr.open_rasterio(join(map_dir, 'LEZC_10m_basin.tif')).drop('band').squeeze().rename(rm)
# lecz = lecz.where(lecz>0)
landmask = elevtn != -9999
# # elevtn.where(lecz>0).plot(vmin=0, vmax=10, cmap='gist_earth')
# elevtn.where(landmask).stack(latlon=('lat','lon')) #.dropna('latlon').unstack().reindex_like(elevtn).plot()
# lecz_stacked

In [None]:
from dask.diagnostics import ProgressBar

from peaks import get_peaks
from xlm_fit import xlm_fit 
from lmoments3 import distr
min_dist = 14
chunks = {'time':-1, 'lat':30, 'lon':30}
chunks2 = {'latlon':500, 'time':-1}
rps_out = np.array([1.1, 1.5, 2, 5, 10, 20, 30, 50, 100])
mask = lecz

with ProgressBar():
    for scen in ['cmpnd', 'runoff'][1:]:
        print(scen)
        #fn_out = join(ddir, f'anu_mswep_{scen}_v362_1980-2014', f'region{name}.zarr')
        #ds = xr.open_zarr(fn_out)
        # 
        fns = glob.glob(join(ddir, f'anu_mswep_{scen}_v362_1980-2014', 'sfcelv*.nc'))
        da = xr.open_mfdataset(fns, combine='by_coords')['sfcelv']# #.sel(lon=slice(xmin, xmax), lat=slice(ymax, ymin))
        da.coords['mask'] = mask
        da_stacked = da.stack(latlon=('lat','lon'))
        da_stacked = da_stacked.where(da_stacked['mask']>0, drop=True).chunk(chunks2)
        # get AM
        peaks_am_stacked = get_peaks(da_stacked, min_dist=min_dist, dim='time', chunks=chunks2).groupby('time.year').max('time')
        peaks_am_stacked = peaks_am_stacked.rename({'year': 'time'}) #.persist()
        #fn_out = join(ddir, f'anu_mswep_{scen}_v362_1980-2014', f'am_sfcelv.nc')
        #print(basename(fn_out))
        #peaks_am_stacked.unstack().to_netcdf(fn_out)
        #peaks_am = xr.open_dataset(fn_out, chunks=chunks)
        # fit gumbel
        ds_rp_stacked = xlm_fit(peaks_am_stacked, fdist=distr.gum, rp=rps_out)
        peaks_am_stacked.name = 'sfcelv_am'
        ds_out = xr.merge([
            peaks_am_stacked, ds_rp_stacked
        ]).unstack().reindex_like(mask).chunk(chunks)
        fn_out = join(ddir, f'anu_mswep_{scen}_v362_1980-2014', f'rp_sfcelv.nc')
        print(basename(fn_out))
        ds_out.to_netcdf(fn_out)
        break

In [None]:
import sys
import os
sys.path.append(os.path.abspath('../4-analyze/'))
from plot_tools import *

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import cartopy.crs as ccrs
import seaborn as sns
vmin, vmax, n= -0.2, 1.0, 7
cticks=np.linspace(vmin, vmax, n)
cmap_turbo_div = ListedColormap([
    interpolate_cmap(google_turbo_data, x) for x in 
    np.hstack([np.linspace(0.55-(.4*abs(vmin)/vmax)*2, .55, int(-vmin*20)), np.linspace(0.6, 1, int(vmax*20))])
])

In [None]:
xr.open_dataset(join(ddir, f'anu_mswep_cmpnd_v362_1980-2014', f'rp_sfcelv.nc'))

In [None]:
swe_cmpnd = xr.open_dataset(join(ddir, f'anu_mswep_cmpnd_v362_1980-2014', f'rp_sfcelv.nc'))['sfcelv']
swe_runoff = xr.open_dataset(join(ddir, f'anu_mswep_runoff_v362_1980-2014', f'rp_sfcelv.nc'))['sfcelv']
swe_diff = (swe_cmpnd - swe_runoff)

In [None]:
regions = {
    'A': (-90.0, 41.0, -50.0, 65.0), # 40 x 24
    'B': (-12.0, 47.0, 18.0, 65.0), # 30 x 18
    'C': (70.0, -10.0, 130.0, 26.0),  # 60 x 36
}
name = 'A'
xmin, ymin, xmax, ymax = regions[name]


In [None]:
import subprocess
import rasterio
from rasterio.transform import from_origin
import os 

map_dir = r'/home/dirk/models/cama-flood_bmi_v3.6.2_nc/map'
sdir = join(map_dir, 'downscale_flddph')
os.chdir(sdir)

fn_elevtn = join(sdir, 'map', 'elevtn.tif')
rm = {'x':'lon', 'y':'lat'}
elevtn = xr.open_rasterio(fn_elevtn).drop('band').squeeze().rename(rm)

fn_regions = join(sdir, 'map', 'hires', 'location.txt')
regions = pd.read_csv(fn_regions, delim_whitespace=True, index_col=0).T \
            .set_index('area').astype(float).to_dict(orient='index')
regions = {k:regions[k] for k in ['af1'] }
# read nc
fn = join(ddir, f'anu_mswep_runoff_v362_1980-2014', f'ev_map_sfcelv.nc')
swe_runoff = xr.open_dataset(fn)['sfcelv_ev']
flddph = swe_runoff - elevtn
flddph = xr.where(flddph<0,0,flddph)

mv=1e+20

# write to bin
for T in flddph['T'].values[2:]:
    print(f'rp: {T:003.1f}')
    data = flddph.fillna(mv).sel(T=T).data

    fn_out_bin = join(sdir, f'flddph_T{T:03.0f}')
    data.astype('f4').tofile(fn_out_bin)
    
    # downscale
    for area in regions.keys():
        print(area)
        msg = ['./downscale_flddph', str(area), basename(fn_out_bin), '1']
        subprocess.call(msg, cwd=sdir, stderr=subprocess.STDOUT)

        # open binary output
        fn_fld = join(sdir, '{:s}.flood'.format(area))
        if os.path.isfile(fn_fld):
            ny, nx = int(regions[area]['ny']), int(regions[area]['nx'])
            with open(fn_fld, 'r') as fid:
                data = np.fromfile(fid, 'f4').reshape(ny, nx)
                data = np.where(data==mv, -9999, data)

            # write to geotiff
            fn_out_tif = join(ddir, f'anu_mswep_runoff_v362_1980-2014', 'flddph', f'{area}_T{T:03.0f}.tif')
            west, north, csize = regions[area]['west'], regions[area]['north'], regions[area]['csize']
            transform = from_origin(west, north, csize, csize)
            with rasterio.open(fn_out_tif, 'w', driver='GTiff', height=data.shape[0],
                        compress='lzw', width=data.shape[1], count=1, dtype=str(data.dtype),
                        crs='+proj=latlong', transform=transform, nodata=-9999) as dst:
                dst.write(data, 1)

            # remove binary output
            os.unlink(fn_fld)
        else:
            print(' '.join(msg))
    os.unlink(fn_out_bin)
    break

In [None]:
import os
def build_vrt(fn_vrt, fns, bbox=None, nodata=None):
    from osgeo import gdal
    vrt_options = gdal.BuildVRTOptions(
        outputBounds=bbox, 
        srcNodata=nodata, 
        VRTNodata=nodata
    )
    if os.path.isfile(fn_vrt):
        os.remove(fn_vrt)
    vrt = gdal.BuildVRT(fn_vrt, fns, options=vrt_options)
    if vrt is None:
        raise Exception('Creating vrt not successfull, check input files.')
    vrt = None # to write to disk

In [None]:
fns = glob.glob(join(ddir, f'anu_mswep_runoff_v362_1980-2014', 'flddph', f'*_T005.tif'))
fn_vrt = join(ddir, f'anu_mswep_runoff_v362_1980-2014', 'flddph', f'T005.vrt')
bbox = -180, -90, 180, 90
nodata = -9999
build_vrt(fn_vrt, fns, bbox, nodata)

In [None]:
rps = [2, 5]
ds_lst = []
chunks={'x':7200, 'y':7200}
for T in rps:
    fn_vrt = join(ddir, f'anu_mswep_runoff_v362_1980-2014', 'flddph', f'T{T:03.0f}.vrt')
    ds_lst.append(xr.open_rasterio(fn_vrt, chunks=chunks).drop('band').squeeze().rename({'x':'lon', 'y':'lat'}))
flddph = xr.concat(ds_lst, dim='T')
flddph.coords['T'] = xr.Variable('T', rps)
flddph.name = 'flddph'
fn_out = join(ddir, f'anu_mswep_runoff_v362_1980-2014', 'flddph_highres.nc')
flddph.to_netcdf(fn_out, encoding={'flddph': {'zlib': True}})

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
fig,ax = plt.subplots(1,1, figsize=(15,15))
flddph = xr.open_rasterio(fn_vrt).squeeze().rename({'x':'lon', 'y':'lat'})
xmin, ymin, xmax, ymax = -11.55,49.74,2.38,59.64
flddph_sel = flddph.sel(lat=slice(ymax, ymin), lon=slice(xmin, xmax))
flddph_sel = flddph_sel.where(flddph_sel!=-9999)
flddph_sel.plot(vmin=0, vmax=2, cmap='Blues')

In [None]:
%matplotlib inline 


fig,ax = plt.subplots(1,1, figsize=(20,10))
flddph.sel(T=50.).plot(ax=ax, vmin=0, vmax=4)

In [None]:
fig,ax = plt.subplots(1,1, figsize=(10,10))
swe_diff0 = swe_diff.sel(lat=slice(ymax,ymin), lon=slice(xmin,xmax))
swe_diff0.where(xr.ufuncs.fabs(swe_diff)>0.01).sel(T=100).plot(ax=ax, vmin=vmin, vmax=vmax, cmap=cmap_turbo_div)

In [None]:
name = 'A'
ds = xr.concat([
    xr.open_dataset(join(ddir, f'anu_mswep_{scen}_v362_1980-2014', f'region{name}_rp.nc')) 
    for scen in ['cmpnd', 'runoff']
], dim='scen')
ds

In [None]:
%matplotlib inline 
da_swe_diff = -ds['sfcelv'].where(ds['sfcelv']>0).diff('scen').squeeze()
da_swe_diff = da_swe_diff.where(xr.ufuncs.fabs(da_swe_diff)>0.01)
da_swe_diff.sel(T=20.).plot(cmap=cmap_turbo_div, vmin=vmin, vmax=vmax)