# Evaluate land model output from perturbed parameter ensemble

This script evaluates model output from a set of ensemble members in a perturbed parameter experiment. It identifies the best-performing ensemble members

## Import modules

In [344]:
import os
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import functools
import netCDF4 as nc4

## Define paths and script parameters

In [447]:
# Case name(s) for the simulation to process
cases = ['conifer-allom-082323a_-17e2acb6a_FATES-55794e61',
        'conifer-allom-082323b_-17e2acb6a_FATES-55794e61',
        'conifer-allom-082323c_-17e2acb6a_FATES-55794e61',
        'conifer-allom-082323d_-17e2acb6a_FATES-55794e61'
        ]
conifer_only_case = 'conifer-allom-082323e_-17e2acb6a_FATES-55794e61'


# Benchmarking metrics
my_metrics = ["BA","AGB","TreeStemD","ShannonE","NPP","FailedPFTs"]

# Parameters perturbed (case-specific)
perturbed_params = {'conifer-allom-082323a_-17e2acb6a_FATES-55794e61':["fates_allom_d2ca_coefficient_min","fates_mort_scalar_cstarvation"],
                    'conifer-allom-082323b_-17e2acb6a_FATES-55794e61':["fates_allom_d2ca_coefficient_min","fates_mort_scalar_cstarvation"],
                    'conifer-allom-082323c_-17e2acb6a_FATES-55794e61':["fates_allom_blca_expnt_diff"],
                    'conifer-allom-082323d_-17e2acb6a_FATES-55794e61':['fates_allom_d2h1','fates_allom_d2h2','fates_allom_dbh_maxheight']}

perturbed_params_conifer = ["fates_allom_d2ca_coefficient_min","fates_mort_scalar_cstarvation"]

# Optional
case_path = None

# Path where case output lives
case_output_root = '/glade/scratch/adamhb/archive'

# Path to ensemble params
params_root = '/glade/u/home/adamhb/ahb_params/fates_api_25/ensembles'

# Path to put any processed output
processed_output_root = '/glade/scratch/adamhb/processed_output'

# Path to where the ensemble member parameter files are stored
params_files_path = '/glade/u/home/adamhb/ahb_params/fates_api_25/ensembles/conifer_allom_082323a'

# Last n years
last_n_years = 1

# Pfts
pft_names = np.array(["pine","cedar","fir","shrub","oak"])
n_pfts = len(pft_names)
pft_colors = ['gold','darkorange','darkolivegreen','brown','springgreen']

## Variables to import

In [226]:
# Keep first two no matter what. They are needed to unravel multi-plexed dimensions
fields = ['FATES_SEED_PROD_USTORY_SZ','FATES_VEGC_AP','FATES_BURNFRAC',
          'FATES_NPLANT_PF','FATES_FIRE_INTENSITY_BURNFRAC','FATES_IGNITIONS',
          'FATES_MORTALITY_FIRE_SZPF','FATES_BASALAREA_SZPF','FATES_CANOPYCROWNAREA_APPF',
          'FATES_CROWNAREA_APPF','FATES_FUEL_AMOUNT_APFC','FATES_NPLANT_SZPF',
          'FATES_PATCHAREA_AP','FATES_CROWNAREA_PF','FATES_VEGC_ABOVEGROUND','FATES_NPP_PF']

## Functions

In [448]:
def setup_benchmarking_data_structure(metrics,parameters):
    
    metrics_out = metrics
    
    # add pft-specific vars
    pft_specific_ba_metrics = ["BA_" + pft for pft in pft_names]  
    metrics_out.extend(pft_specific_ba_metrics)
    
    # add inst tag
    metrics_out.append("inst")
    
    # add parameters for each pft
    parameter_cols = []
    new_param_col = []
    for param in perturbed_params_conifer:
        for p in pft_names:
            new_param_col = param + "_" + p
            parameter_cols.append(new_param_col)
    
    metrics_out.extend(parameter_cols)
    
    benchmarking_dict = {}
    for i in metrics_out:
        benchmarking_dict[i] = []
    return benchmarking_dict

In [449]:
# pick up with integrating this into the script below
setup_benchmarking_data_structure(my_metrics,perturbed_params_conifer)

{'BA': [],
 'AGB': [],
 'TreeStemD': [],
 'ShannonE': [],
 'NPP': [],
 'FailedPFTs': [],
 'BA_pine': [],
 'BA_cedar': [],
 'BA_fir': [],
 'BA_shrub': [],
 'BA_oak': [],
 'inst': [],
 'fates_allom_d2ca_coefficient_min_pine': [],
 'fates_allom_d2ca_coefficient_min_cedar': [],
 'fates_allom_d2ca_coefficient_min_fir': [],
 'fates_allom_d2ca_coefficient_min_shrub': [],
 'fates_allom_d2ca_coefficient_min_oak': [],
 'fates_mort_scalar_cstarvation_pine': [],
 'fates_mort_scalar_cstarvation_cedar': [],
 'fates_mort_scalar_cstarvation_fir': [],
 'fates_mort_scalar_cstarvation_shrub': [],
 'fates_mort_scalar_cstarvation_oak': []}

## Get information about a case

In [223]:
full_case_path = esm_tools.get_path_to_sim('conifer-allom-082323a_-17e2acb6a_FATES-55794e61',
                                           case_output_root)

inst_tags = esm_tools.get_unique_inst_tags(full_case_path)
n_inst = len(inst_tags)
print("ninst:",n_inst)

ninst: 36


## Test inst data

In [195]:
inst_files = esm_tools.get_files_of_inst(full_case_path,inst_tags[0],last_n_years=last_n_years)
ds = esm_tools.multiple_netcdf_to_xarray(inst_files,fields)

In [382]:
def extract_variable_from_netcdf(file_path, variable_name,pft_index):
    """
    Extract a variable from a NetCDF file.

    Parameters:
    - file_path: The path to the NetCDF file.
    - variable_name: The name of the variable to extract.

    Returns:
    - The extracted variable data.
    """
    with nc4.Dataset(file_path, 'r') as dataset:
        # Check if the variable exists in the dataset
        if variable_name in dataset.variables:
            variable_data = dataset.variables[variable_name][:]
            return variable_data.data[pft_index]
        else:
            raise ValueError(f"'{variable_name}' not found in the NetCDF file.")

In [387]:
def get_parameter_file_of_inst(params_root,param_dir,inst):
    
    '''Inputs:
    1) root directory where param perturbation params are stored
    2) subdirectory for case of interest where instance-specific parameter files are stored
    3) the instance tag (e.g. 0001)
    
    Returns: full path to parameter file'''
    
    path_to_param_files = os.path.join(params_root,param_dir)
    
    substring = "_" + inst + ".nc"

    # Get the instance files
    file = esm_tools.find_files_with_substring(path_to_param_files, substring)
    
    full_file_path = os.path.join(path_to_param_files,file[0])
    
    return full_file_path

In [395]:
import importlib
importlib.reload(esm_tools)

<module 'esm_tools' from '/glade/u/home/adamhb/Earth-System-Model-Tools/process_output/esm_tools.py'>

In [396]:
file_path = esm_tools.get_parameter_file_of_inst(params_root,'conifer_allom_082323a',"0001")
esm_tools.extract_variable_from_netcdf(file_path,"fates_allom_d2ca_coefficient_min",pft_index=1)

0.4899647576565016

### Define which parameters were perturbed

## Benchmarking function

In [327]:
def get_benchmarks(case_name,metrics,last_n_years,parameters,test = False):
      
    print("Case:",case_name)
    
    # Get info about the case
    full_case_path = esm_tools.get_path_to_sim(case_name,case_output_root)
    inst_tags = esm_tools.get_unique_inst_tags(full_case_path)
    
    if test == True:
        inst_tags = inst_tags[:3]
    
    n_inst = len(inst_tags)
    print("ninst:",n_inst)
    
    # Set up the benchmarking data structure
    bench_dict = setup_benchmarking_data_structure(metrics,parameters)
    
    for inst in inst_tags:
        
        print("Working on ensemble memeber",inst,"of",len(inst_tags),"members")
        
        # Import the model output data for one ensemble member
        inst_files = esm_tools.get_files_of_inst(full_case_path,
                                                 inst,
                                                 last_n_years)
        
        ds = esm_tools.multiple_netcdf_to_xarray(inst_files,fields)
        
        bench_dict['inst'].append(inst)
        
        ## Basal area [m2 ha-1] ##
        if "BA" in bench_dict.keys():
            
            ## Pft-specific BA
            pft_level_ba = esm_tools.get_pft_level_basal_area(ds)
            
            for i in range(len(pft_names)):
                pft_name = pft_names[i]
                bench_dict['BA_' + pft_name].append(pft_level_ba[i])
            
            ## Shannon equitability index (wrt BA) ##
            bench_dict['ShannonE'].append(esm_tools.shannon_equitability(pft_level_ba))
            
            ## Number of failed pfts ##
            bench_dict['FailedPFTs'].append(esm_tools.get_n_failed_pfts(pft_level_ba,ba_thresh=0.1))
            
            ## Total BA
            bench_dict['BA'].append(pft_level_ba.sum())
                  
        ## Stem density [N ha-1] ##
        if "TreeStemD" in bench_dict.keys():
            
            ## Total tree stem density
            bench_dict["TreeStemD"].append(esm_tools.get_total_stem_den(ds,trees_only=True))
        
        ## AGB [kg C m-2]
        if "AGB" in bench_dict.keys():
            bench_dict["AGB"].append(esm_tools.get_AGB(ds))
        
        ## Total NPP [kg C m-2]
        if "NPP" in bench_dict.keys():
            bench_dict["NPP"].append(esm_tools.get_total_npp(ds))        
            
        
    return bench_dict

In [393]:
import importlib
importlib.reload(esm_tools)

<module 'esm_tools' from '/glade/u/home/adamhb/Earth-System-Model-Tools/process_output/esm_tools.py'>

## Apply benchmarking function to one case

In [398]:
test_benchmarks_dict = get_benchmarks('conifer-allom-082323a_-17e2acb6a_FATES-55794e61',metrics = my_metrics, last_n_years=1,test=True)

Case: conifer-allom-082323a_-17e2acb6a_FATES-55794e61
ninst: 3
Working on ensemble memeber 0001 of 3 members
Working on ensemble memeber 0002 of 3 members
Working on ensemble memeber 0003 of 3 members


### Add the perturbed parameters to the dataframe

In [399]:
test_benchmarks_df = pd.DataFrame(test_benchmarks_dict)

fates_allom_d2ca_coefficient_min_pine = []

full_case_path = esm_tools.get_path_to_sim('conifer-allom-082323a_-17e2acb6a_FATES-55794e61',case_output_root)
inst_tags = esm_tools.get_unique_inst_tags(full_case_path)

for inst in inst_tags:
    file_path = esm_tools.get_parameter_file_of_inst(params_root,'conifer_allom_082323a',inst)
    fates_allom_d2ca_coefficient_min_pine.append(esm_tools.extract_variable_from_netcdf(file_path,"fates_allom_d2ca_coefficient_min",pft_index=1))

In [400]:
fates_allom_d2ca_coefficient_min_pine

[0.4899647576565016,
 0.6591918798433293,
 0.5591396913673465,
 0.44387543429055387,
 0.4669036704609957,
 0.5378857997600446,
 0.703276823991224,
 0.5063956515778991,
 0.5523904766795659,
 0.5331186765507984,
 0.445589932605435,
 0.6040510370465719,
 0.4148754940792792,
 0.6158034102018053,
 0.6923019152295269,
 0.6682868080655677,
 0.7297392477381135,
 0.5727999652803268,
 0.5140650055888949,
 0.7155783184747277,
 0.5974075654535408,
 0.7953585351024293,
 0.40052906094147095,
 0.7447749617335179,
 0.7641969451066422,
 0.6861344558044802,
 0.7885421302995264,
 0.6471831583298981,
 0.7374701578622666,
 0.4596911083415758,
 0.5813836285699507,
 0.4795436369245589,
 0.7731539405842693,
 0.4257538697656328,
 0.6375091631883201,
 0.6224476527864025]

## Apply benchmarking function to multiple cases

In [330]:
multi_case_benchmarks = []
for c in cases:
     multi_case_benchmarks.append(get_benchmarks(c,metrics = my_metrics, last_n_years=1,test=True))
multi_case_benchmarks

Case: conifer-allom-082323a_-17e2acb6a_FATES-55794e61
ninst: 3
Working on ensemble memeber 0001 of 3 members
Working on ensemble memeber 0002 of 3 members
Working on ensemble memeber 0003 of 3 members
Case: conifer-allom-082323b_-17e2acb6a_FATES-55794e61
ninst: 3
Working on ensemble memeber 0001 of 3 members
Working on ensemble memeber 0002 of 3 members
Working on ensemble memeber 0003 of 3 members
Case: conifer-allom-082323c_-17e2acb6a_FATES-55794e61
ninst: 3
Working on ensemble memeber 0001 of 3 members
Working on ensemble memeber 0002 of 3 members
Working on ensemble memeber 0003 of 3 members
Case: conifer-allom-082323d_-17e2acb6a_FATES-55794e61
ninst: 3
Working on ensemble memeber 0001 of 3 members
Working on ensemble memeber 0002 of 3 members
Working on ensemble memeber 0003 of 3 members


[{'BA': [14.730284, 17.271265, 16.227358],
  'AGB': [10.960171699523926, 13.609286308288574, 10.616921424865723],
  'TreeStemD': [1023.1469571590424, 138.85902240872383, 989.0015423297882],
  'ShannonE': [0.8762627428896109, 0.8333740374169594, 0.8441429882930233],
  'NPP': [0.9784010217686046, 0.972694234917526, 1.0309568806974312],
  'FailedPFTs': [0, 1, 0],
  'BA_pine': [4.0757084, 5.8454432, 3.5551548],
  'BA_cedar': [1.9897865, 4.028065, 3.0028014],
  'BA_fir': [3.7301471, 5.048698, 3.0405746],
  'BA_shrub': [0.26549125, 0.011507053, 0.10902222],
  'BA_oak': [4.66915, 2.3375516, 6.5198045],
  'inst': ['0001', '0002', '0003']},
 {'BA': [15.095007, 17.223886, 16.246082],
  'AGB': [11.340022087097168, 13.750171661376953, 10.465803146362305],
  'TreeStemD': [1042.1913117170334, 137.93516904115677, 766.0796493291855],
  'ShannonE': [0.921175825362171, 0.8327862266573357, 0.8441771340541355],
  'NPP': [0.9896557250499427, 0.9617487575610539, 1.0589526946205297],
  'FailedPFTs': [0, 0, 0

### Functions

In [56]:
def agefuel_to_age_by_fuel(agefuel_var, dataset):
    n_age = len(dataset.fates_levage)
    ds_out = (agefuel_var.rolling(fates_levagefuel = n_age, center=False).construct("fates_levage")
          .isel(fates_levagefuel=slice(n_age-1, None, n_age))
          .rename({'fates_levagefuel':'fates_levfuel'})
          .assign_coords({'fates_levage':dataset.fates_levage})
          .assign_coords({'fates_levfuel':np.array([1,2,3,4,5,6])}))
    return ds_out
    #ds_out.attrs['long_name'] = agefuel_var['long_name']
    #ds_out.attrs['units'] = agefuel_var['units']

def scpf_to_scls_by_pft(scpf_var, dataset):
    """function to reshape a fates multiplexed size and pft-indexed variable to one indexed by size class and pft
    first argument should be an xarray DataArray that has the FATES SCPF dimension
    second argument should be an xarray Dataset that has the FATES SCLS dimension 
    (possibly the dataset encompassing the dataarray being transformed)
    returns an Xarray DataArray with the size and pft dimensions disentangled"""
    n_scls = len(dataset.fates_levscls)
    ds_out = (scpf_var.rolling(fates_levscpf=n_scls, center=False)
            .construct("fates_levscls")
            .isel(fates_levscpf=slice(n_scls-1, None, n_scls))
            .rename({'fates_levscpf':'fates_levpft'})
            .assign_coords({'fates_levscls':dataset.fates_levscls})
            .assign_coords({'fates_levpft':dataset.fates_levpft}))
    ds_out.attrs['long_name'] = scpf_var.attrs['long_name']
    ds_out.attrs['units'] = scpf_var.attrs['units']
    return(ds_out)



def get_last_file_of_sim(sim_path):
    files = sorted(os.listdir(sim_path))
    full_files = [os.path.join(sim_path,f) for f in files]
    last_file = full_files[-1]
    return last_file



def open_fates_hist_file(file):
    ds = xr.open_dataset(file)
    ds = ds.sel(lndgrid=0)
    return ds

def get_total_stem_den(file):
    ds = open_fates_hist_file(file)
    den_total = ds.FATES_NPLANT_PF.sum(dim="fates_levpft")
    den_shrub = ds.FATES_NPLANT_PF.isel(fates_levpft = 3)
    den_trees = den_total - den_shrub
    den_trees = den_trees.values * m2_per_ha
    return den_trees

def get_total_npp(file):
    ds = open_fates_hist_file(file)
    npp_total = ds.FATES_NPP_PF.sum(dim="fates_levpft").values * s_per_yr
    return npp_total

def get_AGB(file):
    ds = open_fates_hist_file(file)
    agb_total = ds.FATES_VEGC_ABOVEGROUND.values 
    return agb_total

def get_fuel(file):
    ds = open_fates_hist_file(file)
    age_by_fuel = agefuel_to_age_by_fuel(ds.FATES_FUEL_AMOUNT_APFC,ds)
    fates_fuel_amount_by_class = age_by_fuel.sum(dim = "fates_levage") #sum over patches
    trunk = fates_fuel_amount_by_class.isel(fates_levfuel = 3)
    total = fates_fuel_amount_by_class.sum(dim = "fates_levfuel")
    burnable = total - trunk
    return burnable.values[0]

def get_pft_level_basal_area(file,dbh_min):
    ds = open_fates_hist_file(file)
    basal_area = scpf_to_scls_by_pft(ds.FATES_BASALAREA_SZPF, ds)
    basal_area = basal_area.sel(fates_levscls = slice(dbh_min,None))
    basal_area_pf = basal_area.sum(axis=2).values * m2_per_ha
    return basal_area_pf

def get_total_tree_basal_area(file,dbh_min):
    ds = open_fates_hist_file(file)
    basal_area = scpf_to_scls_by_pft(ds.FATES_BASALAREA_SZPF, ds)
    basal_area = basal_area.sel(fates_levscls = slice(dbh_min,None))
    basal_area_pf = basal_area.sum(axis=2) * m2_per_ha
    basal_area_total = basal_area_pf.sum(axis = 1)
    basal_area_shrub = basal_area_pf.isel(fates_levpft = 3)
    basal_area_tree = basal_area_total-basal_area_shrub
    return basal_area_tree.values[0]

def get_size_class_distribution(file,dbh_min):
    ds = open_fates_hist_file(file)
    basal_area = scpf_to_scls_by_pft(ds.FATES_BASALAREA_SZPF, ds)
    basal_area = basal_area.sel(fates_levscls = slice(dbh_min,None))
    basal_area_pf = basal_area.values * m2_per_ha
    return basal_area_pf

def get_shrub_crown_area(file):
    ds = open_fates_hist_file(file)
    shrub_canopy_cover = ds.FATES_CANOPYCROWNAREA_PF.isel(fates_levpft = 3).values[0]
    return shrub_canopy_cover

def get_diff_table(site,var,d):
    d_site = d.loc[d["site"].str.contains(site),:]
    d_site['delta'] = d_site[var] - d_site.iloc[-1,1]
    d_site['delta_pct'] = d_site['delta'] / d_site[var]
    return d_site



### Get pft-level basal area at end of simulation at many sites

In [None]:
n_inst = 36
inst_tags = []
for i in range(n_inst):
    inst_tags.append(str(i+1).rjust(4, '0'))

In [None]:
get_last_file_of_inst(get_path_to_sim('QUKE-ensemble-081823_-17e2acb6a_FATES-55794e61'),"0036")

In [None]:
get_last_file_of_inst("/glade/scratch/adamhb/archive/conifer-allom-082223_-17e2acb6a_FATES-55794e61/run",t)

In [None]:
#PFT-level basal area in final timestep for each pft
inst = []
pft = []
ba = []


for i,t in enumerate(inst_tags):
    print("working on inst ",t)
    tmp_ba = get_pft_level_basal_area(get_last_file_of_inst("/glade/scratch/adamhb/archive/conifer-allom-082223_-17e2acb6a_FATES-55794e61/run",t),dbh_min=0)[0,:]
    
    #add shrub canopy cover in place of basal area. Multiply by 100 to be %, then divide by two to match the secondary axis transformation in R figure.
    tmp_ba[3] = get_shrub_crown_area(get_last_file_of_inst("/glade/scratch/adamhb/archive/conifer-allom-082223_-17e2acb6a_FATES-55794e61/run",t)) * (100 / 2)
    ba.extend(tmp_ba)
    pft.extend(pft_names)
    inst.extend([t] * len(pft_names))
    
d = pd.DataFrame({'inst':inst,'pft':pft,'ba':ba,'data_type':['fates'] * len(inst)})
#d.to_csv(os.path.join(path_for_output,batch_tag + "_BA.csv"))

In [None]:
#PFT-level basal area in final timestep for each pft
inst = []
pft = []
ba = []


for i,t in enumerate(inst_tags):
    print("working on inst ",t)
    tmp_ba = get_pft_level_basal_area(get_last_file_of_inst(get_path_to_sim(sim_names[0]),t),dbh_min=0)[0,:]
    
    #add shrub canopy cover in place of basal area. Multiply by 100 to be %, then divide by two to match the secondary axis transformation in R figure.
    tmp_ba[3] = get_shrub_crown_area(get_last_file_of_inst(get_path_to_sim(sim_names[0]),t)) * (100 / 2)
    ba.extend(tmp_ba)
    pft.extend(pft_names)
    inst.extend([t] * len(pft_names))
    
d = pd.DataFrame({'inst':inst,'pft':pft,'ba':ba,'data_type':['fates'] * len(inst)})
#d.to_csv(os.path.join(path_for_output,batch_tag + "_BA.csv"))

In [None]:
d.pft.value_counts()

In [None]:
d.loc[(d["pft"] == "pine")].head(36)

In [None]:
d.loc[(d["pft"] == "cedar") & (d["ba"] > 0.1)]

### Get total BA

In [None]:
#PFT-level basal area in final timestep for each pft
site = []
pft = []
ba = []
for i,c in enumerate(sim_names):
    print("working on scenario ",c)
    tmp_ba = get_total_tree_basal_area(get_last_file_of_sim(get_path_to_sim(c)),dbh_mins[i])
    print(tmp_ba)
    ba.append(tmp_ba)
    site.append(site_names[i])
    
d = pd.DataFrame({'site':site,'ba':ba})
d.to_csv(os.path.join(path_for_output,batch_tag + "_BA_total.csv"))

#### See total BA differences

In [None]:
def get_diff_table(site,var):
    d_site = d.loc[d["site"].str.contains(site),:]
    d_site['delta'] = d_site[var] - d_site.iloc[-1,1]
    d_site['delta_pct'] = d_site['delta'] / d_site[var]
    return d_site

In [None]:
d_CZ2 = d.loc[d["site"].str.contains("CZ2"),:]
d_CZ2['delta'] = d_CZ2['ba'] - d_CZ2.iloc[-1,1]
d_CZ2['delta_pct'] = d_CZ2['delta'] / d_CZ2['ba']
d_CZ2

get_diff_table(site,var):
    d_site = d.loc[d["site"].str.contains(site),:]
    d_site['delta'] = d_site[var] - d_site.iloc[-1,1]
    d_site['delta_pct'] = d_site['delta'] / d_site[var]
    return d_site

In [None]:
#PFT-level stem density in final timestep for each pft
site = []
den = []
for i,c in enumerate(sim_names):
    print("working on scenario ",c)
    tmp_den = get_total_stem_den(get_last_file_of_sim(get_path_to_sim(c)))[0]
    print(tmp_den)
    den.append(tmp_den)
    site.append(site_names[i])
    
d = pd.DataFrame({'site':site,'den':den})
d.to_csv(os.path.join(path_for_output,batch_tag + "_stem_den.csv"))

In [None]:
get_diff_table("CZ2","den")

In [None]:
#PFT-level basal area in final timestep for each pft
site = []
den = []
for i,c in enumerate(sim_names):
    print("working on scenario ",c)
    tmp_den = get_total_npp(get_last_file_of_sim(get_path_to_sim(c)))[0]
    print(tmp_den)
    den.append(tmp_den)
    site.append(site_names[i])
    
d = pd.DataFrame({'site':site,'NPP':den})
d.to_csv(os.path.join(path_for_output,batch_tag + "_NPP.csv"))

In [None]:
get_diff_table("CZ2","NPP")

In [None]:
#PFT-level basal area in final timestep for each pft
site = []
den = []
for i,c in enumerate(sim_names):
    print("working on scenario ",c)
    tmp_den = get_AGB(get_last_file_of_sim(get_path_to_sim(c)))[0]
    print(tmp_den)
    den.append(tmp_den)
    site.append(site_names[i])
    
d = pd.DataFrame({'site':site,'AGB':den})
d.to_csv(os.path.join(path_for_output,batch_tag + "_AGB.csv"))

In [None]:
get_diff_table("CZ2","AGB")

### Fuel loads

In [None]:
get_last_file_of_sim(get_path_to_sim(site_names[1]))
ds = open_fates_hist_file(file)
age_by_fuel = agefuel_to_age_by_fuel(ds.FATES_FUEL_AMOUNT_APFC,ds)
fates_fuel_amount_by_class = age_by_fuel.sum(dim = "fates_levage") #sum over patches
trunk = fates_fuel_amount_by_class.isel(fates_levfuel = 3)
total = fates_fuel_amount_by_class.sum(dim = "fates_levfuel")
burnable = total - trunk
burnable.values[0]
#print("total", total.dims)
#print(trunk.dims)


In [None]:
#PFT-level basal area in final timestep for each pft
site = []
den = []
my_dict = {}
for i,c in enumerate(sim_names):
    print("working on scenario ",c)
    tmp_fuel = get_fuel(get_last_file_of_sim(get_path_to_sim(c)))
    print(tmp_fuel)
    my_dict[c] = tmp_fuel
    #yourdict = {k:v for k,v in zip(keys, values)}

    
#d = pd.DataFrame({'site':site,'fuel':den})
#d.to_csv(os.path.join(path_for_output,batch_tag + "_AGB.csv"))

### Get the size class distribution at the end of the simulation at many sites

In [None]:
my_dict

In [None]:
fuel_df = pd.DataFrame.from_dict(my_dict, orient='index', columns=["burnable"])
fuel_df = fuel_df.rename_axis("site").reset_index()
get_diff_table("stan","burnable",d=fuel_df)

In [None]:
pd.DataFrame.from_dict(my_dict, orient='index', columns=["site","leaf","sm_br","lg_br","trunk","grass"])

In [None]:
c = sim_names[1]
print(c)
dbh_min = 0.

file = get_last_file_of_sim(get_path_to_sim(c))

ds = open_fates_hist_file(file)
basal_area = scpf_to_scls_by_pft(ds.FATES_BASALAREA_SZPF, ds)
basal_area = basal_area.sel(fates_levscls = slice(dbh_min,None))
basal_area_pf = basal_area * m2_per_ha
basal_area_pf = basal_area_pf[0,:,:]


# Sample data
data = basal_area_pf.sum(axis = 0)

# Number of bars
num_bars = len(data.values)

# Create bar positions
bar_positions = data.fates_levscls.values

# Plot the bars
plt.bar(bar_positions, data.values)

# Customize the chart
plt.xlabel('Size Class [cm]')
plt.ylabel('BA [m2 ha-1]')
plt.title('Stanislaus')
plt.xticks(data.fates_levscls.values)

# Display the chart
plt.show()
# fig, axes = plt.subplots(ncols=ncol,nrows=nrow,figsize=(12,10))
# for size,ax in zip(range(len(basal_area_pf.fates_levscls.values)),axes.ravel()):

#          cca = xarr.isel(fates_levage = age) / xds.FATES_PATCHAREA_AP.isel(fates_levage = age)

#          for p in range(n_pfts):
#              cca.isel(fates_levpft=p).plot(x = "time",
#                       color = pft_colors[p],lw = 3,add_legend = True,
#                       label = pft_names[p], ax = ax)

#              #plt.legend()
#          ax.set_title('{} yr old patches'.format(xds.fates_levage.values[age]))
#          ax.set_ylabel(ylabel,fontsize = int(12 * 0.75))
#          ax.xaxis.set_major_formatter(DateFormatter('%Y'))
#          #ax.xaxis.set_major_locator(mdates.YearLocator(base=nbase))

#     plt.tight_layout()
#     plt.subplots_adjust(hspace=1,wspace=0.2)
#     fig.suptitle(sup_title, fontsize=12,y=0.99)


In [None]:
site_names

### Stem density distribution

In [None]:
c = sim_names[1]
print(c)
dbh_min = 0.

file = get_last_file_of_sim(get_path_to_sim(c))

ds = open_fates_hist_file(file)
basal_area = scpf_to_scls_by_pft(ds.FATES_NPLANT_SZPF, ds)
basal_area = basal_area.sel(fates_levscls = slice(dbh_min,None))
basal_area_pf = basal_area * m2_per_ha
basal_area_pf = basal_area_pf[0,:,:]


# Sample data
data = basal_area_pf.sum(axis = 0)

# Number of bars
num_bars = len(data.values)

# Create bar positions
bar_positions = data.fates_levscls.values

# Plot the bars
plt.bar(bar_positions, data.values)

# Customize the chart
plt.xlabel('Size Class [cm]')
plt.ylabel('Density [N ha-1]')
plt.title('Stanislaus')
plt.xticks(data.fates_levscls.values)
plt.show()

### Size Class Distribution

### Basal Area by Size Class and PFT

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Sample data
data = basal_area_pf

num_groups, num_bars = data.shape
bar_width = 0.35
opacity = 0.8

# Create subplots for each group
fig, axes = plt.subplots(nrows=num_groups, ncols=1, figsize=(5, num_groups * 5), sharey=True)

# Iterate over the groups and create bar plots
for group_idx, ax in enumerate(axes):
    ax.bar(np.arange(num_bars), data[group_idx, :], width=bar_width, alpha=opacity)

    # Customize each subplot
    ax.set_title(f'PFT: {pft_names[group_idx]}')
    ax.set_xticks(np.arange(num_bars))
    ax.set_xticklabels(basal_area_pf.fates_levscls.values)
    ax.set_ylabel("BA [m2 ha-1]")
    ax.set_xlabel("Size Class [cm]")

# Set common labels
fig.text(0.5, 0.04, 'Bar', ha='center', va='center')
fig.text(0.06, 0.5, 'Values', ha='center', va='center', rotation='vertical')

# Display the facetted bar plots
plt.show()