# Test A Single PFT to Ensure Our Emulation and Calibration Works

In [None]:
import os
import pandas as pd
import numpy as np
import fates_calibration.train_emulators as tr
import fates_calibration.emulation_functions as emf
from fates_calibration.FATES_calibration_constants import FATES_PFT_IDS, FATES_INDEX, IMPLAUS_TOL

import matplotlib.pyplot as plt

## Set Up
Just setting some values

In [None]:
# mesh file and land mask files used in the simulation
mesh_file_dir = '/glade/work/afoster/FATES_calibration/mesh_files'
land_mask_file = os.path.join(mesh_file_dir, 'dominant_pft_grid_update.nc')
mesh_file = os.path.join(mesh_file_dir, 'dominant_pft_grid_update_mesh.nc')

# post-processed ensemble
ensemble_dir = '/glade/work/afoster/FATES_calibration/history_files'
ensemble_file = os.path.join(ensemble_dir, 'fates_lh_dominant_gs1.nc')

# latin hypercube key
lhc_key = '/glade/work/afoster/FATES_calibration/parameter_files/fates_param_lh/fates_lh_key.csv'

# emulator directory
emulator_dir = '/glade/u/home/afoster/FATES_Calibration/pft_output_gs1/emulators'

# where to get grid_1dlat and grid_1dlon
ds0_file = '/glade/work/afoster/FATES_calibration/history_files/fates_lh_dominant_gso_vcmax/ctsm60SP_fates_dominant_pft_gs0_vcmax_FATES_LH_000.nc'

# variables to emulate and test/train split
vars = ['GPP', 'EFLX_LH_TOT', 'FSH', 'EF']
n_test = 50

# observations
obs_file = '/glade/work/afoster/FATES_calibration/mesh_files/dominant_pft_grid_update.csv'
obs_df_all = pd.read_csv(obs_file)
pfts = np.unique(obs_df_all.pft)

## PFT
Which pft to check

In [None]:
pft = 'broadleaf_evergreen_tropical_tree'
pft_id = FATES_PFT_IDS[pft]

In [None]:
fig_dir = f'/glade/u/home/afoster/FATES_Calibration/pft_output_gs1/{pft_id}_outputs'
if not os.path.isdir(fig_dir):
    os.mkdir(fig_dir)

In [None]:
obs_df = obs_df_all[obs_df_all.pft == pft]
lhkey_df = pd.read_csv(lhc_key)
lhkey_df = lhkey_df.drop(columns=['ensemble'])
param_names = lhkey_df.columns

## Train and Test Emulators

### First Train

In [None]:
tr.train(pft, land_mask_file, mesh_file, ensemble_file, vars, lhc_key, n_test, emulator_dir, fig_dir, ds0_file)

### Now Load Back in and Test

In [None]:
emulators = emf.load_all_emulators(pft_id, emulator_dir, vars)

In [None]:
## Sensitivity Analyses
sens_df, oaat_df = emf.sensitivity_analysis(emulators, param_names, pft_id, fig_dir, plot_figs=True)

In [None]:
# Sample Emulator
sample_df = emf.sample_emulators(emulators, param_names, 10000, obs_df, fig_dir, pft_id, plot_figs=True)

In [None]:
# Check variables that should be calibrated
pft_vars = []
for var in [f"{var}_implausibility" for var in vars]:
    sample_sub = sample_df.where(sample_df[var] < 1.0)
    sample_sub = sample_sub.dropna()
    
    prop_in = len(sample_sub)/len(sample_df)*100.0
    print(f"{prop_in}% of emulated sample falls within observational tolerance for {var}.")
    if prop_in < 90.0:
        if prop_in > 0.0:
            pft_vars.append(var)

In [None]:
pft_vars

## Calibrate

In [None]:
def choose_params(sample_df, sens_df, vars, implausibility_tol, sens_tol):

    # subset out anything over implausibility tolerance
    implaus_vars = [f"{var}_implausibility" for var in vars]
    sample_df['implaus_sum'] = emf.calculate_implaus_sum(sample_df, implaus_vars)

    implaus_diff = np.max(sample_df.implaus_sum) - np.min(sample_df.implaus_sum)
    if implaus_diff <= 0.5:
       return None
    
    sample_sub = emf.subset_sample(sample_df, implaus_vars, implausibility_tol)
    if sample_sub.isnull().values.any():
        print("ERROR ERROR ERROR")
        pd.write_csv(sample_df, 'sample_df.csv')
        pd.write_csv(sens_df, 'sens_df.csv')
        
    
    # grab only the sensitive parameters
    sensitive_pars = emf.find_sensitive_parameters(sens_df, vars, sens_tol)

    if sample_sub.shape[0] > 0 and len(sensitive_pars) > 0:
        best_sample = emf.find_best_parameter_sets(sample_sub)
        sample_out = best_sample.loc[:, sensitive_pars]
    
        return sample_out.reset_index(drop=True)
    else:
        return None
    
def calibration_wave(emulators, param_names, n_samp, obs_df, pft_id, out_dir, wave,
                     implausibility_tol, sens_tol, update_vars=None, default_pars=None,
                     plot_figs=False):
    
    sens_df, oaat_df = emf.sensitivity_analysis(emulators, param_names, pft_id, out_dir, wave,
                                   update_vars=update_vars, default_pars=default_pars,
                                   plot_figs=plot_figs)
    
    sample_df = emf.sample_emulators(emulators, param_names, n_samp, obs_df, out_dir, pft_id,
                     update_vars=update_vars, default_pars=default_pars,
                     plot_figs=plot_figs)
    
    best_sample = choose_params(sample_df, sens_df, list(emulators.keys()),
                                implausibility_tol, sens_tol)

    return best_sample


def find_best_parameters(num_waves, emulators, param_names, n_samp, obs_df, pft_id, out_dir,
        implausibility_tol, sens_tol, default_pars=None):

    update_vars = None
    for wave in range(num_waves):
        if wave == 0:
            best_sample = calibration_wave(emulators, param_names, n_samp,
                                           obs_df, pft_id, out_dir, wave,
                                           implausibility_tol, sens_tol,
                                           update_vars=None, default_pars=default_pars)
        else:
            if best_sample is not None:
                if update_vars is None:
                    update_vars = best_sample
                else:
                    update_vars = pd.concat([update_vars, best_sample], axis=1)
                best_sample = calibration_wave(emulators, param_names, n_samp,
                                               obs_df, pft_id, out_dir, wave,
                                               implausibility_tol, sens_tol,
                                               update_vars=update_vars, 
                                               default_pars=default_pars)
            else:
                return update_vars, wave
    return update_vars, wave

In [None]:
top_dir = "/glade/u/home/afoster/FATES_Calibration/pft_output_gs1"
out_dir = os.path.join(top_dir, f"{pft_id}_outputs")
sample_dir = os.path.join(out_dir, 'samples')

In [None]:
pft_var_file = '/glade/u/home/afoster/FATES_Calibration/pft_vars_dompft_gs1.csv'
var_dat = pd.read_csv(pft_var_file)
vars_pft = var_dat[var_dat.pft == pft].vars.values.tolist()
vars = [var.replace('_implausibility', '') for var in vars_pft]

In [None]:
emulators = emf.load_all_emulators(pft_id, emulator_dir, vars)

In [None]:
best_param_set, wave = find_best_parameters(10, emulators, param_names, 100000,
                                              obs_df, pft_id, out_dir, IMPLAUS_TOL[pft],
                                              0.1, default_pars=None)

In [None]:
best_param_set

In [None]:
wave