In [None]:
import os, gc, time
USER = os.getenv('USER')
import numpy as np
import pandas as pd
import xarray as xr
import dask

dask.config.set(**{'array.slicing.split_large_chunks': False})

from dscim import ProWaiter
import itertools
from functools import reduce
from pathlib import Path

In [None]:
from dask.distributed import Client, progress
from dscim.utils.functions import ce_func, mean_func

In [None]:
# from dask.distributed import Client

# client = Client("tcp://127.0.0.1:45097")
# client

In [None]:
# client = Client(n_workers=15, memory_limit="15G", threads_per_worker=1) 

In [None]:
discount_types = [ "euler_ramsey","constant"] # "constant",
menu_options = ["risk_aversion", "adding_up"] 
# pulse_years=[2020]
pulse_years = range(2020, 2081, 10)
weitzman_values = [0.5, 1.0] # [x / 10.0 for x in range(1, 11, 1)] + [0.25, 0.01, 0.001, 0.0001]
eta_rhos = { 
    2.0:0.0,
    1.016010255 : 9.149608e-05,
    1.244459066 : 0.00197263997,
    1.421158116 : 0.00461878399,
    1.567899395 : 0.00770271076
}

gas = "CO2_Fossil"
# climate iterations
factors=False
marginal_damages=False
v=3
suffix='_coastalv19'

# we're doing expost masking with rff_expost_masking.ipynb inside this same directory
mask_list = [
#     'climate',
#     'gdppc',
#     'emissions',
#     'gdppc_emissions',
#     'gdppc_emissions_climate',
    'unmasked',
    # 'gmst'

]

quantile_list = [
    "None"
#     [0.001, 0.999],
#     [0.005, 0.995],
    # [0.01, 0.99],
    # [0.05, 0.95],
    # [0,1],
]

# sectors = ["AMEL_clipped", "mortality_clipped", "labor_clipped", "energy_clipped", "agriculture_clipped"] #"CAMEL_clipped", "coastal"]
sectors = ["coastal", "AMEL_clipped"]
#            , "mortality_clipped", "labor_clipped", "energy_clipped", "agriculture_clipped"]

combos = [(s, p, d, m, k, q, e) 
          for p in pulse_years
          for s in sectors 
          for d in discount_types 
          for m in menu_options
          for k in mask_list
          for q in quantile_list
          for e in eta_rhos
         ]



In [None]:
combos

# run all

In [None]:

n_combos = len(combos)

while n_combos > 0: 
    combo = combos[0]
    print("=========================================================")
    print(combo)
    sector = combo[0]
    pulse_year = combo[1]
    discount_type = combo[2]
    menu_option = combo[3]
    mask = combo[4]
    quantile = combo[5]
    eta = combo[6]
    rho = eta_rhos[eta]

    if mask == 'unmasked':
        mask_name = None
        mask_path = None
    elif mask == 'gmst':
        mask_name =  f'q{quantile[0]}_q{quantile[1]}'
        mask_path = None
    else:
        mask_name = f'q{quantile[0]}_q{quantile[1]}'
        mask_path = f"/shares/gcp/integration/rff/climate/masks/CO2_Fossil/{mask}_based_masks.nc4"

    save_path = f'/mnt/CIL_integration/rff_all_gases/{gas}/{sector}/{pulse_year}/{mask}_{mask_name}'
    
    w = ProWaiter(path_to_config=f'/home/{USER}/repos/integration/configs/rff_config_all_gases.yaml')
    kwargs = {'discounting_type' : discount_type,
              'sector': sector,
              'gases': gas,
              'damage_function_path' : f"/mnt/CIL_integration/damage_function_library/damage_function_library_rff{v}{suffix}/{sector}",
              'save_path' : save_path,
              'save_files' : ['uncollapsed_sccs'],
              'weitzman_parameter' : weitzman_values,
              'pulse_year' : pulse_year,
              'gmst_fair_path' : f"/shares/gcp/integration/rff/climate/stacked/ar6_rff_iter0-19_fair162_all_gases_control_pulse_{pulse_year}_temp_v5.02_newformat_Jan222022.nc",
              'gmsl_fair_path' : "/shares/gcp/integration/rff/climate/stacked/ar6_rff_iter0-19_fair162_control_pulse_2020-2030-2040-2050-2060-2070-2080_gmsl_emissions-driven_naturalfix_v5.02_Jan222022.zarr",
              'path_econ' : f'/shares/gcp/integration/rff/socioeconomics/rff_global_socioeconomics.nc4',
              'ecs_mask_path' :  mask_path,
              'ecs_mask_name' :  mask_name,
              'eta' : eta,
              'rho' : rho,
             }
    
    menu_item = w.menu_factory(menu_key=menu_option,
                                           sector=sector,
                                           kwargs=kwargs
                                          )
    menu_item.order_plate('scc')
    if marginal_damages==True:
        md = (menu_item.global_consumption_no_pulse - menu_item.global_consumption_pulse)* menu_item.climate.conversion
        md.rename('marginal_damages').to_dataset().chunk(
        {"discount_type":1,
         "weitzman_parameter":14,
         "rff_sp":10000,
         "gas":1,
         "simulation":1,
         "year":10}
        ).to_zarr(
            f'{save_path}/{menu_option}_{discount_type}_eta{menu_item.eta}_rho{menu_item.rho}_uncollapsed_marginal_damages.zarr',
            consolidated = True,
            mode = "w"
        )
    if factors == True:
        
        menu_item.calculate_discount_factors(
            menu_item.global_consumption_no_pulse / menu_item.pop
        ).to_dataset(name = "discount_factor").chunk(
        {"discount_type":1,
         "weitzman_parameter":14,
         "rff_sp":10000,
         "gas":1,
         "simulation":1,
         "region":1,
         "year":10}
        ).to_zarr(
            f'{save_path}/{menu_option}_{discount_type}_eta{menu_item.eta}_rho{menu_item.rho}_uncollapsed_discount_factors.zarr',
            consolidated = True,
            mode = "w"
        )
        print("done saving discount factor")
    combos.remove(combo)
    n_combos = len(combos)
    print(f"remaining combos: {n_combos}")

#     if (n_combos % 50 == 0):
#         client.restart()
#         time.sleep(30)