### Imports and configuration

In [None]:
import numpy as np
import os
import pandas as pd
import rasterio as rio

from glob import glob
from pathlib import Path

### Helper function

In [None]:
def couple_transition_anim(surge_inun, riv_inun, ratio_array, mask, alpha):
    ''' 
    return array with inundation zones: 0-coastal, 1-hydrologic, 2-transition
    '''
    compound_inun = np.where(
        (ratio_array > 1) & (surge_inun != 0), # coastal zone
        0,
        np.where(
            (ratio_array <= alpha*riv_inun) & (riv_inun != 0), # hydrologic zone
            1,
            np.where(
                (ratio_array <= 1) & (ratio_array > alpha*riv_inun) & (surge_inun != 0) & (riv_inun != 0), # transition zone
                2,
                np.nan, # else assign NaN
        )))

    return np.ma.masked_array(compound_inun,mask)

### Save data for plots

In [None]:
surge_paths = sorted(glob('data_github\\sensitivity_data\\transformed\\*_m_gdal.tif'))
riv_paths = sorted(glob('data_github\\sensitivity_data\\transformed\\*_map_gdal.tif'))
riv_names = [riv_path.split('\\')[-1][7:17] for riv_path in riv_paths]

alphalist = [0.1, 0.4]

cols = ['riv_datetime','surge_h','alpha_val','n_coastal','n_hydro','n_trans']
df = pd.DataFrame(columns=cols)

# have a row for each combination of riv_datetime, surge_h, and alpha_val
# 22 items repeated 36 * 2 times
riv_rows = np.repeat(riv_names, len(surge_paths) * len(alphalist))
df['riv_datetime'] = riv_rows
prev_riv_path = None
riv_inun = None
i = 0
w_msk = ''
for riv_path in riv_paths:
    # riv_name = riv_path.split('\\')[-1] # file name of riverine inundation
    # riv_date = riv_name[7:17] # YYYYMMDDHH datetime of riverine inundation

    if riv_path != prev_riv_path:
        del(riv_inun)
        with rio.open(riv_path) as ds_in:
            riv_inun = ds_in.read(1, masked=True)
        riv_inun[riv_inun>3.4e+38] = np.nan # set nodata
        riv_inun[riv_inun<0] = np.nan
    prev_riv_path = riv_path

    for surge_path in surge_paths:
        surge_name = surge_path.split('\\')[-1] # file name of surge
        surge_h_val = surge_name[-13:-11] # #.# height of surge in meters

        # open and clean raster datasets
        with rio.open(surge_path) as ds_in:
            surge_inun = ds_in.read(1, masked=True)
            surge_inun_profile = ds_in.profile
        if w_msk == '':
            # save mask, will get wiped out on next lines
            # only need to do once
            msk = surge_inun.copy().mask 
            w_msk = (~surge_inun.mask * 255).astype('uint8')
        surge_inun[surge_inun>3.4e+38] = np.nan # set nodata
        surge_inun[surge_inun < 0] = np.nan

        assert np.equal(
            surge_inun.shape, riv_inun.shape
        ).all()

        ratio_array = surge_inun / riv_inun 

        for alpha in alphalist:

            coupled = couple_transition_anim(
                surge_inun = surge_inun,
                riv_inun = riv_inun,
                ratio_array = ratio_array,
                mask = msk,
                alpha = alpha 
                )

            tmp, counts = np.unique(coupled, return_counts=True)

            if len(counts) == 1:
                n_coastal = counts[0]
            elif len(counts) == 2:
                n_coastal = counts[0]
                n_hydro = counts[1]
            elif len(counts) > 2:
                n_coastal = counts[0]
                n_hydro = counts[1]
                n_trans = counts[2]

            # df.loc[i, 'riv_datetime'] = riv_date
            df.loc[i, 'surge_h'] = surge_h_val
            df.loc[i, 'alpha_val'] = alpha                
            df.loc[i, 'n_coastal'] = n_coastal
            df.loc[i, 'n_hydro'] = n_hydro
            df.loc[i, 'n_trans'] = n_trans

            # just in case
            tmp_path = Path('data_github/sensitivity_data/tmp.csv')
            df.copy().to_csv(tmp_path, sep=',')

            with rio.Env(GDAL_TIFF_INTERNAL_MASK=True):
                with rio.open(
                    "data_github/sensitivity_data/coupled/" \
                    f"{str(df.loc[i,'riv_datetime'])}_{str(df.loc[i,'surge_h'])}_{str(df.loc[i,'alpha_val']).replace('.','')}.tif", 
                    'w', **surge_inun_profile) as ds_out:
                    ds_out.write(coupled,1)
                    ds_out.write_mask(w_msk)

            del(coupled)
            i += 1

        del(surge_inun, ratio_array, tmp)

# save final dataframe
df.to_csv(
    'data_github/sensitivity_data/full_sensitivity_analysis.csv', sep=','
    )
