In [None]:
import xarray
import numpy
import pandas
import pathlib
import time
import geopandas as gpd
import matplotlib.pyplot as plt

In [None]:
import matplotlib as mpl
mpl.rc('font', size=11)
# some of the following may be repetetive but can also be set relative to the font value above 
#    (eg "xx-small, x-small,small, medium, large, x-large, xx-large, larger, or smaller"; see link above for details)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)
mpl.rc('legend', fontsize=12)
mpl.rc('axes', titlesize=16)
mpl.rc('axes', labelsize=12)
mpl.rc('figure', titlesize=16)
mpl.rc('text', usetex=False)
mpl.rc('font', family='sans-serif', weight='normal', style='normal')

In [None]:
# function to reshape FVCOM array
def reshape_fvcom(fvcom_timeIJK, reshape_type):
    """ Reorganize the 2D FVCOM output from 2-dimensions of (time,nodes)
    to a format that allows for daily, yearly, or depth calculations. 
    
    param float fvcom_timeIJK: FVCOM_v2.7ecy output array in dimension of 8760x160120.
    param string reshape_type: ['days','levels','dayslevels']
    return: Reorganized array
    """
    ti,ni = fvcom_timeIJK.shape
    print(ti,ni)
    # Error handling
    if reshape_type not in ['days','levels','dayslevels']:
        raise ValueError(
            "options for reshape_type are: 'days','levels','dayslevels'"
        )
    
    # Reshaping
    if reshape_type == 'days':
        if (ti != 8760):
            raise TypeError(
                "FVCOM array must have a time dimension of 8760"
            )
        fvcom_reshaped = numpy.reshape(
            fvcom_timeIJK[:,:].data, (365,24,ni)
        )
    elif reshape_type == 'levels':
        if (ni != 160120):
            raise TypeError(
                "FVCOM array must have a node dimension of 160120"
            )
        fvcom_reshaped = numpy.reshape(
            fvcom_timeIJK[:,:].data, (ti,16012,10)
        )
    elif reshape_type == 'dayslevels':
        if (ti != 8760) or (ni != 160120):
            raise TypeError(
                "FVCOM array size must be 8760 x 160120"
            )
        fvcom_reshaped = numpy.reshape(
            fvcom_timeIJK[:,:].data, (365,24,16012,10)
        )
        
    return fvcom_reshaped
def calc_fvcom_stat(fvcom_output, stat_type, axis):
    """ Extract model output at nodes by level. 
    
    param float fvcom_output: FVCOM_v2.7ecy output array in dimensions of time x 160120.
    param float stat_type: 'min','mean'.
    param int axis: Integer from 0 to ndims(fvcom_output)
    
    return: stat of model output across specified axis (axs)
    """
    fvcom_stat = getattr(numpy,stat_type)(fvcom_output,axis=axis)
    
    return fvcom_stat

def extract_fvcom_level(gdf, fvcom_timeIJK, LevelNum):
    """ Extract model output at nodes by level. 
    
    param dataframe gdf: geopandas dataframe of FVCOM nodes from 2D planar nodes
        with dimensions of 16012.
    param float fvcom_timeIJK: 3D-FVCOM output in dimensions of time x 160120.
    param int LevelNum: Integer from 1 (surface) to 10 (bottom)
    
    return fvcom_nodeIDs: model output at level in dimension of time x 16012
    """
    if LevelNum not in range(1,11):
        raise ValueError("fvcom_LevelNum must be an integer value from 1-10")

    try:
        node_ids = gdf['node_id'].to_numpy()
    except:
        raise AttributeError("missing 'node_id' column in dataframe")
        
    ijk_index = node_ids * 10 - (11-LevelNum)
    # get DO values at each level
    fvcom_nodeIDs = fvcom_timeIJK[:,ijk_index]
    # if ds['Var_10'] is passed in: 
    # fvcom_nodeIDs = fvcom_timeIJK[:,:].data[:,ijk_index]
    
    return fvcom_nodeIDs

In [None]:
# Create a dictionary with all variable options (these can be expanded)
variable_name_list=['DO','NH3','NO3','NPP','Temp','Salinity']
parameter_ID_list=['Var_10','Var_14','Var_15','Var_17','Var_18','Var_19']
model_output_name = {
    variable_name_list[i]: parameter_ID_list[i] for i in range(len(variable_name_list))
}

# Define the variable that we want to plot
variable_name = "DO" 

# Define directory for saving netcdf output
output_directory = pathlib.Path('/mmfs1/gscratch/ssmc/USRS/PSI/Rachael/output/TS')

# Define locations for different scenarios
root_dir = pathlib.Path('/mmfs1/gscratch/ssmc/USRS/PSI/Adi/BS_WQM/')
data_paths=numpy.array(
    [root_dir/'2014_SSM4_WQ_exist_orig/hotstart/outputs',
     root_dir/'2014_SSM4_WQ_ref_orig/hotstart/outputs'])

In [None]:
si = 0 # choose first directory = Existing
variable_name = 'DO'

scenario_name=str(data_paths[si]).split('/')[-3]
# output netcdf filename
output_file = output_directory/f'{scenario_name}_{variable_name}.nc'
# input netcdf filename
file_path=data_paths[si]/'s_hy_base000_pnnl007_nodes.nc'
# load variable into xarray and calculate daily min.
with xarray.open_dataset(file_path) as ds:
    dailyDO = reshape_fvcom(
        ds[model_output_name[variable_name]][:,:].data, 
        'days'
    ) #return (365x24xnodes)
    # calculate daily minimum
    dailyDO_tmin = calc_fvcom_stat(dailyDO, 'min', axis=1)
    # reshape to levels
    dailyDO_tmin_rshp = reshape_fvcom(dailyDO_tmin, 'levels')
    # calculate minimum across depth levels
    dailyDO_tmin_zmin = calc_fvcom_stat(dailyDO_tmin_rshp, 'min', axis=2)

In [None]:
# Kevin's shapefile
shapefile_path = pathlib.Path(
    '/mmfs1/gscratch/ssmc/USRS/PSI/Rachael/projects/KingCounty/KingCounty-Rachael/kevin_shapefiles'
)/'SSMGrid2_tce.shp'
gdf_k = gpd.read_file(shapefile_path)
gdf_k = gdf_k.loc[:, ('tce','Basin','geometry')]
gdf_k=gdf_k.rename(columns={'tce':'node_id'})
# Extract SOG_Bays from Kevin's shapefile 
gdf_SOG_Nbays = gdf_k.loc[gdf_k['Basin']=='SOG_Bays']
gdf = gdf_SOG_Nbays.copy()
gdf_k.head(2)

In [None]:
dailyDO_tmin_zmin_SOG = dailyDO_tmin_zmin[gdf['node_id']-1]

dailyDO_tmin_zmin_SOG.shape

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
gdir = pathlib.Path(
    '/mmfs1/gscratch/ssmc/USRS/PSI/Rachael/projects/CWA/CleanWaterAlliance/dev/minDO/')
fs_t=14
fs_a=12
cax={}
gdf['minDO'] = dailyDO_tmin_zmin_SOG   
fig, axs = plt.subplots(1,1, figsize = (8,6))
# create `cax` for the colorbar
divider = make_axes_locatable(axs)
cax = divider.append_axes("right", size="5%", pad=0.1)
gdf.plot('DOlt5', ax=axs, cax=cax, legend=True,vmin=1, vmax=150)
cax.set_ylabel(f'Days with DO < {threshold}[mg/l]',fontsize=14)
axs.set(yticklabels='', xticklabels='')
#axs.set_title(f'min DO (day={time_index})\nall levels', fontsize=fs_t)
#plt.savefig(gdir/f'SOGNB_minDO_lt{threshold}.jpeg',dpi=150)
plt.show()