# Compute climate indices
**Please clean all outputs when committing all the changes in git.**

In [None]:
%matplotlib inline  
import os, sys
import xarray as xr
import geopandas as gpd
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import scripts.clim_indices as ci
import scripts.utility as util

print("\nThe Python version: %s.%s.%s" % sys.version_info[:3])

-------------------------
## 1. Setup <a id='setup'></a>

In [None]:
setup = util.load_yaml("./climate_index_setup.yaml")

catchment_name = setup["catchment_name"]  # Catchment case: CONUS_HUC12 or camels
serial = setup["serial"] # True: dask distributed enabled
saveCSV = setup["saveCSV"] # True: save attributes for each HRU in csv
saveNetCDF = setup["saveNetCDF"] # True: save attributes for each HRU in netcdf
remap = setup["remap"] # True: remap meteorological time series to HRUs

# files and directories
src_dir = setup["src_dir"]
catch_gpkg = setup["catch_gpkg"]
mapping_file = setup["mapping_file"]
aggregate_daily = setup["aggregate_daily"]

# climate variable meta
variables = setup["climate_vars"]
nc_var = [meta['name'] for var, meta in variables.items()]
remap_variables = variables

# catchment meta
catch_attrs = setup["catch_attrs"]

print('-- Setup:')
print(f" Dask not enabled: {serial}")
print(f" catchment_name: {catchment_name}")
print(f" saveCSV: {saveCSV}")
print(f" saveNetCDF: {saveNetCDF}")
print(f" remap: {remap}")
print(f" climate data directory: {src_dir}")
print(f" catchment gpkg: {catch_gpkg}")
print(f" mapping file: {mapping_file}")

## dask 

In [None]:
client = None

if not serial:
    from dask.distributed import Client
    from dask_jobqueue import PBSCluster

    cluster = PBSCluster(
        cores=1,
        processes=1,
        memory="50GB",
        queue="casper",
        walltime="00:30:00",
    )
    cluster.scale(jobs=15)
    client = Client(cluster)
    
client

# Reading climate data

In [None]:
def preprocess(ds):
    ds = ds[nc_var]
    for var, meta in variables.items():
        ds[meta['name']] = ds[meta['name']]*meta['scale']+meta['offset']
    return ds

In [None]:
%%time
print(f'Reading climate data')
a = xr.open_mfdataset(os.path.join(src_dir, f'*.nc'), data_vars='minimal', preprocess=preprocess, parallel=True)    # WARNING: read all the netcdfs!!
if aggregate_daily:
    a = a.resample(time='D').mean()
    a[variables['hru_id']['name']] = a[variables['hru_id']['name']].isel(time=0, drop=True)
a = a.load()

## Re-mapping
- remapping 7 climate variables

**TODO: revise yaml file to incorporate reading gridded climate data. currently trying to read hru assuming the data is remapped**

In [None]:
%%time
if remap:
    a = util.regrid_mean_timeSeries(xr.open_dataset(mapping_file), a, 
                                    xr.where(np.isnan(a[variables['tair']['name']].isel(time=0)),0,1), 
                                    list(variables.keys()))

In [None]:
a = a.assign_coords(hru=a[variables['hru_id']['name']].astype(np.int64))

## Computing climate indices

In [None]:
%%time
pe = ci.Penman(a[variables['sw']['name']], a[variables['lw']['name']], a[variables['wind']['name']], a[variables['tair']['name']], a[variables['q']['name']], a[variables['p']['name']])

In [None]:
b = a[variables['precp']['name']].mean(dim='time').to_dataset(name='p_mean')
b['pe_mean'] = pe.mean(dim='time')

In [None]:
%%time
b = xr.merge([b, ci.seasonality_index(a[variables['tair']['name']], a[variables['precp']['name']])])

In [None]:
b['aridity'] = pe.mean(dim='time')/a[variables['precp']['name']].mean(dim='time')

In [None]:
%%time
ds1 = ci.high_p_freq_dur(a[variables['precp']['name']]) #, dayofyear='calendar'
b['high_prec_freq'] = ds1['high_prec_freq'].mean(dim='year')
b['high_prec_dur'] = ds1['high_prec_dur'].mean(dim='year')
# Apply the mode function along the 'year' dimension
b['high_prec_timing'] = xr.apply_ufunc(
    util.mode_func,
    ds1['high_prec_timing'],
    input_core_dims=[['year']],   # Specify the dimension along which to apply the function
    vectorize=True
)

In [None]:
%%time
ds2 = ci.low_p_freq_dur(a[variables['precp']['name']])
b['low_prec_dur'] = ds2['low_prec_dur'].mean(dim='year')
b['low_prec_freq'] = ds2['low_prec_freq'].mean(dim='year')
b['low_prec_timing'] = xr.apply_ufunc(
    util.mode_func,
    ds2['low_prec_timing'],
    input_core_dims=[['year']],   # Specify the dimension along which to apply the function
    vectorize=True
)

## Dataset to Dataframe

In [None]:
df = b.to_dataframe()

## Save in csv or netcdf

In [None]:
if saveCSV:
    df.to_csv(os.path.join('output', f'{catchment_name}_clim_test.csv'), float_format='%g')
if saveNetCDF:
    b.to_netcdf(os.path.join('output', f'{catchment_name}_clim_test.nc'))

## Plotting

In [None]:
# camels shapefile
gdf_camels = util.read_shps([catch_gpkg],[catch_attrs[catchment_name]['id']])

In [None]:
gdf_camels = gdf_camels.merge(df,left_on=catch_attrs[catchment_name]['id'], right_index=True)

In [None]:
var_name = 'p_mean'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_camels.plot(ax=ax, column=var_name, cmap='turbo',
                norm=colors.Normalize(0, 5),
                legend=True
               );
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_clim_{var_name}.png', dpi=300)

In [None]:
var_name = 'pe_mean'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_camels.plot(ax=ax, column=var_name, cmap='turbo',
                norm=colors.Normalize(0, 5),
                legend=True
               );
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_clim_{var_name}.png', dpi=300)

In [None]:
var_name = 'p_seasonality'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_camels.plot(ax=ax, column=var_name, cmap='turbo',
                norm=colors.Normalize(-1, 1),
                legend=True
               );
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_clim_{var_name}.png', dpi=300)

In [None]:
var_name = 'snow_frac'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_camels.plot(ax=ax, column=var_name, cmap='turbo', 
                norm=colors.Normalize(0, 0.6),
                legend=True
);
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_clim_{var_name}.png', dpi=300)

In [None]:
var_name = 'aridity'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_camels.plot(ax=ax, column=var_name, cmap='turbo', 
                norm=colors.Normalize(0.20, 3.0),
                legend=True
);
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_clim_{var_name}.png', dpi=300)

In [None]:
var_name = 'high_prec_dur'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_camels.plot(ax=ax, column=var_name, cmap='turbo', 
                norm=colors.Normalize(1.0, 1.8),
                legend=True
);
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_clim_{var_name}_new.png', dpi=300)

In [None]:
var_name = 'high_prec_freq'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_camels.plot(ax=ax, column=var_name, cmap='turbo', 
                norm=colors.Normalize(5, 25),
                legend=True
);
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_clim_{var_name}_new.png', dpi=300)

In [None]:
var_name = 'low_prec_dur'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_camels.plot(ax=ax, column=var_name, cmap='turbo', 
                norm=colors.Normalize(1.0, 30),
                legend=True
);
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_clim_{var_name}_new.png', dpi=300)

In [None]:
var_name = 'low_prec_freq'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_camels.plot(ax=ax, column=var_name, cmap='turbo', 
                norm=colors.Normalize(200, 365),
                legend=True
);
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_clim_{var_name}_new.png', dpi=300)