# FATES SP LH analysis

In [None]:
import os
import copy

import xarray as xr
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import dask
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import gpflow

from dask_jobqueue import PBSCluster
from dask.distributed import Client
from esem import gp_model
from sklearn.metrics import mean_squared_error
from scipy import stats

from SALib.sample import fast_sampler
from SALib.analyze import fast

## PBS Cluster Setup

In [None]:
# Setup PBSCluster
cluster = PBSCluster(
    cores=1,                                     # The number of cores you want
    memory='25GB',                               # Amount of memory
    processes=1,                                 # How many processes
    queue='casper',                              # The type of queue to utilize
    local_directory='/glade/work/afoster',       # Use your local directory
    resource_spec='select=1:ncpus=1:mem=25GB',   # Specify resources
    project='P93300041',                         # Input your project ID here
    walltime='04:00:00',                         # Amount of wall time
    interface='ext',                             # Interface to use
)

In [None]:
cluster.scale(30)

In [None]:
client = Client(cluster)

## Helper Functions

In [None]:
def get_ensemble(files, whittaker_ds):

    # read in dataset and attach other info
    ds = xr.open_mfdataset(files, combine='nested', concat_dim='ensemble',
                           parallel=True, chunks = {'time': 60, 'ensemble': 100,
                                                    'gridcell': 200})

    ds['biome'] = whittaker_ds.biome
    ds['biome_name'] = whittaker_ds.biome_name

    return ds

In [None]:
def annual_mean(da):

    cf1, cf2 = cfs[da.name].values()

    days_per_month = da['time.daysinmonth']
    ann_mean = cf1*(days_per_month*da).groupby('time.year').sum()
    ann_mean.name = da.name
    return ann_mean

In [None]:
def month_wts(nyears):

    days_pm = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]

    return xr.DataArray(np.tile(days_pm, nyears), dims='time')

In [None]:
def area_mean(ds, data_var, domain, cfs, land_area):
    '''
    Calculate area mean for data_var across gridcells, either globally or by biome
    ds:        dataset
    data_var:  data variable
    domain:   'global' or 'biome'
    cfs:       unit conversion factors
    land_area: land area dataset
    '''

    # update conversion factor if need be
    cf1, cf2 = cfs[data_var].values()
    if cf2 == 'intrinsic':
        if domain == 'global':
            cf2 = 1/land_area.sum()
        else:
            cf2 = 1/land_area.groupby(ds.biome).sum()

    # weight by landarea
    area_weighted = land_area*ds[data_var]

    # sort out domain groupings
    area_weighted['biome'] = ds.biome
    area_weighted = area_weighted.swap_dims({'gridcell': 'biome'})
    if domain == 'global':
        # every gridcell is in biome 1
        grid = 1+0*area_weighted.biome
    else:
        grid = area_weighted.biome

    # calculate area mean
    area_mean = cf2*area_weighted.groupby(grid).sum()

    if domain == 'global':
        # get rid of gridcell dimension
        area_mean = area_mean.mean(dim='biome')

    area_mean.name = data_var

    return area_mean

In [None]:
def get_map(ds, da):
    
    thedir = '/glade/u/home/forrest/ppe_representativeness/output_v4/'
    thefile = 'clusters.clm51_PPEn02ctsm51d021_2deg_GSWP3V1_leafbiomassesai_PPE3_hist.annual+sd.400.nc'
    sg = xr.open_dataset(thedir + thefile)

    out = np.zeros(sg.cclass.shape) + np.nan
    for c, (o, a) in enumerate(sg.rcent_coords):
        i = np.arange(400)[
            (abs(ds.grid1d_lat - a) < 0.1) &
            (abs(ds.grid1d_lon - o) < 0.1)]
        out[sg.cclass == c + 1] = i
    cclass = out.copy()
    cclass[np.isnan(out)] = 0

    sgmap = xr.Dataset()
    sgmap['cclass'] = xr.DataArray(cclass.astype(int), dims=['lat', 'lon'])
    sgmap['notnan'] = xr.DataArray(~np.isnan(out), dims=['lat', 'lon'])
    sgmap['lat'] = sg.lat
    sgmap['lon'] = sg.lon
    
    damap = da.sel(gridcell=sgmap.cclass).where(sgmap.notnan).compute()
    
    return damap

In [None]:
def normalize(var):
    return (var - min(var))/(max(var) - min(var))

In [None]:
def unnormalize(norm_var, raw_var):
    return norm_var*np.array(max(raw_var) - min(raw_var)) + np.array(min(raw_var))

In [None]:
def split_dataset(var, params, n_test):

    # target variable (excluding default [0])
    Y = var[1:].values

    # test and training parameters
    X_test, X_train = params[:n_test], params[n_test:]

    # test and training output
    y_test, y_train = Y[:n_test], Y[n_test:]

    return X_test, X_train, y_test, y_train

In [None]:
def train_emulator(num_params, X_train, y_train):

    # create kernel
    kernel_linear = gpflow.kernels.Linear(active_dims=range(num_params),
                                          variance=1)
    kernel_matern32 = gpflow.kernels.Matern32(active_dims=range(num_params),
                                              variance=1,
                                              lengthscales=np.tile(1, num_params))
    kernel = kernel_linear + kernel_matern32

    # define emulator model and train
    emulator = gp_model(np.array(X_train), np.array(y_train), kernel=kernel)
    emulator.train()

    return emulator

In [None]:
def fourier_sensitivity(emulator):

    # fourier amplitude sensitivity test w/emulator
    problem = {
        'names': ppe_params.columns,
        'num_vars': num_params,
        'bounds': [[0, 1]],
    }

    sample = fast_sampler.sample(problem, 1000, M=4, seed=None)
    Y, _ = emulator.predict(sample)
    FAST = fast.analyze(problem, Y, M=4, num_resamples=100, conf_level=0.95,
                        print_to_console=False, seed=None)
    sens = pd.DataFrame.from_dict(FAST)
    sens.index = sens.names
    df_sens = sens.sort_values(by=['S1'], ascending=False)

    return df_sens

In [None]:
def oaat_sens(ppe_params, emulator, default):

    num_params = len(ppe_params.columns)

    # hold all parameters at median value
    n = 21
    unif = pd.concat([pd.DataFrame(np.tile(0.5, n))]*num_params, axis=1)
    unif.columns = ppe_params.columns
    s = np.linspace(0, 1, n)

    sample = unif
    plt.figure(figsize=[18, 16])
    for i, p in enumerate(ppe_params.columns):
        sample[p] = s
        oaat, v = emulator.predict(sample.values)
        sample[p] = np.tile(0.5, n)  # set column back to median

        ax = plt.subplot(7, 5, i + 1)
        ax.fill_between(s, oaat - 3.0*v**0.5, oaat + 3.0*v**0.5, color='peru',
                        alpha=0.4)  # shade three standard deviations
        ax.plot(s, oaat, c='k')
        ax.set_xlabel(p)
    plt.tight_layout()

In [None]:
def adjust_lon(ds, lon_name):

    # Adjust lon values to make sure they are within (-180, 180)
    ds['_longitude_adjusted'] = xr.where(
        ds[lon_name] > 180,
        ds[lon_name] - 360,
        ds[lon_name])

    # reassign the new coords to as the main lon coords
    # and sort DataArray using new coordinate values
    ds = (
        ds
        .swap_dims({lon_name: '_longitude_adjusted'})
        .sel(**{'_longitude_adjusted': sorted(ds._longitude_adjusted)})
        .drop_vars(lon_name))

    ds = ds.rename({'_longitude_adjusted': lon_name})
    
    return ds

In [None]:
def ooat_sens(ppe_params, emulator, dir, var):
    num_params = len(ppe_params.columns)

    # hold all parameters at median value
    n = 50
    unif = pd.concat([pd.DataFrame(np.tile(0.5, n))]*num_params, axis=1)
    unif.columns = ppe_params.columns
    s = np.linspace(0, 1, n)
    param = np.array([])
    oaats = np.array([])
    vars = np.array([])
    samps = np.array([])
    sample = unif
    for i, p in enumerate(ppe_params.columns):
        sample[p] = s
        oaat, v = emulator.predict(sample.values)
        sample[p] = np.tile(0.5, n)  # set column back to median
        oaats = np.append(oaats, oaat)
        vars = np.append(vars, v)
        param = np.append(param, np.repeat(p, n))
        samps = np.append(samps, s)
    df = {'sample': samps,
      'predict': oaats,
      'variance': vars,
      'parameter': param}
    dataf = pd.DataFrame(df)
    dataf.to_csv(f'{dir}/{var}_oaat_global.csv')

## Directory Names and Conversion Factors

In [None]:
# fetch the sparsegrid landarea - needed for unit conversion
land_area_file = '/glade/work/afoster/FATES_calibration/CLM5PPE/postp/sparsegrid_landarea.nc'
land_area_dat = xr.open_dataset(land_area_file)
land_area = land_area_dat.landarea  # km2

In [None]:
# whittaker biomes
whit = xr.open_dataset('/glade/work/afoster/FATES_calibration/CLM5PPE/pyth/whit/whitkey.nc')

In [None]:
topdir = '/glade/work/afoster/FATES_calibration/FATES_SP_LH/hist'

In [None]:
# conversion factors
cfs = {'GPP': {'cf1': 24*60*60, 'cf2': 1e-6},
       'EFLX_LH_TOT': {'cf1': 1/2.5e6*24*60*60, 'cf2': 1e-9},
       'ASA': {'cf1': 1/365, 'cf2': 'intrinsic'},
       'SOILWATER_10CM': {'cf1': 1/365, 'cf2': 1e-9},
       'FSH': {'cf1': 1/365, 'cf2': 'intrinsic'},
       'Temp': {'cf1': 1/365, 'cf2': 'intrinsic'}}
units = {'GPP': 'PgC/yr',
         'EFLX_LH_TOT': 'TtH2O/yr',
         'ASA': '0-1',
         'SOILWATER_10CM': 'TtH2O',
         'FSH': 'W/m2',
         'Temp': 'degrees C'}

In [None]:
global_obs = xr.open_dataset('ILAMB_global_obs.nc')

## Read in ensemble

In [None]:
files = sorted([os.path.join(topdir, file) for file in os.listdir(topdir)])
ds = get_ensemble(files, whit)

In [None]:
#lhckey = '/glade/u/home/afoster/CLM_PPE_FATES/FATES_SP_Calib/lh_key.csv'
lhckey = '/glade/work/afoster/FATES_calibration/FATES_SP_LH/lh_key_300.csv'
df = pd.read_csv(lhckey)

In [None]:
ppe_params = df #z.drop(columns=['ensemble'])
num_params = len(ppe_params.columns)

In [None]:
n_test = 50

## Global Annual GPP

In [None]:
# global annual mean (PgC/yr)
gpp_glob = annual_mean(area_mean(ds, 'GPP', 'global', cfs, land_area)).mean(dim='year')

In [None]:
# split dataset into training/testing
params_test, params_train, gpp_test, gpp_train = split_dataset(gpp_glob, ppe_params, n_test)

In [None]:
# train emulator
gpp_em = train_emulator(num_params, params_train, gpp_train)

In [None]:
# predict test points with emulator
gpp_pred, gpp_pred_var = gpp_em.predict(params_test)
st_dev = gpp_pred_var.flatten()**0.5
rms = mean_squared_error(gpp_test, gpp_pred, squared=False)

In [None]:
df = {'gpp_test': gpp_test,
     'gpp_pred': gpp_pred,
     'gpp_var': gpp_pred_var}
gpp_test_pred = pd.DataFrame(df)
gpp_test_pred.to_csv('LH_output_global/gpp_test_pred.csv')

In [None]:
gpp_sens = fourier_sensitivity(gpp_em)

In [None]:
gpp_sens.to_csv('LH_output_global/gpp_sensitivity.csv')

In [None]:
oaat_sens(ppe_params, gpp_em, 0.0)

## Global Annual Evapotranspiration

In [None]:
# global annual mean (TtH2O/yr)
et_glob = annual_mean(area_mean(ds, 'EFLX_LH_TOT', 'global', cfs, land_area)).mean(dim='year')

# default et
#default_et = et_glob[0].values

In [None]:
# split dataset into training/testing
params_test, params_train, et_test, et_train = split_dataset(et_glob, ppe_params, n_test)

# train emulator
et_em = train_emulator(num_params, params_train, et_train)

# predict test points with emulator
et_pred, et_pred_var = et_em.predict(params_test)
st_dev = et_pred_var.flatten()**0.5
rms = mean_squared_error(et_test, et_pred, squared=False)

In [None]:
df = {'et_test': et_test,
      'et_pred': et_pred,
      'et_var': et_pred_var}
et_test_pred = pd.DataFrame(df)
et_test_pred.to_csv('LH_output_global/et_test_pred.csv')

In [None]:
ooat_sens(ppe_params, et_em, 'LH_output', 'et')

In [None]:
et_sens = fourier_sensitivity(et_em)
et_sens.to_csv('LH_output_global/et_sensitivity.csv')

## Global Annual Sensible Heat

In [None]:
# global annual mean (W/m2)
h_glob = annual_mean(area_mean(ds, 'FSH', 'global', cfs, land_area)).mean(dim='year')

In [None]:
# split dataset into training/testing
params_test, params_train, h_test, h_train = split_dataset(h_glob, ppe_params, n_test)

# train emulator
h_em = train_emulator(num_params, params_train, h_train)

# predict test points with emulator
h_pred, h_pred_var = h_em.predict(params_test)
st_dev = h_pred_var.flatten()**0.5
rms = mean_squared_error(h_test, h_pred, squared=False)

In [None]:
# plot predicted values
plt.scatter(h_test, h_pred)
plt.plot([min(h_test), max(h_test)], [min(h_test), max(h_test)], c='k',
         linestyle='--', label='1:1 line')
plt.errorbar(h_test.flatten(), h_pred.flatten(), yerr=2*st_dev, fmt="o")
plt.text(min(h_test), max(h_test), 'RMSE = {}'.format(np.round(rms, 3)))
plt.xlabel('FATES global mean annual mean LH (W m$^{-2}$)')
plt.ylabel('Emulated global mean annual mean LH (W m$^{-2}$)')
plt.legend(loc='lower right') 

In [None]:
oaat_sens(ppe_params, h_em, default_h)

In [None]:
h_sens = fourier_sensitivity(h_em)