# FATES SP LH analysis

In [None]:
import os
import xarray as xr
import pandas as pd
import dask
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from dask_jobqueue import PBSCluster
from dask.distributed import Client
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.util import add_cyclic_point
import copy
import functools

## PBS Cluster Setup

In [None]:
# Setup PBSCluster
cluster = PBSCluster(
    cores=1,                                      # The number of cores you want
    memory='25GB',                                # Amount of memory
    processes=1,                                  # How many processes
    queue='casper',                               # The type of queue to utilize (/glade/u/apps/dav/opt/usr/bin/execcasper)
    local_directory='/glade/work/afoster',        # Use your local directory
    resource_spec='select=1:ncpus=1:mem=25GB',    # Specify resources
    project='P93300041',                          # Input your project ID here
    walltime='04:00:00',                          # Amount of wall time
    interface='ext',                              # Interface to use
)

In [None]:
cluster.scale(30)

In [None]:
dask.config.set({'distributed.dashboard.link':'https://jupyterhub.hpc.ucar.edu/stable/user/{USER}/proxy/{port}/status'})

In [None]:
client = Client(cluster)

## Helper Functions

In [None]:
def find_files(fstring, topdir, paramkey):
    
    params = sorted(np.unique(paramkey.param))
    
    all_files = []
    for param in params:
        keys = paramkey[paramkey.param == param]['key'].values
        if len(keys) == 1:
            keys = [keys[0], 'FATES_OAAT_000']
        files = [os.path.join(topdir, f"{fstring}_{key}.nc") for key in keys]
        all_files.append(files)
    
    return all_files

In [None]:
def get_ensemble(files, whittaker_ds):

    ## read in dataset and attach other info
    ds = xr.open_mfdataset(files, combine='nested', concat_dim=['param', 'minmax'], parallel=True)
    
    ds['biome'] = whittaker_ds.biome
    ds['biome_name'] = whittaker_ds.biome_name
    
    return ds

In [None]:
def get_map(ds, da):
    
    thedir  = '/glade/u/home/forrest/ppe_representativeness/output_v4/'
    thefile = 'clusters.clm51_PPEn02ctsm51d021_2deg_GSWP3V1_leafbiomassesai_PPE3_hist.annual+sd.400.nc'
    sg = xr.open_dataset(thedir+thefile)
    
    ds = ds.isel(param=0).isel(minmax=0)
    out = np.zeros(sg.cclass.shape) + np.nan
    for c,(o,a) in enumerate(sg.rcent_coords):
        i = np.arange(400)[
            (abs(ds.grid1d_lat - a) < 0.1) &
            (abs(ds.grid1d_lon - o) < 0.1)]
        out[sg.cclass == c + 1] = i
    cclass = out.copy()
    cclass[np.isnan(out)] = 0

    sgmap = xr.Dataset()
    sgmap['cclass'] = xr.DataArray(cclass.astype(int), dims=['lat','lon'])
    sgmap['notnan'] = xr.DataArray(~np.isnan(out), dims=['lat','lon'])
    sgmap['lat'] = sg.lat
    sgmap['lon'] = sg.lon
    
    damap = da.sel(gridcell=sgmap.cclass).where(sgmap.notnan).compute()
    
    return damap

In [None]:
def plot_fig(da, vmin, vmax, cmap_name, cbar_title, plot_title):
    
    fig = plt.figure(figsize=(13, 6))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson())
    ax.set_extent([-180,180,-56,85], crs=ccrs.PlateCarree())
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.OCEAN, facecolor='#CCFEFF')
    ax.add_feature(cfeature.LAKES, facecolor='#CCFEFF')
    ax.add_feature(cfeature.LAND, facecolor='lightgray')
    ax.add_feature(cfeature.RIVERS, edgecolor='#CCFEFF')

    cmap = copy.copy(plt.get_cmap(cmap_name, 21))
    cf = ax.pcolormesh(da.lon, da.lat, da.values,
                       vmin=vmin, vmax=vmax,
                       transform=ccrs.PlateCarree(), cmap=cmap);

    cb = plt.colorbar(cf)
    cb.ax.set_title(cbar_title)
    ax.set_title(plot_title);

In [None]:
def month_wts(nyears):
    days_pm  = [31,28,31,30,31,30,31,31,30,31,30,31]
    return xr.DataArray(np.tile(days_pm,nyears),dims='time')

In [None]:
def top_n(da, nx):
    ''' return top_n by param effect '''
    dx = abs(da.sel(minmax='max') - da.sel(minmax='min'))
    ix = dx.argsort()[-nx:].values
    x = da.isel(param=ix)
    return x

In [None]:
def ensemble_means(ds, data_var, domain, cfs, land_area):
    
    annual_means = annual_mean(area_mean(ds, data_var, domain, cfs, land_area))
    
    # average/iav
    average_vals  = annual_means.mean(dim='year') 
    interannual_mean = annual_means.std(dim='year')
    
    # save the reduced data
    out = xr.Dataset()
    out[f'{data_var}_mean'] = average_vals
    out[f'{data_var}_mean'].attrs= {'units':units[data_var],
                                 'long_name':ds[data_var].attrs['long_name']}
    out[f'{data_var}_iav']  = interannual_mean
    out[f'{data_var}_iav'].attrs= {'units':units[data_var],
                                'long_name':ds[data_var].attrs['long_name']}
    out['param']  = ds.param
    out['minmax'] = ds.minmax
    
    return out

In [None]:
def rank_plot(da, xdef, nx):

    x = top_n(da, nx)
    
    fig=plt.figure()
    ax=fig.add_subplot()
    ax.plot([xdef, xdef], [0, nx-1], 'k:', label='default')
    ax.scatter(x.sel(minmax='min'), range(nx), marker='o', facecolors='none', edgecolors='r', label='low-val')
    ax.plot(x.sel(minmax='max'),range(nx),'ro',label='high-val')

    params = x.param
    for i in range(nx):
        xsel = x.sel(param=params[i])
        xsel_min = xsel.sel(minmax='min')
        xsel_max = xsel.sel(minmax='max')
        ax.plot([xsel_min, xsel_max], [i, i], 'r')
    ax.set_yticks(range(nx))
    ax.set_yticklabels([p for p in x.param.values]);

In [None]:
def annual_mean(da):
    cf1, cf2 = cfs[da.name].values()
    
    days_per_month = da['time.daysinmonth']
    ann_mean = cf1*(days_per_month*da).groupby('time.year').sum().compute()
    ann_mean.name = da.name
    return ann_mean

In [None]:
def area_mean(ds, data_var, domain, cfs, land_area):
    '''
    Calculate area mean for data_var across gridcells, either globally or by biome
    ds:        dataset
    data_var:  data variable
    domain:   'global' or 'biome'
    cfs:       unit conversion factors
    land_area: land area dataset
    '''
    
    ## update conversion factor if need be
    cf1, cf2 = cfs[data_var].values()
    if cf2 == 'intrinsic':
        if domain == 'global':
            cf2 = 1/land_area.sum()
        else:
            cf2 = 1/land_area.groupby(ds.biome).sum()
            
    # weight by landarea
    area_weighted = land_area*ds[data_var]
            
    # sort out domain groupings
    area_weighted['biome'] = ds.biome
    area_weighted = area_weighted.swap_dims({'gridcell':'biome'})
    if domain =='global': 
        grid = 1+0*area_weighted.biome  #every gridcell is in biome 1
    else: 
        grid = area_weighted.biome
    
    # calculate area mean
    area_mean = cf2*area_weighted.groupby(grid).sum()
    
    if domain =='global': 
        area_mean = area_mean.mean(dim='biome')  # get rid of gridcell dimension 
        
    area_mean.name = data_var
        
    return area_mean

In [None]:
def get_all_vars(data_vars, ds, cfs, land_area, domain):
    all_vars = []
    for data_var in data_vars:
        means = ensemble_means(ds, data_var, domain, cfs, land_area)
        all_vars.append(means)
    
    all_means = xr.merge(all_vars)
    
    df = all_means.to_dataframe()
    
    return df

In [None]:
def plot_param_effect(ds, var, parameter):
    ## get annual mean
    da = ds[var].sel(param=parameter)
    mean_da = annual_mean(da).mean(dim='year')
    
    # map to whole earth
    da_map = get_map(ds, mean_da)
    
    ## difference 
    dsDiff = da_map.isel(minmax=1) - da_map.isel(minmax=0)
    
    # get cmap limits
    vval = abs(dsDiff).max().values.round(1)
    
    ## plot difference
    plot_fig(dsDiff, -1*vval, vval, 'bwr_r', f"{var} [{units[var]}]", f"Effect of {parameter} on {var}")
    plt.savefig(f"figs/{var}_{parameter}.png")

## Parameter values and directory names

In [None]:
# fetch the parameter information, including parameter names and their key values
paramkey_file = '/glade/work/afoster/FATES_calibration/FATES_SP_OAAT/FATES_SP_OAAT_param_key.csv' 
paramkey = pd.read_csv(paramkey_file)
params = sorted(np.unique(paramkey.param))

In [None]:
# fetch the sparsegrid landarea - needed for unit conversion
land_area_file = '/glade/work/afoster/FATES_calibration/CLM5PPE/postp/sparsegrid_landarea.nc'
land_area = xr.open_dataset(land_area_file).landarea  #km2

In [None]:
## whittaker biomes
whit = xr.open_dataset('/glade/work/afoster/FATES_calibration/CLM5PPE/pyth/whit/whitkey.nc')

In [None]:
topdir = '/glade/work/afoster/FATES_calibration/FATES_SP_OAAT/hist'
fstring = "ctsm51FATES_SP_OAAT_SatPhen_derecho_2000"

In [None]:
data_vars = ['GPP', 'EFLX_LH_TOT', 'ASA', 'SOILWATER_10CM', 'FSH', 'Temp']

In [None]:
#conversion factors
cfs={'GPP': {'cf1':24*60*60,'cf2':1e-6},
    'EFLX_LH_TOT': {'cf1':1/2.5e6*24*60*60,'cf2':1e-9},
    'ASA': {'cf1':1/365,'cf2':'intrinsic'},
    'SOILWATER_10CM': {'cf1':1/365,'cf2':1e-9},
    'FSH': {'cf1':1/365,'cf2':'intrinsic'},
    'Temp': {'cf1':1/365,'cf2':'intrinsic'}}
units={'GPP':'PgC/yr',
      'EFLX_LH_TOT': 'TtH2O/yr',
      'ASA': '0-1',
      'SOILWATER_10CM': 'TtH2O',
      'FSH': 'W/m2',
      'Temp': 'degrees C'}

## Test all variables

In [None]:
files = find_files(fstring, topdir, paramkey)

In [None]:
ds = get_ensemble(files[1:], whit)

In [None]:
ds_def = xr.open_dataset(files[0][0])
ds_def['biome'] = whit.biome
ds_def['biome_name'] = whit.biome_name

In [None]:
global_df = get_all_vars(data_vars, ds, cfs, land_area, 'global')
global_df.to_csv('global_means.csv')

In [None]:
biome_df = get_all_vars(data_vars, ds, cfs, land_area, 'biome')
biome_df.to_csv('biome_means.csv')

In [None]:
global_def = get_all_vars(data_vars, ds_def, cfs, land_area, 'global')
biome_def = get_all_vars(data_vars, ds_def, cfs, land_area, 'biome')
global_def.to_csv('global_default.csv')
biome_def.to_csv('biome_default.csv')

## Do some mapping

In [None]:
params = ["fates_leaf_vcmax25top", "fates_leaf_theta_cj_c3", "fates_leaf_stomatal_intercept", "fates_rad_leaf_clumping_index",
          "fates_maintresp_leaf_atkin2017_baserate", "fates_allom_fnrt_prof_b", "fates_leaf_theta_cj_c4",
          "fates_rad_leaf_xl", "fates_rad_stem_rhovis", "fates_nonhydro_smpsc", "fates_leaf_stomatal_slope_medlyn", "fates_stoich_nitr",
          "fates_rad_leaf_rhonir", "fates_rad_leaf_taunir", "fates_turb_leaf_diameter", "fates_leaf_photo_temp_acclim_timescale",
          "fates_rad_stem_rhonir", "fates_nonhydro_smpso", "fates_rad_stem_tauvis", "fates_leaf_slatop", "fates_rad_leaf_rhovis",
          "fates_allom_fnrt_prof_a", "fates_rad_leaf_tauvis", "fates_turb_z0mr", "fates_turb_displar", "fates_rad_stem_taunir",
          "fates_allom_d2ca_coefficient_max", "fates_allom_crown_depth_frac", "fates_allom_d2bl2", "fates_allom_dbh_maxheight",
          "fates_leaf_photo_temp_acclim_thome_time", "fates_allom_d2h2", "fates_allom_blca_expnt_diff", "fates_leaf_slamax",
          "fates_allom_d2h1", "fates_allom_d2h3"] 

In [None]:
parameter = 'fates_leaf_vcmax25top'
dvars = ['GPP', 'EFLX_LH_TOT', 'SOILWATER_10CM', 'FSH']

In [None]:
params[23]

In [None]:
for parameter in params[23:]:
    for var in dvars:
        plot_param_effect(ds, var, parameter)