In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os, sys, yaml
USER = os.getenv('USER')
from dscim.menu.simple_storage import Climate

In [None]:
sectors=['CAMEL_clipped']
# sectors=['agriculture_clipped', 'mortality_clipped', 'energy_clipped', 'labor_clipped', 'AMEL_clipped', "CAMEL_clipped"]

In [None]:
def scc_hist(
    sector,
    kind,
    recipe='risk_aversion',
    disc = 'euler_ramsey',
    flip=False,
    log=True,
):

    # parameters
    if 'ssp' in kind:
        config = f'/home/{USER}/repos/integration/configs/epa_tool_config-histclim_AR6.yaml'
        socioec = xr.open_zarr(
            '/shares/gcp/integration/float32/dscim_input_data/econvars/zarrs/integration-econ-bc39.zarr'
        ).sum('region').sel(year=2100, ssp=['SSP2', 'SSP3', 'SSP4'], drop=True)
        gdppc = socioec.gdp/socioec.pop
        idx=['rcp', 'simulation']
        
    else:
        config = f'/home/{USER}/repos/integration/configs/rff_config.yaml'
        socioec = xr.open_dataset(
        "/shares/gcp/integration/rff/socioeconomics/rff_global_socioeconomics.nc4"
        ).sel(region='world', year=2300, drop=True)
        gdppc = socioec.gdp/socioec.pop
        idx=['rff_sp', 'simulation']

    with open(config) as config_file:
        params = yaml.full_load(config_file)

    anomalies = (Climate(**params['climate'])
                 .fair_pulse
                 .sel(year=2300)
                 .to_dataframe()
                 .reset_index()
                )

    if kind=='epa_mask':
        indir = f"/mnt/CIL_integration/rff3_with_disc_factors/iter0-19/{sector}/2020/unmasked_None/"
    elif kind=='gdppc_mask':
        indir=f'/mnt/CIL_integration/rff3_expost_mask/iter0-19/{sector}/2020/gdppc_emissions_q0.01_q0.99'
    elif 'ssp' in kind:
        indir = f'/mnt/CIL_integration/menu_results_AR6_bc39_epa_vsl_continuous/{sector}/2020/unmasked'

    near_term = {
        # 0.03: (1.567899395, 0.00770271076),
        0.025 : (1.421158116, 0.00461878399),
        0.02 : (1.244459066, 0.00197263997),
        0.015 : (1.016010255, 9.149608e-05)
    }

    sccs = []
    if flip==False:
        x='temperature'
        hue='log_gdppc'
    else:
        x='log_gdppc'
        hue='temperature'

    for rate, eta_rho in near_term.items():
        df = xr.open_dataset(
            f"{indir}/{recipe}_{disc}_eta{eta_rho[0]}_rho{eta_rho[1]}_uncollapsed_sccs.nc4"
        ).sel(weitzman_parameter='0.5')

        if kind=='epa_mask':
            rffsp_values = pd.read_csv("/mnt/sacagawea_shares/gcp/integration/rff/rffsp_trials.csv").rffsp_id.unique()
            df = df.sel(rff_sp=rffsp_values)

        df = xr.combine_by_coords([gdppc.rename('gdppc').to_dataset(), df])
        df = df.to_dataframe()
        df['Ramsey discount rate'] = rate
        sccs.append(df.reset_index())

    data = pd.concat(sccs).reset_index()

    data = data.merge(anomalies, on=idx)
    data['log_gdppc'] = np.log(data.gdppc)
    
    print(len(data))
    print(len(data.loc[data.uncollapsed_sccs.isnull()]))

    if kind == 'ssp_unique':
        data['ssp_model'] = data.ssp + '_' + data.model
        g=sns.relplot(data=data,
                    x=x,
                    y='uncollapsed_sccs',
                    col='Ramsey discount rate',
                    row='ssp_model',
                    kind='scatter',
                    hue=hue,
                    edgecolor='face',
                    s=1
                   )
    else:
        g=sns.relplot(data=data,
                    x=x,
                    y='uncollapsed_sccs',
                    col='Ramsey discount rate',
                    kind='scatter',
                    hue=hue,
                    edgecolor='face',
                    s=1
                   )

    if log == True:
        g.set(yscale="symlog")
    plt.subplots_adjust(top=0.85)
    g.fig.suptitle(f"SCC, GMST, log(GDPpc) \n {sector}, {kind}")
    plt.savefig(f'/mnt/CIL_integration/rff_diagnostics/v3/scc_gmst_gdppc_scatter/{sector}_{kind}_uncollapsed_scc_gmst_gdppc_scatter_flip-{flip}_log-{log}.png', dpi=300, bbox_inches='tight')

In [None]:
for sector in sectors:
    for kind in ['ssp', 'ssp_unique', 'epa_mask', 'gdppc_mask']:
        for flip in [True]:
            data=scc_hist(sector,
                          kind,
                         flip=flip,
                          log=False
                         )