In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import product

In [None]:
def scc_hist(
    kind,
    sectors=['coastal', 'agriculture_clipped', 'mortality_clipped', 'energy_clipped', 'labor_clipped'],
    recipe='risk_aversion',
    disc = 'euler_ramsey',
    random=False,
):

    if kind=='epa_mask':
        indir = "/mnt/CIL_integration/rff3_with_disc_factors/iter0-19/{}/2020/unmasked_None/"
    elif kind=='gdppc_mask':
        indir='/mnt/CIL_integration/rff3_expost_mask/iter0-19/{}/2020/gdppc_emissions_q0.01_q0.99'
    elif 'ssp' in kind:
        indir = '/mnt/CIL_integration/menu_results_AR6_bc39_epa_vsl_continuous/{}/2020/unmasked'

    near_term = {
        # 0.03: (1.567899395, 0.00770271076),
        0.025 : (1.421158116, 0.00461878399),
        0.02 : (1.244459066, 0.00197263997),
        0.015 : (1.016010255, 9.149608e-05)
    }

    sccs = []

    for rate, eta_rho in near_term.items():
        for sector in sectors:
            df = xr.open_dataset(
                indir.format(sector) + f"/{recipe}_{disc}_eta{eta_rho[0]}_rho{eta_rho[1]}_uncollapsed_sccs.nc4"
            ).sel(weitzman_parameter='0.5')

            if kind=='epa_mask':
                rffsp_values = pd.read_csv("/mnt/sacagawea_shares/gcp/integration/rff/rffsp_trials.csv").rffsp_id.unique()
                df = df.sel(rff_sp=rffsp_values)

            if random==True:
                # create random variable
                df['random'] = xr.DataArray(
                data=np.random.rand(*list(df['uncollapsed_sccs'].shape)),
                dims=df['uncollapsed_sccs'].dims,
                coords=df['uncollapsed_sccs'].coords
            )
                # random selection of one simulation for each rff_sp
                df = df.uncollapsed_sccs.sel(simulation=df['random'].idxmax('simulation'), drop=True)
            else:
                df = df.uncollapsed_sccs

            df = df.to_dataframe()
            df['Ramsey discount rate'] = rate
            df['sector'] = sector
            sccs.append(df.reset_index())

    data = pd.concat(sccs).reset_index()
    data['sector'] = pd.Categorical(data.sector)
    data['sector_code'] = data['sector'].cat.codes

    subsets = {}
    if kind=='ssp_unique':
        for s,m, in product(data.ssp.unique(), data.model.unique()):
            subsets[f"{m}_{s}"] = data.loc[(data.ssp == s) & (data.model == m)]
    else:
        subsets[''] = data

    fig, ax = plt.subplots(3,len(subsets),figsize=(8*len(subsets),5), sharex=True, squeeze=False)

    for i, s in enumerate(subsets.items()):

        pal = sns.color_palette('husl', 4)

        for d, disc in enumerate(data['Ramsey discount rate'].unique()):
            sns.boxplot(data=s[1].loc[(s[1]['Ramsey discount rate'] == disc)],
                        x='uncollapsed_sccs',
                        y='sector',
                        hue='sector',
                        orient='h',
                        showfliers=False,
                        whis=[5, 95],
                        dodge=False,
                        ax=ax[d][i],
                        palette=pal,
                        linewidth=0.5,
                       )
        
            ax[d][i].set_title(f"{s[0]} {disc}")
            ax[d][i].set_ylabel("")
            ax[d][i].set_xlabel("")
            ax[d][i].get_yaxis().set_visible(False)
            
            if (d > 0) or (i > 0):
                ax[d][i].get_legend().remove()
            else:
                ax[d][i].legend(bbox_to_anchor=(0,0))

        plt.subplots_adjust(wspace=0, hspace=0.5, top=0.8)

    fig.suptitle(f"Partial SCCs \n {kind}, random-{random}")
    plt.savefig(f'/mnt/CIL_integration/rff_diagnostics/v3/partial_sccs/{kind}_partial_sccs_random-{random}.png', dpi=300, bbox_inches='tight')

In [None]:
for kind in ['ssp_unique', 'ssp', 'epa_mask', 'gdppc_mask']:
    scc_hist(kind=kind)