# 4) Run GPEs #

Code for running the ECO-FAST and ORG-FAST analysis using gaussian process regression. <br>
<br>
Inputs: <br>
Metrics_PFT.nc - created by notebook 3 and contains metrics for GPP, Stress, and Drought Sensitivity for the primary PFT of each grid cell (y vars in the ORG-FAST analysis) <br>
Metrics_GC.nc - created by notebook 3 and contains metrics for Drought Sensitivity for entire grid cell (y vars in the ECO-FAST analysis)<br>
Traits.nc - created by notebook 2 and contains the parameter values, both for individual PFTs and gridcell weighted means and coefficient of variations (y vars in the AST analysis)<br>

pft_id - a helper dataset for mapping patches to their correct PFT<br>
nonrepresentative_pftarea.nc - the land area for each PFT<br>
<br>
Ouputs:<br>
Metrics.nc - contains select output variables used for analysis<br>

## Load packages ##

In [1]:
import importlib.util
import os
import pickle
import warnings
from math import pi
import cartopy.crs as ccrs
import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr
from matplotlib.colors import LinearSegmentedColormap
from pypalettes import load_cmap
from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import permutation_importance
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.tools.tools import add_constant
from tqdm import tqdm
import gpflow
from esem import gp_model
from esem.sampler import MCMCSampler
from esem.utils import get_random_params, leave_one_out, prediction_within_ci
from SALib.analyze import fast
from SALib.sample import fast_sampler
from scipy import stats
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import tensorflow as tf

repo_dir = os.getcwd()

2025-07-16 11:28:10.395240: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-16 11:28:10.719309: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-16 11:28:10.719361: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-16 11:28:10.737761: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-16 11:28:10.778955: I tensorflow/core/platform/cpu_feature_guar

## Functions ##

In [18]:
def gpe(data, xvar, yvar):
    # two scalers: one for X, one for y
    scaler_x = StandardScaler()
    scaler_y = StandardScaler()

    # Data prep: aggregate over 'year' if present, then convert to DataFrame
    if 'year' in data.dims:
        subset = data.mean(dim='year', skipna=True)
    else:
        subset = data
    df = subset[xvar + [yvar]].to_dataframe().reset_index()
    df = df.replace([np.inf, -np.inf], np.nan)#.dropna()

    # preserve ensemble ID, drop gridcell
    ens = df['ens']
    df = df.drop(['gridcell','ens'], axis=1)

    # — Z-score X and y —
    X_all = pd.DataFrame(
        scaler_x.fit_transform(df[xvar]),
        columns=xvar
    )
    if yvar == 'Stress':
        y_all = df[[yvar]]
    else:
        y_all = pd.DataFrame(
            scaler_y.fit_transform(df[[yvar]]),
            columns=[yvar]
        )
    
    df_scaled = pd.concat([X_all, y_all], axis=1)
    df_scaled['ens'] = ens.values

    # split train and test by ensemble ID
    ids = list(range(1,501))
    train_ids, test_ids = train_test_split(ids, test_size=0.3, random_state=39)
    df_tr = df_scaled[df_scaled.ens.isin(train_ids)].dropna()
    df_te = df_scaled[df_scaled.ens.isin(test_ids)].dropna()

    X_train = df_tr[xvar].values
    y_train = df_tr[yvar].values.reshape(-1,1)
    X_test  = df_te[xvar].values
    y_test  = df_te[yvar].values

    # — Build & train the GP model —
    D = len(xvar)
    kernel = (
        gpflow.kernels.Linear(active_dims=range(D), variance=1.0)
        + gpflow.kernels.Matern32(active_dims=range(D),
                                 variance=1.0,
                                 lengthscales=np.ones(D))
    )
    model = gpflow.models.GPR((X_train, y_train), kernel=kernel)
    opt = gpflow.optimizers.Scipy()
    opt.minimize(model.training_loss,
                 model.trainable_variables,
                 options=dict(maxiter=500))

    # — Evaluate R²  —
    μ_tr, _ = model.predict_y(X_train)
    μ_te, _ = model.predict_y(X_test)
    r2_train = r2_score(y_train.flatten(), μ_tr.numpy().flatten())
    r2_test  = r2_score(y_test,      μ_te.numpy().flatten())

    # — FAST sensitivity  —
    xdata = pd.DataFrame(df_scaled.drop(columns=[yvar,'ens']), columns=xvar)
    bounds = [[xdata[c].quantile(0.1), xdata[c].quantile(0.9)+1e-6] for c in xvar]
    problem = {'names': xvar, 'num_vars': D, 'bounds': bounds}
    sample = fast_sampler.sample(problem, 1000, M=4)
    Y, _ = model.predict_f(sample)
    FASTres = fast.analyze(problem, Y.numpy().flatten(), M=4,
                           num_resamples=100, conf_level=0.95,
                           print_to_console=False)
    Si_df = pd.DataFrame.from_dict(FASTres).set_index('names')\
              .sort_values('S1', ascending=False)

    # — Direct GP slopes via analytic gradients —
    X_tf = tf.convert_to_tensor(df_scaled[xvar].values, dtype=tf.float64)
    with tf.GradientTape() as tape:
        tape.watch(X_tf)
        μ, _ = model.predict_f(X_tf)       # [N,1]
    grads = tape.batch_jacobian(μ, X_tf)   # [N,1,D]
    grads = tf.squeeze(grads, axis=1)      # [N,D]

    grads_xr = xr.DataArray(grads.numpy(), dims=['ens', 'trait'],
                        coords={'trait': xvar})

    X_tf_xr = xr.DataArray(X_tf.numpy(), dims=['ens', 'trait'],
                       coords={'trait': xvar})

    
    avg_slopes = tf.reduce_mean(grads, axis=0).numpy()

    slope_series = pd.Series(avg_slopes, index=xvar,
                             name='GP_avg_slope')\
                     .sort_values(ascending=False)

    return r2_test, r2_train, Si_df, slope_series, X_tf_xr, grads_xr

def gpe_gridcell(i, data, xvar, yvar):
    try:
        if i%25 == 0:
            print(i)
        xy = data.sel(gridcell=i)
        r2_test_gp, r2_train_gp, df_Si, df_slope, xr_Xtest, xr_PSlope = gpe(xy, xvar, yvar)
        return {
            'gridcell': i,
            'r2_test': r2_test_gp,
            'r2_train': r2_train_gp,
            'importances': df_Si,
            'PDP_Slope': df_slope,
            'XTest': xr_Xtest,
            'XSlope': xr_PSlope
        }
    except Exception as e:
        return {
            'gridcell': i,
            'r2_test': np.nan,
            'r2_train': np.nan,
            'importances': np.nan,
            'PDP_Slope': np.nan,
            'XTest': np.nan,
            'XSlope': np.nan
        }

def gpe_pft(i, data, xvars, yvar):
    try:
        if i%25 == 0:
            print(i)
        ds = data.sel(gridcell = i)
        prim_pft_gc = prim_pft.sel(gridcell = i).pft.values
        ds = ds.sel(pft = prim_pft_gc)

        if 'year' in ds[yvar].dims:
            ds = ds.mean(dim = 'year')

        r2_test_gp, r2_train_gp, df_Si, df_slope, xr_Xtest, xr_PSlope = gpe(ds, xvars, yvar)
        
        return {
            'gridcell': i,
            'r2_test': r2_test_gp,
            'r2_train': r2_train_gp,
            'importances': df_Si,
            'PDP_Slope': df_slope,
            'XTest': xr_Xtest,
            'XSlope': xr_PSlope
        }
    except Exception as e:
        return {
            'gridcell': i,
            'r2_test': np.nan,
            'r2_train': np.nan,
            'importances': np.nan,
            'PDP_Slope': np.nan,
            'XTest': np.nan,
            'XSlope': np.nan
        }
        
def gppp_slope(r2_train_list, r2_test_list, slope_list, importances_list_gp, X_tf_xr_list, grads_xr_list, xvar, gridcells):
    # Create a DataFrame filled with NaNs for handling missing entries
    nan_df = pd.DataFrame(np.nan, index=xvar, columns=['S1', 'ST', 'S1_conf', 'ST_conf'])
    
    # Convert each DataFrame to xarray and handle NaNs
    data_arrays = []
    for df in importances_list_gp:
        if isinstance(df, pd.DataFrame):
            data_array = df.to_xarray().rename({'names':'trait'})
        else:
            data_array = nan_df.to_xarray().rename({'index':'trait'})
        data_arrays.append(data_array)
    
    # Concatenate along a new dimension 'gridcell'
    combined = xr.concat(data_arrays, dim='gridcell')

    nan_df = pd.DataFrame(np.nan, index=xvar, columns=['slope'])
    slope_arrays = []
    for s in slope_list:
        if isinstance(s, pd.Series):
            # ensure name is 'slope' so to_xarray() yields a DataArray called 'slope'
            s2 = s.copy()
            s2.name = 'slope'
            da = s2.to_xarray().rename({'index': 'trait'}).rename('slope')
        else:
            da = nan_df.to_xarray().rename({'index': 'trait'}).slope
        slope_arrays.append(da)

    slope_combined = xr.concat(slope_arrays, dim='gridcell')
    combined['slope'] = slope_combined

    # Add X_tf_xr (input values) and grads_xr (GP slopes)
    nan_X = xr.DataArray(np.full((500, len(xvar)), np.nan), dims=['ens', 'trait'], coords={'trait': xvar})
    X_arrays, G_arrays = [], []
    for x, g in zip(X_tf_xr_list, grads_xr_list):
        X_arrays.append(x if isinstance(x, xr.DataArray) else nan_X)
        G_arrays.append(g if isinstance(g, xr.DataArray) else nan_X)
    X_combined = xr.concat(X_arrays, dim='gridcell')
    G_combined = xr.concat(G_arrays, dim='gridcell')
    combined['X'] = X_combined
    combined['grads'] = G_combined
    combined['r2_train'] = xr.DataArray(r2_train_list, dims=['gridcell'], coords={'gridcell': gridcells})
    combined['r2_test'] = xr.DataArray(r2_test_list, dims=['gridcell'], coords={'gridcell':gridcells})

    
    return combined

## Input Data ##

In [8]:
# trait data
trait_data = xr.open_dataset(repo_dir+'/input/Traits.nc')

### PFT Level ###

In [9]:
xtraits_PFT = ['kmax_Norm','leafcn_Norm','medlynslope_Norm','psi50_Norm','slatop_Norm','jmaxb0','jmaxb1','wc2wjb0']
prim_pft=xr.open_dataset(repo_dir+'/utils/primaryPFT.nc')

Metrics_PFT = xr.open_dataset(repo_dir+'/input/Metrics_PFT.nc')

inds_test_PFT = xr.merge([trait_data[xtraits_PFT], Metrics_PFT])
rename_dict = {var: var.replace("_Norm", "") for var in inds_test_PFT.data_vars}
inds_test_PFT = inds_test_PFT.rename(rename_dict)
xtraits_PFT = ['kmax','leafcn','medlynslope','psi50','slatop','jmaxb0','jmaxb1','wc2wjb0']
inds_test_PFT['Stress'] = 1 - inds_test_PFT['B']
inds_test_PFT

### Grid Cell Level ###

In [10]:
xtraits_GC = ['kmaxCV', 'leafcnCV', 'medlynslopeCV', 'psi50CV', 'slatopCV', 'kmaxMean', 'leafcnMean', 'medlynslopeMean', 'psi50Mean', 'slatopMean', 'jmaxb0','jmaxb1','wc2wjb0']
Metrics_GC = xr.open_dataset(repo_dir+'/input/Metrics_GC.nc')

inds_test_GC = xr.merge([trait_data[xtraits_GC], Metrics_GC])
inds_test_GC

## Run ORG-FAST Analysis ##

In [None]:
#GPP
results_list = []
for i in range(0,400):
    result = gpe_pft(i, inds_test_PFT, xtraits_PFT, 'GPP')
    results_list.append(result)

# Run some post processing
r2_train_list_gp = [r['r2_train'] for r in results_list]
r2_test_list_gp  = [r['r2_test'] for r in results_list]
importances_list_gp = [r['importances'] for r in results_list]
slope_list_gp = [r['PDP_Slope'] for r in results_list]
XTest_list_gp = [r['XTest'] for r in results_list]
XSlope_list_gp = [r['XSlope'] for r in results_list]

gp_output = gppp_slope(r2_train_list_gp, r2_test_list_gp, slope_list_gp, importances_list_gp, XTest_list_gp, XSlope_list_gp, xtraits_PFT, range(0,400))
gp_output.to_netcdf(repo_dir+'/output/ORGFAST_GPP.nc')

In [None]:
#Stress
results_list = []
for i in range(0,400):
    result = gpe_pft(i, inds_test_PFT, xtraits_PFT, 'Stress')
    results_list.append(result)

# Run some post processing
r2_train_list_gp = [r['r2_train'] for r in results_list]
r2_test_list_gp  = [r['r2_test'] for r in results_list]
importances_list_gp = [r['importances'] for r in results_list]
slope_list_gp = [r['PDP_Slope'] for r in results_list]
XTest_list_gp = [r['XTest'] for r in results_list]
XSlope_list_gp = [r['XSlope'] for r in results_list]

gp_output = gppp_slope(r2_train_list_gp, r2_test_list_gp, slope_list_gp, importances_list_gp, XTest_list_gp, XSlope_list_gp, xtraits_PFT, range(0,400))
gp_output.to_netcdf(repo_dir+'/output/ORGFAST_Stress.nc')

In [None]:
#DroughtSens
results_list = []
for i in range(0,400):
    result = gpe_pft(i, inds_test_PFT, xtraits_PFT, 'DroughtSens')
    results_list.append(result)

# Run some post processing
r2_train_list_gp = [r['r2_train'] for r in results_list]
r2_test_list_gp  = [r['r2_test'] for r in results_list]
importances_list_gp = [r['importances'] for r in results_list]
slope_list_gp = [r['PDP_Slope'] for r in results_list]
XTest_list_gp = [r['XTest'] for r in results_list]
XSlope_list_gp = [r['XSlope'] for r in results_list]

gp_output = gppp_slope(r2_train_list_gp, r2_test_list_gp, slope_list_gp, importances_list_gp, XTest_list_gp, XSlope_list_gp, xtraits_PFT, range(0,400))
gp_output.to_netcdf(repo_dir+'/output/ORGFAST_DroughtSens.nc')

## Run ECO-FAST Analysis ##

In [None]:
results_list = []
for i in range(0,400):
    result = gpe_gridcell(i, inds_test_GC, xtraits_GC, 'DroughtSens')
    results_list.append(result)

# Convert results into lists
r2_train_list_gp = [r['r2_train'] for r in results_list]
r2_test_list_gp  = [r['r2_test'] for r in results_list]
importances_list_gp = [r['importances'] for r in results_list]
slope_list_gp = [r['PDP_Slope'] for r in results_list]
XTest_list_gp = [r['XTest'] for r in results_list]
XSlope_list_gp = [r['XSlope'] for r in results_list]

gp_output = gppp_slope(r2_train_list_gp, r2_test_list_gp, slope_list_gp, importances_list_gp, XTest_list_gp, XSlope_list_gp, xtraits_PFT, range(0,400))
gp_output.to_netcdf(repo_dir+'/output/ECOFAST_DroughtSens.nc')