In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import seaborn as sns
import os, sys
import matplotlib.pyplot as plt
from dscim.diagnostics.damage_function import damage_function

In [None]:
USA=False

if USA==True:
    sectors=['CAMEL_USA']
else:
    sectors=['CAMEL', 'mortality', 'energy', 'labor', 'agriculture']

eta_rhos = {
    2.0 :0.0,
    1.016010255: 9.149608e-05,
    1.244459066: 0.00197263997,
    1.421158116: 0.00461878399,
    1.567899395: 0.00770271076,
}

rff=True

In [None]:
def plot_stacked(sector,
                 recipe='risk_aversion',
                disc='constant',
                eta=2.0,
                 rho=0.0,
                 xlim=(-1, 8),
                years = [2020, 2050, 2090, 
                         2100, 2200, 2300],
                 sharey=False,
                ):

    if USA==True:
        root_rff=f'/mnt/CIL_integration/USA_SCC_rff2/{sector}/2020/unmasked_None'
        root_ssp=f'/mnt/CIL_integration/USA_SCC/{sector}/2020/unmasked'
        coefs_rff=f'/mnt/CIL_integration/damage_function_library/damage_function_library_USA_SCC_rff2/{sector}/'
        coefs_ssp=f'/mnt/CIL_integration/damage_function_library/damage_function_library_USA_SCC/{sector}/'
        output = '/mnt/CIL_integration/plots/rff_diagnostics_USA/rff_ssp_stacked_damage_functions/'

    else:
        root_rff=f'/mnt/CIL_integration/rff2/{sector}/2020/unmasked_None'
        root_ssp=f'/mnt/CIL_integration/menu_results_AR6_epa/{sector}/2020/unmasked'
        coefs_rff=f'/mnt/CIL_integration/damage_function_library/damage_function_library_rff2/{sector}/'
        coefs_ssp=f'/mnt/CIL_integration/damage_function_library/damage_function_library_epa/{sector}/'
        output = '/mnt/CIL_integration/plots/rff_diagnostics/rff_ssp_stacked_damage_functions/'

    # overlap ssp and rff
    quantiles = [0.01, 0.05, 0.25, 0.5, 0.75, 0.95, 0.99]
    ds_rff_mean = xr.open_dataset(f"{coefs_rff}/{recipe}_{disc}_eta{eta}_rho{rho}_damage_function_coefficients.nc4").sel(year=years).mean('runid')
    ds_rff = xr.open_dataset(f"{coefs_rff}/{recipe}_{disc}_eta{eta}_rho{rho}_damage_function_coefficients.nc4").sel(year=years).quantile(quantiles, 'runid')
    ds_ssp = xr.open_dataset(f"{coefs_ssp}/{recipe}_{disc}_eta{eta}_rho{rho}_damage_function_coefficients.nc4").sel(year=years)

    temps = xr.DataArray(
        np.arange(xlim[0], xlim[1], 0.1),
        coords={
            'anomaly' : np.arange(xlim[0], xlim[1], 0.1)
        })
    
    fit_rff_mean = ds_rff_mean['anomaly'] * temps + ds_rff_mean['np.power(anomaly, 2)'] * temps ** 2
    fit_rff_mean = fit_rff_mean.to_dataframe('fit').reset_index()
    fit_rff_mean['model'] =  "RFF mean"
    fit_rff_mean = fit_rff_mean[['fit','year','model','anomaly']]

    fit_rff = ds_rff['anomaly'] * temps + ds_rff['np.power(anomaly, 2)'] * temps ** 2
    fit_rff = fit_rff.to_dataframe('fit').reset_index()
    fit_rff['model'] =  "quantile: " + fit_rff['quantile'].astype(str)
    fit_rff = fit_rff[['fit','year','model','anomaly']]

    fit_ssp = ds_ssp['anomaly'] * temps + ds_ssp['np.power(anomaly, 2)'] * temps ** 2
    fit_ssp = fit_ssp.to_dataframe('fit').reset_index()
    fit_ssp["model"] = fit_ssp.ssp + "-" + fit_ssp.model
    fit_ssp = fit_ssp[['fit','year','model','anomaly']]

    if rff==True:
        fit = pd.concat([fit_ssp, fit_rff, fit_rff_mean]).set_index(['year','model','anomaly'])
        pal=sns.color_palette("Paired", 6) + sns.color_palette("Greys", len(quantiles) + 1)
    else:
        fit = fit_ssp.set_index(['year','model','anomaly'])
        pal=sns.color_palette("Paired", 6)

    g = sns.relplot(
        data=fit,
             x='anomaly',
             y='fit',
             hue='model',
             col='year',
             col_wrap=3,
             kind='line',
             palette = pal,
             facet_kws={'sharey': sharey, 'sharex': True},
             legend='full',
            )

    g.fig.suptitle(f"{sector} {recipe} {disc} eta={eta} rho={rho}")
    
    plt.subplots_adjust(top=0.85)

    os.makedirs(output, exist_ok=True)
    plt.savefig(f'{output}/stacked_{sector}_{recipe}_{disc}_{eta}_{rho}_xlim{xlim}_rff-{rff}_years{years}.png', 
                bbox_inches='tight', 
                dpi=300
               )
    
    # plt.close()

In [None]:
for sector in sectors:
    for recipe in ['adding_up', 'risk_aversion']:
        for disc in ['constant', 'euler_ramsey']:
            for eta, rho in eta_rhos.items():
                plot_stacked(sector, 
                             recipe=recipe,
                             disc=disc,
                             eta=eta, 
                             rho=rho, 
                             xlim=(0,8),
                             years=[2020, 2050, 2075, 2100, 2200, 2300])