In [None]:
import xarray as xr
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os, sys, yaml
USER = os.getenv('USER')
from dscim.menu.simple_storage import Climate
from itertools import product

In [None]:
USA=False

In [None]:
if USA==True:
    results_root = "/mnt/CIL_integration/USA_SCC"
    sectors=['CAMEL_USA']
    config = f'/home/{USER}/repos/integration/configs/USA_SCC_ssps.yaml'
    output = "/mnt/CIL_integration/plots/rff_diagnostics_USA/SSP_scc_subset_timeseries/"
else:
    results_root = "/mnt/CIL_integration/menu_results_AR6_epa"
    sectors=['CAMEL']
    config = f'/home/{USER}/repos/integration/configs/epa_tool_config-histclim_AR6.yaml'
    output = "/mnt/CIL_integration/plots/rff_diagnostics/SSP_scc_subset_timeseries/"
    
ssps=['SSP3', 'SSP2', 'SSP4']
models=['IIASA GDP', 'OECD Env-Growth']
rcps=['ssp370', 'ssp245', 'ssp460', 'ssp585']

kinds= ['p99', 'p1']
masks=['unmasked']

pulse_year = 2020
results_mask = "unmasked"
eta_rhos = {
    2.0 :0.0,
    1.016010255: 9.149608e-05,
    1.244459066: 0.00197263997,
    1.421158116: 0.00461878399,
    1.567899395: 0.00770271076,
}

In [None]:
def timeseries(
    sector,
    mask,
    kind,
    ssp,
    rcp,
    model,
    recipe='risk_aversion',
    disc='euler_ramsey',
    eta=2.0,
    rho=0.0,
):
    
    results = f"{results_root}/{sector}/{pulse_year}/{results_mask}"

    # index
    ids = pd.read_csv(
        f"{results}/sim_ids/{recipe}_{disc}_eta{eta}_rho{rho}_{kind}_sim_ids.csv"
    )
    ids = ids.loc[(ids.ssp==ssp) & (ids.model==model) & (ids.rcp==rcp)].set_index(['simulation']).index

    # sccs
    sccs = xr.open_dataset(
        f"{results}/{recipe}_{disc}_eta{eta}_rho{rho}_uncollapsed_sccs.nc4"
    ).sel(weitzman_parameter='0.5', gas='CO2_Fossil', fair_aggregation='uncollapsed', drop=True).uncollapsed_sccs.rename('SCC')
    
    # marginal damages
    damages = xr.open_zarr(
                f"{results}/{recipe}_{disc}_eta{eta}_rho{rho}_uncollapsed_marginal_damages.zarr"
            ).sel(weitzman_parameter='0.5', gas='CO2_Fossil', drop=True).marginal_damages

    # discount factors
    df = xr.open_zarr(f"{results}/{recipe}_{disc}_eta{eta}_rho{rho}_uncollapsed_discount_factors.zarr"
                    ).sel(weitzman_parameter='0.5', gas='CO2_Fossil', drop=True).discount_factor.rename('discount_factors')

    # discounted damages
    discounted_damages = (damages * df).rename('discounted_damages')

    # emissions
    emissions = xr.open_dataset(
    "/shares/gcp/integration/float32/dscim_input_data/climate/AR6/ar6_fair162_control_pulse_2020-2030-2040-2050-2060-2070-2080_emis_conc_rf_temp_lambdaeff_emissions-driven_naturalfix_v4.0_Jan212022.nc"
).emissions.sel(gas='CO2_Fossil', drop=True)

    # cumulative emissions
    c_emissions=emissions.cumsum('year').rename('cumulative_emissions')
    c_emissions['year'] = emissions.year

    # gmst
    with open(config) as config_file:
        params = yaml.full_load(config_file)

    gmst = Climate(**params['climate']).fair_pulse.temperature.sel(gas='CO2_Fossil', drop=True).rename('gmst')

    data = xr.combine_by_coords([i.to_dataset() for i in 
                                 [sccs, damages, df, discounted_damages, emissions, c_emissions, gmst]])
    data = data.sel(simulation=ids, year=slice(2000,2300), ssp=ssp, model=model, rcp=rcp).to_dataframe().reset_index()

    if disc =='constant':
        data= data.loc[data.discrate==0.02]

    fig, ax = plt.subplots(6,1,figsize=(10,15), sharex=True)
    
    for i, yvar in enumerate(['cumulative_emissions', 'emissions', 'gmst', 
                              'marginal_damages', 'discount_factors', 'discounted_damages']):
        
        sns.lineplot(data=data,
                 x='year',
                 y=yvar,
                 hue='SCC',
                 ax=ax[i]
                )

        ax[i].set_title(yvar)
        if i > 0:
            ax[i].get_legend().remove()
    
    plt.subplots_adjust(top=0.9)
    fig.suptitle(f"mask={mask}, SCC subset={kind} \n {sector} {recipe} {disc}, eta={eta} rho={rho} \n {ssp} {model} {rcp}")
    os.makedirs(f"{output}/{sector}", exist_ok=True)
    plt.savefig(
        f'{output}/{sector}/vars_timeseries_{sector}_{recipe}_{disc}_{eta}_{rho}_{kind}_{mask}_{ssp}_{rcp}_{model}.png', 
        bbox_inches='tight', 
        dpi=300
        )
    
    return data

In [None]:
for s, m, k, er, ssp, model, rcp, in product(sectors, masks, kinds, eta_rhos.items(), ssps, models, rcps):
    print(s,m,k,er, ssp, model, rcp)
    dt=timeseries(
        sector=s,
        mask=m,
        kind=k,
        eta=er[0],
        rho=er[1],
        ssp=ssp,
        model=model,
        rcp=rcp
    )