# Vegetation idices
**Please clean all outputs when committing all the changes in git.**
- use MODIS

In [None]:
%matplotlib inline  
import os, sys
import xarray as xr
import geopandas as gpd
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from scipy.optimize import fsolve

import scripts.utility as util

print("\nThe Python version: %s.%s.%s" % sys.version_info[:3])

-------------------------
## 1. Setup <a id='setup'></a>

In [None]:
setup = util.load_yaml("./vege_index_setup.yaml")

catchment_name = setup["catchment_name"]  # Catchment case: CONUS_HUC12, camels, camelsx-v3
serial = setup["serial"] # True: dask distributed enabled
saveCSV = setup["saveCSV"] # True: save attributes for each HRU in csv
saveNetCDF = setup["saveNetCDF"] # True: save attributes for each HRU in netcdf
remap = setup["remap"] # True: remap meteorological time series to HRUs

# files and directories
src_dir = setup["src_dir"]
catch_gpkg = setup["catch_gpkg"]
mapping_file = setup["mapping_file"]

# vegetation data meta
nc_vege_var = setup['nc_var']

IGBP = setup['IGBP']

# catchment meta
catch_attrs = setup["catch_attrs"]

print('-- Setup:')
print(f" Dask not enabled: {serial}")
print(f" catchment_name: {catchment_name}")
print(f" saveCSV: {saveCSV}")
print(f" saveNetCDF: {saveNetCDF}")
print(f" remap: {remap}")
print(f" vegetation data directory: {src_dir}")
print(f" catchment gpkg: {catch_gpkg}")
print(f" mapping_file: {mapping_file}")

## Reading hru geopackage

In [None]:
%%time
gdf_catch = util.read_shps([catch_gpkg], [catch_attrs[catchment_name]['id']])
gdf_catch.set_crs(epsg=4326, inplace=True)
gdf_catch = gdf_catch.rename(columns={catch_attrs[catchment_name]['id']:'hruId'})
gdf_catch = gdf_catch.set_index('hruId')

## Reading MODIS landcover and lai
- IGBG land cover class MCD12Q1 (aggregated 600m resolution)
- LAI MOD15A2 (aggregated 600m resolution)

In [None]:
%%time
print(f'Reading MODIS data')
ds = xr.open_dataset(os.path.join(src_dir, f'veg_NLDAS_float.nc')).load()
ds['landcover'] = ds['landcover'].where(ds['landcover']>0, np.nan)
ds['forest']    = ds['landcover'].where((ds['landcover']>=1) & (ds['landcover']<=5), 0)
ds['forest']    = ds['forest'].where(ds['forest']==0, 1)

In [None]:
## copute root depth at 50 and 99 percentile based on IGBG class

def func(x, a, b, Y):
    return np.exp(a*x) + np.exp(b*x) + 2.0*(Y-1.0)

root_param = {
    1: {'a':6.706, 'b':2.175, 'dr':1.8, 'y05':np.nan, 'y99':np.nan},      # Evergreen needleleaf tree
    2: {'a':7.344, 'b':1.303, 'dr':3.0, 'y05':np.nan, 'y99':np.nan},      # Evergreen broadleaf tree
    3: {'a':7.066, 'b':1.953, 'dr':2.0, 'y05':np.nan, 'y99':np.nan},      # Deciduous needleleaf tree
    4: {'a':5.990, 'b':1.955, 'dr':2.0, 'y05':np.nan, 'y99':np.nan},      # Deciduous broadleaf tree
    5: {'a':4.453, 'b':1.631, 'dr':2.4, 'y05':np.nan, 'y99':np.nan},      # Mixed forest
    6: {'a':6.326, 'b':1.567, 'dr':2.5, 'y05':np.nan, 'y99':np.nan},      # Closed shrubland
    7: {'a':7.718, 'b':1.262, 'dr':3.1, 'y05':np.nan, 'y99':np.nan},      # Open shrubland
    8: {'a':7.604, 'b':2.300, 'dr':1.7, 'y05':np.nan, 'y99':np.nan},      # Woody Savanna
    9: {'a':8.235, 'b':1.627, 'dr':2.4, 'y05':np.nan, 'y99':np.nan},      # Savanna
    10:{'a':10.74, 'b':2.608, 'dr':1.5, 'y05':np.nan, 'y99':np.nan},      # Grassland
    11:{'a':np.nan, 'b':np.nan, 'dr':np.nan, 'y05':np.nan, 'y99':np.nan}, # Permanent wetland
    12:{'a':5.558, 'b':2.614, 'dr':1.5, 'y05':np.nan, 'y99':np.nan},      # Cropland
    13:{'a':5.558, 'b':2.614, 'dr':1.5, 'y05':np.nan, 'y99':np.nan},      # Urban and built-up land
    14:{'a':5.558, 'b':2.614, 'dr':1.5, 'y05':np.nan, 'y99':np.nan},      # Cropland/natural vegetation
    15:{'a':np.nan, 'b':np.nan, 'dr':np.nan, 'y05':np.nan, 'y99':np.nan}, # snow and ice
    16:{'a':4.372, 'b':0.978, 'dr':4.0, 'y05':np.nan, 'y99':np.nan},      # Barren
    17:{'a':np.nan, 'b':np.nan, 'dr':np.nan, 'y05':np.nan, 'y99':np.nan}, # water bodies
}

for lc in root_param.keys():
    if lc== 11 or lc==15 or lc==17:
        continue
    root_param[lc]['y05'] = fsolve(func, -0.1, args=(root_param[lc]['a'], root_param[lc]['b'], 0.5))[0]
    root_param[lc]['y99'] = fsolve(func, -0.1, args=(root_param[lc]['a'], root_param[lc]['b'], 0.99))[0]

In [None]:
%%time
ds['rd05'] = util.map_param(ds['landcover'], root_param, 'y05')
ds['rd99'] = util.map_param(ds['landcover'], root_param, 'y99')

In [None]:
ds

## Re-mapping
- Monthly LAI
- Landcover class
- Root depth

In [None]:
%%time
# monthly weighted average
dr_mask = xr.where(np.isnan(ds['lai'].isel(month=0)), 0, 1)
a1 = []
for mon in np.arange(12):
    a1.append(util.regrid_mean(xr.open_dataset(mapping_file), ds.isel(month=mon), dr_mask, ['lai'], verbose=False))
a = xr.concat(a1, dim="month")

In [None]:
# weighted average
dr_mask = xr.where(np.isnan(ds['rd05']), 0, 1)
a1 = util.regrid_mean(xr.open_dataset(mapping_file), ds, dr_mask, ['rd05', 'rd99'], verbose=False)
a = xr.merge([a, a1])

In [None]:
# Dominant class
dr_mask = xr.where(np.isnan(ds['landcover']), 0, 1)
a1 = util.regrid_mode(xr.open_dataset(mapping_file), ds, dr_mask, ['landcover'])
a = xr.merge([a, a1])

In [None]:
# Forest fraction
dr_mask = xr.where(np.isnan(ds['forest']), 0, 1)
a1 = util.regrid_mean(xr.open_dataset(mapping_file), ds, dr_mask, ['forest'])
a = xr.merge([a, a1['forest'].rename('forest_frac')])

In [None]:
a['lai_max'] = a['lai'].max(dim='month')
a['lai_diff'] = a['lai_max'] - a['lai'].min(dim='month')

In [None]:
a['1st_dominant_landcover'] = a['1st_dominant_landcover'].where(a['1st_dominant_landcover']!='N/A', -999).astype(float)
a['1st_dominant_landcover'] = a['1st_dominant_landcover'].where(~np.isnan(a['1st_dominant_landcover']), -999.0).astype(int)

In [None]:
a['2nd_dominant_landcover'] = a['2nd_dominant_landcover'].where(a['2nd_dominant_landcover']!='N/A', -999).astype(float)
a['2nd_dominant_landcover'] = a['2nd_dominant_landcover'].where(~np.isnan(a['2nd_dominant_landcover']), -999.0).astype(int)

In [None]:
a

## Fill missing values

In [None]:
# find hru with missing value - this is specific to vegetation data 
missing_huc = a['hru'].where(a['1st_dominant_landcover']==-999, drop=True).values

In [None]:
gdf_catch_proj = gdf_catch.to_crs('epsg:3785')
centroid = gdf_catch_proj.copy(deep=True)
centroid.geometry = centroid['geometry'].centroid

print('index missing_huc12, target_closest_huc12')
count=0
neighbor_huc12 = {}
for index, row in centroid.iterrows():
    if index not in missing_huc:
        continue
    count+=1
    geo_ix = gdf_catch_proj.geometry.distance(row.geometry).sort_values().index[1:50]
    for ix in geo_ix:
        if ix not in missing_huc:
            neighbor_huc12[index] = ix
            break
    if count%10==0:
        print(count, index, neighbor_huc12[index])

In [None]:
veg_vars = ['1st_dominant_landcover', '1st_dominant_landcover_fraction','forest_frac','rd05', 'rd99']
for huc12_missing, huc12_target in neighbor_huc12.items():
    for var in veg_vars:
        a[var].loc[dict(hru=huc12_missing)] = a[var].loc[dict(hru=huc12_target)].values

## Dataset to Dataframe

In [None]:
var_list = ['forest_frac','lai_max','lai_diff','1st_dominant_landcover','1st_dominant_landcover_fraction','rd05','rd99']
df = a[var_list].to_dataframe()

## Save in csv or netcdf

In [None]:
if saveCSV:
    df.to_csv(os.path.join('output', f'{catchment_name}_veg.csv'), float_format='%g')
if saveNetCDF:
    a.to_netcdf(os.path.join('output', f'{catchment_name}_veg.nc'))

## Plotting

In [None]:
gdf_catch = gdf_catch.merge(df,left_index=True, right_index=True)

In [None]:
df = a[['lai']].to_dataframe().reset_index()
df_pivot = df.pivot(index='hru', columns='month', values='lai')
df_pivot.reset_index(inplace=True)
df_pivot.columns.name = None  # Remove the name of the column index
df_pivot = df_pivot.rename_axis(None, axis=1)  # Remove axis labels
for mon in np.arange(12):
    df_pivot.rename(columns={mon:f'lai{mon+1}'}, inplace=True)

In [None]:
gdf_catch = gdf_catch.merge(df_pivot, left_index=True, right_on='hru')

In [None]:
var_name = 'rd05'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
#gdf_catch.plot(ax=ax, color='white', edgecolor='black', lw=0.2)
gdf_catch.plot(ax=ax, column=var_name, cmap='turbo',
                norm=colors.Normalize(-0.25, -0.1),
                legend=True
               );
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_veg_{var_name}.png', dpi=300)

In [None]:
var_name = 'rd99'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_catch.plot(ax=ax, column=var_name, cmap='turbo',
                norm=colors.Normalize(-3, -1),
                legend=True
               );
ax.set_title(var_name);
plt.savefig(f'./figures/{catchment_name}_veg_{var_name}.png', dpi=300)

In [None]:
var_name = '1st_dominant_landcover'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_catch.plot(ax=ax, column=var_name, cmap='turbo',
                norm=colors.Normalize(1, 17),
                legend=True
               );
ax.set_title(var_name);
plt.savefig(f'./figures/{catchment_name}_veg_{var_name}.png', dpi=300)

In [None]:
var_name = 'lai_max'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_catch.plot(ax=ax, column=var_name, cmap='turbo',
                norm=colors.Normalize(0, 6),
                legend=True
               );
ax.set_title(var_name);
plt.savefig(f'./figures/{catchment_name}_veg_{var_name}.png', dpi=300)

In [None]:
var_name = 'lai_diff'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_catch.plot(ax=ax, column=var_name, cmap='turbo',
                norm=colors.Normalize(0, 6),
                legend=True
               );
ax.set_title(var_name);
plt.savefig(f'./figures/{catchment_name}_veg_{var_name}.png', dpi=300)

In [None]:
var_name = 'forest_frac'
fig, ax = plt.subplots(figsize=(6, 3), dpi=150)
gdf_catch.plot(ax=ax, column=var_name, cmap='turbo',
                norm=colors.Normalize(0, 1),
                legend=True
               );
ax.set_title(var_name);
fig.tight_layout()
plt.savefig(f'./figures/{catchment_name}_veg_{var_name}.png', dpi=300)