# Import modules

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
import functools
import cftime
import warnings
from matplotlib.dates import DateFormatter
import math
from datetime import datetime

# Ignore all warnings
warnings.filterwarnings("ignore")

def create_directory(directory_path):
    try:
        os.mkdir(directory_path)
        print(f"Directory '{directory_path}' created successfully!")
    except FileExistsError:
        print(f"Directory '{directory_path}' already exists!")
        
def select_from_vec(x,indices):
    return np.array(x)[np.array(indices)]

# Set paths

In [None]:
sim_root = '/glade/scratch/adamhb/archive/'
ahb_root = '/glade/u/home/adamhb/reports'
param_set = "params_070723_with_hybrid"
output_path = os.path.join(ahb_root,param_set)
make_report = False

if make_report == True:
    create_directory(report_path)

#choose a name for each scenario
scenarios = ["Dry MCF Pre-colonial",
             "Dry MCF 1981-2009",
             "Dry MCF 2010-2020",
             "Wet MCF Pre-colonial",
             "Wet MCF 1981-2009",
             "Wet MCF 2010-2020"]

scenario_colors = ["r","r","r","b","b","b"]
scenario_line_styles = ["solid","dashed","dashed","solid","dashed","dashed"]  

#choose cases to represent the scenarios
cases = ['CZ2_HF_070723_CLM-17e2acb6a_FATES-74ade8b2',
         'CZ2-1980-2020_070723_CLM-17e2acb6a_FATES-74ade8b2',
         'CZ2-1980-2020_070723_CLM-17e2acb6a_FATES-74ade8b2',
         'stan_HF_070723_CLM-17e2acb6a_FATES-74ade8b2',
         'stan-1980-2020_070723_CLM-17e2acb6a_FATES-74ade8b2',
         'stan-1980-2020_070723_CLM-17e2acb6a_FATES-74ade8b2']

#choose the years to import for each scenario
years = [list(range(2020, 2050)),
         list(range(1980,2010)),
         list(range(2010,2020)),
         list(range(2020, 2050)),
         list(range(1980,2010)),
         list(range(2010,2020))]

#choose climate scenarios to compare
climate_response_comparison = ["Dry MCF 1981-2009","Wet MCF 1981-2009"]
climate_response_comparison_indices = np.where([s in climate_response_comparison for s in scenarios])[0]

#choose fire scenarios to compare
fire_response_comparison = ["Dry MCF Pre-colonial","Dry MCF 1981-2009"]
fire_response_comparison_indices = np.where([s in fire_response_comparison for s in scenarios])[0]

#date for getting fraction of basal area for climate comparison
basal_area_year_climate_comparison = "1981-01-01"

#date for getting ba fraction for fire comparison. Pre-colonial value (spinup) first.
basal_area_year_fire_comparison = ["2048-01-01","1981-01-01"]




#choose observations to compare to each scenario above (choose indices of obs scenarios below)
obs_indices = [0,1,2,3,4,5]
obs_scenarios = ["Dry MCF Pre-colonial",
             "Dry MCF 1981-2009",
             "Dry MCF 2010-2020",
             "Wet MCF Pre-colonial",
             "Wet MCF 1981-2009",
             "Wet MCF 2010-2020"]

#fields to import
#keep first two no matter what
fields = ['FATES_SEED_PROD_USTORY_SZ','FATES_VEGC_AP','FATES_BURNFRAC',
          'FATES_NPLANT_PF','FATES_FIRE_INTENSITY_BURNFRAC','FATES_IGNITIONS',
          'FATES_MORTALITY_FIRE_SZPF','FATES_BASALAREA_SZPF','FATES_CANOPYCROWNAREA_APPF',
          'FATES_CROWNAREA_APPF','FATES_FUEL_AMOUNT_APFC','FATES_NPLANT_SZPF',
          'FATES_PATCHAREA_AP','FATES_CROWNAREA_PF']

# Observations

In [None]:
##Observations
#Annual burned area
burnFrac_obs = select_from_vec([0.0426,0.0053,0.020328458893853,
 0.0321,0.00388307,0.02015],obs_indices)


#Uncertainty in burned area
#For pre-colonial these numbers come from range of estimates in the literature (Williams et al., 2023, Table 3)
#For the suppression era the standard deviation in the burned fraction is 0.009
#This is taken from variation in burned area seen in Williams et al 2013 Fig. 2A
burnFrac_uncertainty_low = select_from_vec([0.029, 0, 0.01092, 0.0143, 0, 0.01072], obs_indices)
burnFrac_uncertainty_high = select_from_vec([0.09, 0.01467, 0.02967, 0.067, 0.0132561, 0.029473118], obs_indices)
                            
                            
PHS_obs = select_from_vec([6,25,43,
           8,30,37], obs_indices)



# Constants

In [None]:
#Define constants
n_pfts = 5
pft_names = ["pine","cedar","fir","shrub","oak"]
pft_colors = ['gold','darkorange','darkolivegreen','brown','springgreen']
s_per_yr = 31536000
s_per_day = 3600 * 24
m2_per_ha = 1e4
m2_per_km2 = 1e6
g_per_kg = 1000
mm_per_m = 1000
months_per_yr = 12
s_per_month = 3600 * 24 * 30.4

# Functions

In [None]:
def preprocess(ds, fields):
    '''Selects the variables we want to read in 
       Drops lndgrid because we are on a single point'''
    
    return ds[fields].sel(lndgrid=0)


def fix_time(ds):
    '''Does a quick fix to adjust time vector for monthly data'''
    nmonths = len(ds.time)
    yr0 = ds['time.year'][0].values
    ds['time'] = xr.cftime_range(str(yr0), periods=nmonths, freq='MS')

    return ds

def multiple_netcdf_to_xarray(path, case, years, fields):
    
    if case == None:
        return None
    
    months = list(range(1, 13, 1))

    # build a list of file names based on the year and month
    file_names = [f"{case}.clm2.h0.{str(year)}-{str(month).rjust(2, '0')}.nc"
                  for year in years for month in months]

    # create their full path
    full_paths = [os.path.join(path, case, 'lnd/hist', fname) for fname in file_names]

    # open the dataset -- this may take a bit of time
    ds = fix_time(xr.open_mfdataset(full_paths, decode_times=True,
                                    preprocess=functools.partial(preprocess, fields=fields)))

    print('-- your data have been read in -- ')
    
    return(ds)

def scpf_to_scls_by_pft(scpf_var, dataset):
    """function to reshape a fates multiplexed size and pft-indexed variable to one indexed by size class and pft
    first argument should be an xarray DataArray that has the FATES SCPF dimension
    second argument should be an xarray Dataset that has the FATES SCLS dimension 
    (possibly the dataset encompassing the dataarray being transformed)
    returns an Xarray DataArray with the size and pft dimensions disentangled"""
    n_scls = len(dataset.fates_levscls)
    ds_out = (scpf_var.rolling(fates_levscpf=n_scls, center=False)
            .construct("fates_levscls")
            .isel(fates_levscpf=slice(n_scls-1, None, n_scls))
            .rename({'fates_levscpf':'fates_levpft'})
            .assign_coords({'fates_levscls':dataset.fates_levscls})
            .assign_coords({'fates_levpft':dataset.fates_levpft}))
    ds_out.attrs['long_name'] = scpf_var.attrs['long_name']
    ds_out.attrs['units'] = scpf_var.attrs['units']
    return(ds_out)


def appf_to_ap_by_pft(appf_var, dataset):
    """function to reshape a fates multiplexed size and pft-indexed variable to one indexed by size class and pft
    first argument should be an xarray DataArray that has the FATES SCPF dimension
    second argument should be an xarray Dataset that has the FATES SCLS dimension 
    (possibly the dataset encompassing the dataarray being transformed)
    returns an Xarray DataArray with the size and pft dimensions disentangled"""
    n_ap = len(dataset.fates_levage)
    ds_out = (appf_var.rolling(fates_levagepft=n_ap, center=False)
            .construct("fates_levage")
            .isel(fates_levagepft=slice(n_ap-1, None, n_ap))
            .rename({'fates_levagepft':'fates_levpft'})
            .assign_coords({'fates_levage':dataset.fates_levage})
            .assign_coords({'fates_levpft':dataset.fates_levpft}))
    #ds_out.attrs['long_name'] = scpf_var.attrs['long_name']
    #ds_out.attrs['units'] = scpf_var.attrs['units']
    return(ds_out)


def getNBase(xarr):
    nyears = len(np.unique(pd.to_datetime(xarr.time).year))
    nbase = max(nyears // n_ticks, 1)
    return(nbase)


def plot_appf(xarr, xds, n_pfts, sup_title, ylabel, output_path):

    xarr = appf_to_ap_by_pft(xarr, xds)

    n_age = len(xds.fates_levage)

    ncol,nrow = get_n_subplots(n_age)

    #nbase = getNBase(xarr) * 2

    fig, axes = plt.subplots(ncols=ncol,nrows=nrow,figsize=(12,10))

    for age,ax in zip(range(n_age),axes.ravel()):

         cca = xarr.isel(fates_levage = age) / xds.FATES_PATCHAREA_AP.isel(fates_levage = age)

         for p in range(n_pfts):
             cca.isel(fates_levpft=p).plot(x = "time",
                      color = pft_colors[p],lw = 3,add_legend = True,
                      label = pft_names[p], ax = ax)

             #plt.legend()
         ax.set_title('{} yr old patches'.format(xds.fates_levage.values[age]))
         ax.set_ylabel(ylabel,fontsize = int(12 * 0.75))
         ax.xaxis.set_major_formatter(DateFormatter('%Y'))
         #ax.xaxis.set_major_locator(mdates.YearLocator(base=nbase))

    plt.tight_layout()
    plt.subplots_adjust(hspace=1,wspace=0.2)
    fig.suptitle(sup_title, fontsize=12,y=0.99)


def make_output_path(base,case):
    
    directory = os.path.join(base,case)

    # Create the directory if it doesn't already exist
    if not os.path.exists(directory):
        os.makedirs(directory)
        print("Directory created successfully:",directory)
    else:
        print("Directory already exists:",directory)
        
    return directory


def get_ignition_success(ds, ignition_density):
    successful_ignitions = ds.FATES_IGNITIONS.values.mean() * s_per_yr * m2_per_km2
    ignition_success_rate = successful_ignitions / ignition_density
    return np.round(ignition_success_rate,3)


def get_mean_annual_burn_frac(ds,start_date,end_date):
    burnfrac = ds.FATES_BURNFRAC.sel(time = slice(start_date,end_date)).values.mean()  * s_per_yr
    return np.round(burnfrac,3)

def plot_mean_annual_burn_frac(ds,case):
    burnfrac = ds.FATES_BURNFRAC  * s_per_yr
    total_mean_annual_burnfrac = get_mean_annual_burn_frac(ds)

    annual_mean_burnfrac = burnfrac.groupby('time.year').mean(dim='time').values
    title = f"Mean annual burn fraction: {np.round(total_mean_annual_burnfrac,3)}"
    # Create a histogram of the distribution of annual means
    plt.hist(annual_mean_burnfrac, bins=20, edgecolor='black')
    plt.xlabel('Annual burn fraction')
    plt.ylabel('Frequency')
    plt.title(title)
    plt.savefig(output_path + "/" + case + "_" + title.replace(" ","-") + ".png")
    plt.clf()

    
def get_awfi(ds):
    aw_fi = ds.FATES_FIRE_INTENSITY_BURNFRAC / (ds.FATES_BURNFRAC * s_per_day) / 1000
    return aw_fi

def plot_area_weighted_fire_intensity(ds,case):
    aw_fi = get_awfi(ds)
    aw_fi.plot(marker = "o",linewidth = 0.5)
    plt.ylabel("Fire line intensity [kW m-1]")
    title = "Fire Intensity"
    plt.title(title)
    plt.savefig(output_path + "/" + case + "_" + title.replace(" ","-") + ".png")
    plt.clf()

def get_n_fire_months(ds):
    aw_fi = get_awfi(ds)
    n_months = len(aw_fi.values)
    aw_fi = aw_fi.where(~np.isnan(aw_fi), 0)
    n_fire_months_boolean = aw_fi > 0
    n_fire_months = np.sum(n_fire_months_boolean.values)
    return n_fire_months

def get_PHS_FLI_thresh(ds,FLI_thresh):
    
    aw_fi = get_awfi(ds)
    n_months_greater_than_thresh_boolean = aw_fi > FLI_thresh
    n_months_greater_than_thresh = np.sum(n_months_greater_than_thresh_boolean.values)
    n_fire_months = get_n_fire_months(ds)
    PHS = n_months_greater_than_thresh / n_fire_months * 100
    return PHS

    
def get_PHS(ds,start_date,end_date):
    ds = ds.sel(time = slice(start_date,end_date))
    
    n_fire_months = get_n_fire_months(ds)
    
    #disentangle the multiplexed size class X pft dimension
    mort_fire_by_pft_and_scls = scpf_to_scls_by_pft(ds.FATES_MORTALITY_FIRE_SZPF, ds)

    #get the monthly burned area to calculate mortality rates just on the burned area
    monthly_burnfrac = ds.FATES_BURNFRAC  * s_per_month

    #sum across size classes to get pft-level mort from fire
    mort_fire_by_pft = mort_fire_by_pft_and_scls.sum(axis=2)

    #per capita mort per month per unit area that burned
    mort_fire_per_capita_per_month_per_burned_area = mort_fire_by_pft / ds.FATES_NPLANT_PF / months_per_yr / monthly_burnfrac
    greater_than_95_mort_bool = mort_fire_per_capita_per_month_per_burned_area.sel(fates_levpft = slice(1,3)).mean(axis = 1) > 0.95
    greater_than_95_mort = np.sum(greater_than_95_mort_bool.values)
    #print("conifer mort greater_than_95_mort",greater_than_95_mort)

    greater_than_95_mort_bool_all_pfts = mort_fire_per_capita_per_month_per_burned_area.mean(axis = 1) > 0.95
    greater_than_95_mort_all_pfts = np.sum(greater_than_95_mort_bool_all_pfts.values)

    frac_greater_than_95_mort = greater_than_95_mort / n_fire_months
    #print("Conifer PHS:",np.round(frac_greater_than_95_mort,3))

    frac_greater_than_95_mort_all_pfts = greater_than_95_mort_all_pfts / n_fire_months
    #print("PHS:",np.round(frac_greater_than_95_mort_all_pfts,3))
    
    return np.round(frac_greater_than_95_mort_all_pfts,3) * 100


def get_PHS_conifer(ds,start_date,end_date):
    ds = ds.sel(time = slice(start_date,end_date))
    
    n_fire_months = get_n_fire_months(ds)
    
    #disentangle the multiplexed size class X pft dimension
    mort_fire_by_pft_and_scls = scpf_to_scls_by_pft(ds.FATES_MORTALITY_FIRE_SZPF, ds)

    #get the monthly burned area to calculate mortality rates just on the burned area
    monthly_burnfrac = ds.FATES_BURNFRAC  * s_per_month

    #sum across size classes to get pft-level mort from fire
    mort_fire_by_pft = mort_fire_by_pft_and_scls.sum(axis=2)

    #per capita mort per month per unit area that burned
    mort_fire_per_capita_per_month_per_burned_area = mort_fire_by_pft / ds.FATES_NPLANT_PF / months_per_yr / monthly_burnfrac
    greater_than_95_mort_bool = mort_fire_per_capita_per_month_per_burned_area.sel(fates_levpft = slice(1,3)).mean(axis = 1) > 0.95
    greater_than_95_mort = np.sum(greater_than_95_mort_bool.values)
    #print("conifer mort greater_than_95_mort",greater_than_95_mort)

    greater_than_95_mort_bool_all_pfts = mort_fire_per_capita_per_month_per_burned_area.mean(axis = 1) > 0.95
    greater_than_95_mort_all_pfts = np.sum(greater_than_95_mort_bool_all_pfts.values)

    frac_greater_than_95_mort = greater_than_95_mort / n_fire_months
    #print("Conifer PHS:",np.round(frac_greater_than_95_mort,3))

    frac_greater_than_95_mort_all_pfts = greater_than_95_mort_all_pfts / n_fire_months
    #print("PHS:",np.round(frac_greater_than_95_mort_all_pfts,3))
    
    return np.round(frac_greater_than_95_mort,3) * 100


def get_frac_pft_level_basal_area(ds,pft_i,date,dbh_min = 0):
    basal_area = scpf_to_scls_by_pft(ds.FATES_BASALAREA_SZPF, ds) 
    basal_area = basal_area.sel(fates_levscls = slice(dbh_min,None)).sel(time = date)
    total_basal_area = basal_area.sum(axis=1).sum(axis = 0)
    basal_area_pf = basal_area.isel(fates_levpft = pft_i).sum(axis = 0)
    frac_ba = basal_area_pf.values / total_basal_area.values
    return frac_ba


def write_fire_report(ds,ignition_density,output_path,case):
    
    original_stdout = sys.stdout
    
    mean_burn_frac = get_mean_annual_burn_frac(ds)
    
    with open(output_path + '/' + 'fire_report.txt', 'w') as f:
        sys.stdout = f # Change the standard output to the file we created.
        
        print("case:",case)
        print("Mean annual burn frac:",mean_burn_frac)
        print("Mean FRI:",1 / mean_burn_frac)
        print('PHS (> 95% mort):',get_PHS(ds))
        print('PHS (> 3500 kW m-1):',get_PHS_FLI_thresh(ds,FLI_thresh))
        print('Ignition success:',get_ignition_success(ds, ignition_density))
        
        sys.stdout = original_stdout
        
        
def plus_minus_20_pct(number):
    plus_20 = number * 1.2
    minus_20 = number * 0.8
    return [plus_20,minus_20]



def check_fire_regime(ds):

    FLI_threshold = 3500

    burnfrac = ds.FATES_BURNFRAC  * s_per_yr
    print("Mean annual burn fraction",burnfrac.values.mean())
    fri = 1 / (ds.FATES_BURNFRAC.values.mean() * s_per_yr)
    print("Mean fire return interval (yrs):",fri)

    aw_fi = ds.FATES_FIRE_INTENSITY_BURNFRAC / (ds.FATES_BURNFRAC * s_per_day) / 1000
    n_months = len(aw_fi.values)
    print("n months",n_months)
    aw_fi = aw_fi.where(~np.isnan(aw_fi), 0)
    n_fire_months_boolean = aw_fi > 0
    n_fire_months = np.sum(n_fire_months_boolean.values)
    print("n fire months", n_fire_months)
    n_months_greater_than_thresh_boolean = aw_fi > FLI_threshold
    n_months_greater_than_thresh = np.sum(n_months_greater_than_thresh_boolean.values)
    print("n months > threshold",n_months_greater_than_thresh)

    print("Fraction of fire months that burned hotter than X",n_months_greater_than_thresh / n_fire_months)

    #disentangle the multiplexed size class X pft dimension
    mort_fire_by_pft_and_scls = scpf_to_scls_by_pft(ds.FATES_MORTALITY_FIRE_SZPF, ds)

    #get the monthly burned area to calculate mortality rates just on the burned area
    monthly_burnfrac = ds.FATES_BURNFRAC  * s_per_month

    #sum across size classes to get pft-level mort from fire
    mort_fire_by_pft = mort_fire_by_pft_and_scls.sum(axis=2)

    #per capita mort per month per unit area that burned
    mort_fire_per_capita_per_month_per_burned_area = mort_fire_by_pft / ds.FATES_NPLANT_PF / months_per_yr / monthly_burnfrac
    greater_than_95_mort_bool = mort_fire_per_capita_per_month_per_burned_area.sel(fates_levpft = slice(1,3)).mean(axis = 1) > 0.95
    greater_than_95_mort = np.sum(greater_than_95_mort_bool.values)
    print("conifer mort greater_than_95_mort",greater_than_95_mort)

    greater_than_95_mort_bool_all_pfts = mort_fire_per_capita_per_month_per_burned_area.mean(axis = 1) > 0.95
    greater_than_95_mort_all_pfts = np.sum(greater_than_95_mort_bool_all_pfts.values)


    frac_greater_than_95_mort = greater_than_95_mort / n_fire_months
    print("frac_greater_than_95_mort",frac_greater_than_95_mort)

    frac_greater_than_95_mort_all_pfts = greater_than_95_mort_all_pfts / n_fire_months
    print("frac_greater_than_95_mort_all_pfts",frac_greater_than_95_mort_all_pfts)
    
def plot_ba(ds,ax,title):
        #disentangle the multiplexed size class X pft dimension
        basal_area = scpf_to_scls_by_pft(ds.FATES_BASALAREA_SZPF, ds)

        #sum across size classes to get pft-level ba
        basal_area_pf = basal_area.sum(axis=2)

        #plot pft-level basal area over time
        for p in range(n_pfts):
            ba_per_ha = basal_area_pf.isel(fates_levpft=p) * m2_per_ha
            ba_per_ha.plot(x = "time", color = pft_colors[p],lw = 5, add_legend = True, marker = "o", ax = ax)

        ax.set_title(title)
        ax.set_ylabel('BA [m-2 ha-1]', fontsize=12)
        
        
    
    
def getFullFilePaths(case,start_year,end_year):
    
    years = list(range(int(start_year), int(end_year))) 
    months = list(range(1, 13, 1)) 
    file_names = [f"{case}.clm2.h0.{str(year)}-{str(month).rjust(2, '0')}.nc"
              for year in years for month in months]

    full_paths = [os.path.join(archive_path, case, 'lnd/hist', fname) for fname in file_names]
    return full_paths


def preprocess(ds, fields):
    '''Selects the variables we want to read in 
       Drops lndgrid because we are on a single point'''

    return ds[fields].sel(lndgrid=0)


def fix_time(ds):
    '''Does a quick fix to adjust time vector for monthly data'''
    nmonths = len(ds.time)
    yr0 = ds['time.year'][0].values
    #ds['time'] = xr.cftime_range(str(yr0), periods=nmonths, freq='MS')
    ds['time'] = pd.date_range(start=str(yr0),periods=nmonths,freq="MS")
    return ds


def scpf_to_scls_by_pft(scpf_var, dataset):
    """function to reshape a fates multiplexed size and pft-indexed variable to one indexed by size class and pft
    first argument should be an xarray DataArray that has the FATES SCPF dimension
    second argument should be an xarray Dataset that has the FATES SCLS dimension 
    (possibly the dataset encompassing the dataarray being transformed)
    returns an Xarray DataArray with the size and pft dimensions disentangled"""
    n_scls = len(dataset.fates_levscls)
    ds_out = (scpf_var.rolling(fates_levscpf=n_scls, center=False)
            .construct("fates_levscls")
            .isel(fates_levscpf=slice(n_scls-1, None, n_scls))
            .rename({'fates_levscpf':'fates_levpft'})
            .assign_coords({'fates_levscls':dataset.fates_levscls})
            .assign_coords({'fates_levpft':dataset.fates_levpft}))
    #ds_out.attrs['long_name'] = scpf_var.attrs['long_name']
    #ds_out.attrs['units'] = scpf_var.attrs['units']
    return(ds_out)

def agefuel_to_age_by_fuel(agefuel_var, dataset):
    n_age = len(dataset.fates_levage)
    ds_out = (agefuel_var.rolling(fates_levagefuel = n_age, center=False).construct("fates_levage")
          .isel(fates_levagefuel=slice(n_age-1, None, n_age))
          .rename({'fates_levagefuel':'fates_levfuel'})
          .assign_coords({'fates_levage':dataset.fates_levage})
          .assign_coords({'fates_levfuel':np.array([1,2,3,4,5,6])}))
    return ds_out
    #ds_out.attrs['long_name'] = agefuel_var['long_name']
    #ds_out.attrs['units'] = agefuel_var['units']


def appf_to_ap_by_pft(appf_var, dataset):
    """function to reshape a fates multiplexed size and pft-indexed variable to one indexed by size class and pft
    first argument should be an xarray DataArray that has the FATES SCPF dimension
    second argument should be an xarray Dataset that has the FATES SCLS dimension 
    (possibly the dataset encompassing the dataarray being transformed)
    returns an Xarray DataArray with the size and pft dimensions disentangled"""
    n_ap = len(dataset.fates_levage)
    ds_out = (appf_var.rolling(fates_levagepft=n_ap, center=False)
            .construct("fates_levage")
            .isel(fates_levagepft=slice(n_ap-1, None, n_ap))
            .rename({'fates_levagepft':'fates_levpft'})
            .assign_coords({'fates_levage':dataset.fates_levage})
            .assign_coords({'fates_levpft':dataset.fates_levpft}))
    #ds_out.attrs['long_name'] = scpf_var.attrs['long_name']
    #ds_out.attrs['units'] = scpf_var.attrs['units']
    return(ds_out)

def get_n_subplots(n_pfts): 
 
    if (n_pfts % 2 == 0) | (n_pfts == 1): 
        n_subplots = n_pfts 
    else: 
        n_subplots = n_pfts + 1

    if n_subplots == 1:
        ncol = 1
        nrow = 1

    else:
        ncol = 2
        nrow = n_subplots / ncol

    return (ncol,int(nrow))

def per_capita_rate(xarr,xds,unit_conversion):
    
    xarr = xarr * unit_conversion
    
    if xarr.dims == ('time', 'fates_levscpf'):
        xarr = scpf_to_scls_by_pft(xarr, xds)
        xarr = xarr.sum(axis=2) #sum across size classes
        
    xarr_per_cap = xarr / xds.FATES_NPLANT_PF
    
    return(xarr_per_cap)


def get_rate_table(xarr,xds,var_title,indices,index_title):
    
    if xarr.dims == ('time', 'fates_levage'):
        xarr = xarr.isel(time = slice(-12,-1))
        series = pd.DataFrame(xarr.mean(axis = 0).values,
                     index=xds.fates_levage.values)

    if xarr.dims == ('time', 'fates_levagepft'):
        xarr = appf_to_ap_by_pft(xarr,xds)
        xarr = xarr / xds.FATES_PATCHAREA_AP
        series = pd.DataFrame(xarr.mean(axis = 0).values,
                     index = indices, columns=xds.fates_levage.values)
        series.loc["Total"] = series.sum()
        tab = tabulate(series, headers="keys", tablefmt="psql")
        return(tab)

    if xarr.dims == ('time', 'levgrnd'):
        grnd_depths = xds.levgrnd.values[indices]
        xarr = xarr.isel(levgrnd = indices).isel(time = slice(12,-1)) * MPa_per_mmh2o
        series = pd.DataFrame(xarr.mean(axis = 0).values,
                          index = grnd_depths).sort_values(by = 0, ascending=True).reset_index()
    else:
        series = pd.DataFrame(xarr.mean(axis = 0).values,
                          index = indices).sort_values(by = 0, ascending=False).reset_index()
    
    my_dict = {index_title:list(series.iloc[:,0]), var_title:list(series.iloc[:,1])}
    my_df = pd.DataFrame.from_dict(my_dict).set_index(index_title)
    tab = tabulate(my_df, headers='keys', tablefmt='psql')
    return(tab)

def weighted_avg_par(par_stream,frac_in_canopy):
    par_z = (par_stream.isel(fates_levcnlf = 0) * frac_in_canopy) +\
    (par_stream.isel(fates_levcnlf = 30) * (1 - frac_in_canopy))
    return(par_z)


def frac_in_canopy(xds):
    return(xds.FATES_CANOPYCROWNAREA_PF / xds.FATES_CROWNAREA_PF)

def incident_par(xds):
    f = frac_in_canopy(xds)

    par_z_dir = weighted_avg_par(xds.FATES_PARPROF_DIR_CLLL, f)
    par_z_dif = weighted_avg_par(xds.FATES_PARPROF_DIF_CLLL, f)
    par_total = par_z_dir + par_z_dif

    return(par_total.rolling(time=12, center=True).mean())

def cca_by_patch_age(ds,ax,title):

    canopy_crown_area_by_patch_age = appf_to_ap_by_pft(ds.FATES_CROWNAREA_APPF, ds)

    canopy_crown_area_by_patch_age = canopy_crown_area_by_patch_age / ds.FATES_PATCHAREA_AP

    for p in range(n_pfts):
        canopy_crown_area_by_patch_age.isel(fates_levpft = p).mean(axis = 0).plot(x = "fates_levage",color = pft_colors[p], linewidth = 3, ax = ax, marker = "o")

    ax.set_title(title)
    ax.set_xlabel("Patch age bin (yrs)")
    ax.set_ylabel("Total crown area [m2 m-2]")

    
def is_xarray_dataset(obj):
    return isinstance(obj, xr.Dataset)

def filter_data(ds,start,stop):
    if is_xarray_dataset(ds):
        return ds.sel(time = slice(start, stop))
    else:  
        return None
    
    
def agefuel_to_age_by_fuel(agefuel_var, dataset): 
    n_age = len(dataset.fates_levage)
    ds_out = (agefuel_var.rolling(fates_levagefuel = n_age, center=False).construct("fates_levage")
          .isel(fates_levagefuel=slice(n_age-1, None, n_age))
          .rename({'fates_levagefuel':'fates_levfuel'})
          .assign_coords({'fates_levage':dataset.fates_levage})
          .assign_coords({'fates_levfuel':np.array([1,2,3,4,5,6])}))
    return ds_out
    #ds_out.attrs['long_name'] = agefuel_var['long_name']
    #ds_out.attrs['units'] = agefuel_var['units']

def get_area_weighed_FLI(ds):    
    return (ds.FATES_FIRE_INTENSITY_BURNFRAC / (ds.FATES_BURNFRAC * s_per_day) / 1000)


def get_per_capita_fire_mort_by_scls(ds):
    mort_fire_by_pft_and_scls = scpf_to_scls_by_pft(ds.FATES_MORTALITY_FIRE_SZPF, ds)
    nplant_by_pft_scls = scpf_to_scls_by_pft(ds.FATES_NPLANT_SZPF, ds)
    return(mort_fire_by_pft_and_scls / nplant_by_pft_scls)

# Load data

In [None]:
fates_data = []

for i,c in enumerate(cases):
    fates_data.append(multiple_netcdf_to_xarray(sim_root, c, years[i], fields))

# Basal Area

In [None]:
fig, axes = plt.subplots(ncols=2,nrows=math.ceil(len(scenarios)/2),figsize=(12,10), sharey=True)

for s,ax in zip(range(len(scenarios)),axes.ravel()):
    print(scenarios[s])
    plot_ba(fates_data[s], ax, scenarios[s])
    
plt.tight_layout()
plt.subplots_adjust(hspace=0.5,wspace=0.2)
plt.suptitle('Basal Area ({})'.format(param_set),y = 1.035)

if make_report == True:
    plt.savefig(os.path.join(output_path,"BA.png"))

# Crown area

In [None]:
crown_area = appf_to_ap_by_pft(fates_data[s].FATES_CROWNAREA_APPF,fates_data[s])
crown_area

In [None]:
fates_data[s].FATES_CROWNAREA_APPF

In [None]:
fig, axes = plt.subplots(ncols=2,nrows=math.ceil(len(scenarios)/2),figsize=(12,10), sharey=True)

for s,ax in zip(range(len(scenarios)),axes.ravel()):
    print(scenarios[s])
    crown_area = appf_to_ap_by_pft(fates_data[s].FATES_CROWNAREA_APPF,fates_data[s])
    for p in range(n_pfts):
        crown_area.sum(axis=2).isel(fates_levpft = p).plot(ax = ax, color = pft_colors[p])
    
plt.tight_layout()
plt.subplots_adjust(hspace=0.5,wspace=0.2)
plt.suptitle('Crown Area ({})'.format(param_set),y = 1.035)

if make_report == True:
    plt.savefig(os.path.join(output_path,"Crown_area.png"))

# Burned Area

In [None]:
positions = list(range(len(scenarios)))
max_position = np.max(np.array(positions))

fig, ax = plt.subplots()

for x,x_label in enumerate(scenarios):
    
    #add observations
    if x < 1:
        legend_label = "Observations"
    else:
        legend_label = None
    
    #add observations
    ax.scatter(x, burnFrac_obs[x], color='r', label = legend_label)
    ax.plot([x,x], [burnFrac_uncertainty_low[x], burnFrac_uncertainty_high[x]], color = "r")
    
    #add fates prediction
    
    if x < 1:
        legend_label = "FATES"
    else:
        legend_label = None
    
    print(x_label)
    
    ax.scatter(x, 
               get_mean_annual_burn_frac(fates_data[x],None,None), color='b', marker = "^", label = legend_label)
    
ax.set_xticks(positions)
ax.set_xticklabels(scenarios)
ax.set_xlim(-1, max_position + 1)

plt.xticks(rotation=45)
plt.ylabel("Annual burned fraction")
plt.title('Annual burned fraction ({})'.format(param_set))
plt.legend()
if make_report == True:
    plt.savefig(os.path.join(output_path,"Burned_Fraction.png"))
plt.show()

# Percent high severity

In [None]:
fig, ax = plt.subplots()

for x,x_label in enumerate(scenarios):
    
    #add observations
    if x < 1:
        legend_label = "Observations"
    else:
        legend_label = None
    
    #add observations
    ax.scatter(x, PHS_obs[x], color='r', label = legend_label)
    
    #add fates prediction
    print(x_label)
    
    if x < 1:
        legend_label = "FATES"
    else:
        legend_label = None
    
    ax.scatter(x, 
               get_PHS_conifer(fates_data[x],None,None), color='b', marker = "^", label = legend_label)
    
ax.set_xticks(positions)
ax.set_xticklabels(scenarios)
ax.set_xlim(-1, max_position + 1)

plt.xticks(rotation=45)
plt.ylabel("Percentage high severity")
plt.title('Percentage high severity ({})'.format(param_set))
plt.legend()
if make_report == True:
    plt.savefig(os.path.join(output_path,"PHS.png"))
plt.show()

# Percent of fires above 3500 kW

In [None]:
fig, ax = plt.subplots()

for x,x_label in enumerate(scenarios):
    
    #add observations
    if x < 1:
        legend_label = "Observations"
    else:
        legend_label = None
    
    #add observations
    ax.scatter(x, PHS_obs[x], color='r', label = legend_label)
    
    #add fates prediction
    print(x_label)
    
    if x < 1:
        legend_label = "FATES"
    else:
        legend_label = None
    
    ax.scatter(x, 
               get_PHS_FLI_thresh(fates_data[x],3500), color='b', marker = "^", label = legend_label)
    
ax.set_xticks(positions)
ax.set_xticklabels(scenarios)
ax.set_xlim(-1, max_position + 1)

plt.xticks(rotation=45)
plt.ylabel("Percentage fires > 3500 kW")
plt.title('Percentage high severity FLI ({})'.format(param_set))
plt.legend()
if make_report == True:
    plt.savefig(os.path.join(output_path,"PHS_3500_kW.png"))
plt.show()

# PFT climate responses

In [None]:
positions = list(range(len(pft_names)))
max_position = np.max(np.array(positions))


fig, ax = plt.subplots()

for x,x_label in enumerate(pft_names):
    
    if x < 1:
        legend_label = scenarios[climate_response_comparison_indices[0]]
    else:
        legend_label = None
    
    #dry site
    ax.scatter(x, get_frac_pft_level_basal_area(fates_data[climate_response_comparison_indices[0]],x,basal_area_year_climate_comparison,0),
                                                color='black', label = legend_label)
    
    
    
    if x < 1:
        legend_label = scenarios[climate_response_comparison_indices[1]]
    else:
        legend_label = None
    
    ax.scatter(x, get_frac_pft_level_basal_area(fates_data[climate_response_comparison_indices[1]],x,basal_area_year_climate_comparison,0),
                                                color='black', marker = "^",label = legend_label)
    
ax.set_xticks(positions)
ax.set_xticklabels(pft_names)
ax.set_xlim(-1, max_position + 1)

plt.xticks(rotation=45)
plt.ylabel("Fraction of total basal area")
plt.title("PFT climate response ({})".format(param_set))
plt.legend()
if make_report == True:
    plt.savefig(os.path.join(output_path,"climate_response.png"))
plt.show()

# PFT fire responses

In [None]:
fire_response_comparison_indices[1]

In [None]:
fig, ax = plt.subplots()

for x,x_label in enumerate(pft_names):
    
    if x < 1:
        legend_label = scenarios[fire_response_comparison_indices[0]]
    else:
        legend_label = None
    
    #Fire Suppression
    ax.scatter(x, get_frac_pft_level_basal_area(fates_data[fire_response_comparison_indices[0]],x,basal_area_year_fire_comparison[0],0),
                                                color='black', label = legend_label)
    
    
    
    if x < 1:
        legend_label = scenarios[fire_response_comparison_indices[1]]
    else:
        legend_label = None
    
    ax.scatter(x, get_frac_pft_level_basal_area(fates_data[fire_response_comparison_indices[1]],x,basal_area_year_fire_comparison[1],0),
                                                color='black', marker = "^",label = legend_label)
    
ax.set_xticks(positions)
ax.set_xticklabels(pft_names)
ax.set_xlim(-1, max_position + 1)

plt.xticks(rotation=45)
plt.ylabel("Fraction of total basal area")
plt.title("PFT response to fire ({})".format(param_set))
plt.legend()
if make_report == True:
    plt.savefig(os.path.join(output_path,"fire_response.png"))
plt.show()

# Post-fire regeneration crown area

In [None]:
fig, axes = plt.subplots(ncols=2,nrows=math.ceil(len(scenarios)/2),figsize=(12,10))

for s,ax in zip(range(len(scenarios)),axes.ravel()):
    print(scenarios[s])
    cca_by_patch_age(fates_data[s], ax, scenarios[s])
    
plt.tight_layout()
plt.subplots_adjust(hspace=0.5,wspace=0.2)
plt.suptitle('Regeneration ({})'.format(param_set),y = 1.035)

if make_report == True:
    plt.savefig(os.path.join(output_path,"regen.png"))

# Fire analysis

## Fuel load over time among scenarios

In [None]:
age_by_fuel = agefuel_to_age_by_fuel(fates_data[i].FATES_FUEL_AMOUNT_APFC,fates_data[i])
age_by_fuel.isel(fates_levfuel=2).isel(fates_levage=2).plot()

In [None]:
plt.rcParams.update({'axes.titlesize': 'large', 'axes.labelsize':'large'})

fig, axes = plt.subplots(ncols=2,nrows=3,figsize=(10,10), sharey = False)

fuel_class_names = ['twig','small branch','large branch','trunk','dead leaves','live grass']

fuel_class_obs = [0.0525,0.192,0.27,0.792,1.32,0.0]

for p,ax in zip(range(6),axes.ravel()):
    for i,s in enumerate(scenarios):
        age_by_fuel = agefuel_to_age_by_fuel(fates_data[i].FATES_FUEL_AMOUNT_APFC,fates_data[i])
        fates_fuel_amount_by_class = age_by_fuel.sum(axis = 2)
        fates_fuel_amount_by_class.isel(fates_levfuel=p).plot(x = "time",
                                                                                                           lw = 2,
                                                                                                           ax = ax,
                                                                                                           label = s,
                                                                                                           color = scenario_colors[i],
                                                                                                           linestyle = scenario_line_styles[i])
        
        ax.axhline(y=fuel_class_obs[p], color='black', linestyle = "dashed")
        ax.set_ylabel("Fuel amount [Kg C m-2]")
        ax.set_title(fuel_class_names[p])
    
plt.tight_layout()
plt.legend()
plt.show()

## Fuel load over patch age among scenarios



The below shows that fuel accumulates very fast on young patches in the scenarios dominated by shrubs.

In [None]:
age_by_fuel = agefuel_to_age_by_fuel(fates_data[i].FATES_FUEL_AMOUNT_APFC,fates_data[i])
age_by_fuel = age_by_fuel.mean(axis = 0)
mean_patch_area = fates_data[i].FATES_PATCHAREA_AP.mean(axis = 0)
mean_patch_area.values

In [None]:
age_by_fuel / mean_patch_area

In [None]:
plt.rcParams.update({'axes.titlesize': 'large', 'axes.labelsize':'large'})

fig, axes = plt.subplots(ncols=2,nrows=3,figsize=(10,10), sharey = False)

fuel_class_names = ['twig','small branch','large branch','trunk','dead leaves','live grass']

for p,ax in zip(range(6),axes.ravel()):
    for i,s in enumerate(scenarios):
        age_by_fuel = agefuel_to_age_by_fuel(fates_data[i].FATES_FUEL_AMOUNT_APFC,fates_data[i])
        age_by_fuel = age_by_fuel.mean(axis = 0)
        mean_patch_area = fates_data[i].FATES_PATCHAREA_AP.mean(axis = 0)
        age_by_fuel_normalized = age_by_fuel / mean_patch_area
        age_by_fuel_normalized.isel(fates_levfuel=p).plot(x = "fates_levage",lw = 2,ax = ax,label = s,color = scenario_colors[i],
                                                                      linestyle = scenario_line_styles[i], marker = "o")
        ax.set_ylabel("Fuel amount [Kg C m-2]")
        ax.set_title(fuel_class_names[p])
    
plt.tight_layout()
plt.legend()
plt.show()

## Fire line intensity

In [None]:
fig, axes = plt.subplots(ncols=2,nrows=int(len(scenarios)/2),figsize=(10,10), sharey = True)

for s,ax in zip(range(len(scenarios)),axes.ravel()):
    print(scenarios[s])
    aw_fi = get_area_weighed_FLI(fates_data[s])
    aw_fi.plot(ax = ax, marker = "o",linewidth = 0.5,color = scenario_colors[s], linestyle = scenario_line_styles[s])
    ax.set_ylabel("Fire line intensity [kW m-1]")
    ax.set_title(scenarios[s])
    
plt.tight_layout()

## Size-based fire mortality

In [None]:
fig, axes = plt.subplots(ncols=2,nrows=3,figsize=(10,10), sharey = True)

for p,ax in zip(range(n_pfts),axes.ravel()):
    for i,s in enumerate(scenarios):
        per_capita_fire_mort_by_scls = get_per_capita_fire_mort_by_scls(fates_data[i])
        per_capita_fire_mort_by_scls.mean(axis = 0).isel(fates_levpft = p).plot(x = "fates_levscls", ax = ax,label = scenarios[i],color = scenario_colors[i],
                                                                      linestyle = scenario_line_styles[i], marker = "o")

        ax.set_ylabel("[N per capita yr -1]")
        ax.set_title("{} fire mort.".format(pft_names[p]))
        if p == 0:
            ax.legend()


plt.tight_layout()
plt.show()