In [None]:
import xarray as xr
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os, sys, yaml
USER = os.getenv('USER')
from dscim.menu.simple_storage import Climate
from itertools import product
from p_tqdm import p_map, p_imap, p_uimap, p_umap
import gc

In [None]:
USA=False

In [None]:
if USA==True:
    results_root = "/mnt/CIL_integration/USA_SCC_rff2"
    sectors=['mortality_USA', 'agriculture_USA', 'energy_USA', 'labor_USA'] # 'CAMEL_USA', 
    config = f'/home/{USER}/repos/integration/configs/USA_SCC_rff.yaml'
    output = "/mnt/CIL_integration/plots/rff_diagnostics_USA/scc_subset_timeseries/"
else:
    results_root = "/mnt/CIL_integration/rff2"
    sectors=['CAMEL', 'mortality', 'energy', 'labor', 'agriculture']
    config = f'/home/{USER}/repos/integration/configs/rff2_config_all_gases.yaml'
    output = "/mnt/CIL_integration/plots/rff_diagnostics/scc_subset_timeseries/"

kinds= [ 'p99.99', 'p0.01'] # 'p75', 'p50',
masks=['unmasked']
gases = [ "CO2_Fossil"] # "CH4","N2O",

recipes = ['equity']
discs = ['constant','euler_ramsey']

pulse_year = 2020
results_mask = "unmasked_None"
eta_rhos = {
    2.0 :0.0,
    1.016010255: 9.149608e-05,
    1.244459066: 0.00197263997,
    1.421158116: 0.00461878399,
    1.567899395: 0.00770271076,
}

In [None]:
def timeseries(param_dict):
    
    sector = param_dict['sector']
    mask = param_dict['mask']
    recipe= param_dict['recipe']
    disc=param_dict['disc']
    eta=param_dict['eta']
    rho=param_dict['rho']
    gas = param_dict['gas']
    kind = param_dict['kind']
    
    results = f"{results_root}/{sector}/{pulse_year}/{results_mask}"

    # index
    rff_ids = pd.read_csv(
        f"{results}/runids/{recipe}_{disc}_eta{eta}_rho{rho}_{kind}_{mask}_runids.csv"
    ).set_index('runid').index
    
    # runid rff_sp-simulation crosswalk
    cw = xr.open_dataset('/shares/gcp/integration/rff2/rffsp_fair_sequence.nc')

    # sccs
    sccs = xr.open_dataset(
        f"{results}/{recipe}_{disc}_eta{eta}_rho{rho}_uncollapsed_sccs.nc4"
    ).sel(weitzman_parameter='0.5', gas=gas, fair_aggregation='uncollapsed', drop=True).uncollapsed_sccs.rename('SCC')
    
    # marginal damages
    damages = xr.open_zarr(
                f"{results}/{recipe}_{disc}_eta{eta}_rho{rho}_uncollapsed_marginal_damages.zarr"
            ).sel(weitzman_parameter='0.5', gas=gas, drop=True).marginal_damages

    # discount factors
    df = xr.open_zarr(f"{results}/{recipe}_{disc}_eta{eta}_rho{rho}_uncollapsed_discount_factors.zarr"
                    ).sel(weitzman_parameter='0.5', gas=gas, drop=True).discount_factor.rename('discount_factors')

    # discounted damages
    discounted_damages = (damages * df).rename('discounted_damages')

    # emissions
    if gas == "CO2_Fossil":
        g = "C"
    elif gas == "CH4":
        g = "CH4"
    elif gas == "N2O":
        g = "N2"
    emissions = xr.open_dataset("/shares/gcp/integration/rff2/climate/emissions/rff-sp_emissions_all_gases.nc").sel(gas=g, drop=True).rename(
        {'Year' : 'year',
         'simulation' : 'rff_sp'
        }).emissions.sel(rff_sp=cw.rff_sp)

    # cumulative emissions
    c_emissions=emissions.cumsum('year').rename('cumulative_emissions')
    c_emissions['year'] = emissions.year

    # gmst
    with open(config) as config_file:
        params = yaml.full_load(config_file)
        params['climate'].update({
              'gmst_fair_path' : f"/shares/gcp/integration/rff2/climate/ar6_rff_fair162_control_pulse_all_gases_2020-2030-2040-2050-2060-2070-2080_emis_conc_rf_temp_lambdaeff_ohc_emissions-driven_naturalfix_v5.03_Feb072022.nc"
        })
    gmst = Climate(**params['climate']).fair_pulse.temperature.sel(gas=gas, drop=True).rename('gmst')
    gmst_pulse = Climate(**params['climate']).fair_pulse.temperature.sel(gas=gas, drop=True).rename('gmst')
    gmst_control = Climate(**params['climate']).fair_control.temperature.sel(gas=gas, drop=True).rename('gmst')
    gmst_pulse_minus_control = gmst_pulse - gmst_control
    gmst_pulse_minus_control = gmst_pulse_minus_control.rename("gmst_pulse_minus_control")

    data = xr.combine_by_coords([i.to_dataset() for i in 
                                 [sccs, damages, df, discounted_damages, emissions, c_emissions, gmst, gmst_pulse_minus_control]])
    data = data.sel(runid=rff_ids).to_dataframe().reset_index()

    fig, ax = plt.subplots(7,1,figsize=(10,15), sharex=True)
    
    for i, yvar in enumerate(['cumulative_emissions', 'emissions', 'gmst', 'gmst_pulse_minus_control', 
                              'marginal_damages', 'discount_factors', 'discounted_damages']):

        sns.lineplot(data=data,
                 x='year',
                 y=yvar,
                 hue='SCC',
                 ax=ax[i]
                )

        ax[i].set_title(yvar)
        if i > 0:
            ax[i].get_legend().remove()
    
    plt.subplots_adjust(top=0.9)
    fig.suptitle(f"mask={mask}, SCC subset={kind} \n {sector} {recipe} {disc}, eta={eta} rho={rho}")
    os.makedirs(f"{output}/{sector}", exist_ok=True)
    plt.savefig(
        f'{output}/{sector}/vars_timeseries_{gas}_{sector}_{recipe}_{disc}_{eta}_{rho}_{kind}_{mask}.png', 
        bbox_inches='tight', 
        dpi=300
        )
    plt.close()
    del(data)
    del(rff_ids)
    del(gmst_pulse_minus_control)
    del(gmst_control)
    del(gmst_pulse)
    del(gmst)
    del(c_emissions)
    del(emissions)
    del(discounted_damages)
    del(sccs)
    del(cw)
    del(damages)
    del(df)
    gc.collect(0)
    gc.collect(1)
    gc.collect(2)

In [None]:
param_dict_list = list()
i=0
for s, m, k, er, g, r, d  in product(sectors, masks, kinds, eta_rhos.items(), gases, recipes, discs):
    print(s,m, k, er, g, r, d)
    param_dict={}
    param_dict['sector']=s
    param_dict['mask']=m
    param_dict['kind']=k
    param_dict['eta']=er[0]
    param_dict['rho']=er[1]
    param_dict['gas']=g
    param_dict['recipe']=r
    param_dict['disc']=d
    param_dict_list.append(param_dict)
    i=i+1
    timeseries(param_dict)
    
# p_umap(timeseries, param_dict_list, num_cpus = 30)