In [None]:
import xarray as xr
import pandas as pd
import numpy as np
from p_tqdm import p_map
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# root = "/mnt/CIL_integration/menu_results_AR6_epa/CAMEL/2020/unmasked"
root = "/mnt/CIL_integration/menu_results_96k_integration/integration_AMEL/2020/unmasked/"
recipe = "risk_aversion"
disc = "euler_ramsey"
eta = 2.0 # 1.567899395
rho = 0.0 # 0.00770271076
weitzman_parameter=['0.001', '0.01']
model = 'IIASA GDP'
ssp='SSP3'
rcp='rcp85'
discrate=0.02
no_draws=100

In [None]:
ds = xr.open_dataset(
    f"{root}/{recipe}_{disc}_eta{eta}_rho{rho}_uncollapsed_sccs.nc4"
)

if disc == 'constant' : 
    ds = ds.sel(discrate=discrate)
    
ds['simulation'] = range(0, len(ds.simulation))

In [None]:
np.random.seed(30)

In [None]:
def sample_dist(sample, ds=ds):
    
    samples = []
    
    for draw in range(no_draws):
        subset = ds.sel(
                    weitzman_parameter=weitzman_parameter,
                    model=model,
                    rcp=rcp,
                    ssp=ssp,
                    simulation=np.random.randint(0, ds.simulation.size, sample)
                ).mean('simulation')

        samples.append(subset.expand_dims({
            'sample' : [sample],
            'draw' : [draw],
        }))
        
    return xr.combine_by_coords(samples)

In [None]:
samples = p_map(sample_dist,
                [
                 50, 100, 500, 
                 1000, 1500, 2000, 
                 2500, 3000, 5000, 
                 10000, 50000, len(ds.simulation)
                ]
               )

samples = xr.combine_by_coords(samples)

In [None]:
for wp in weitzman_parameter:
    sns.set_theme(style="darkgrid")
    sns.displot(
        samples.sel(weitzman_parameter=wp).to_dataframe().reset_index(),
        x="uncollapsed_sccs",
        col="sample", 
        col_wrap=4,
        bins=50, 
    )
    
    plt.subplots_adjust(top=0.9)

    plt.suptitle(f"{recipe} {disc} {model} {ssp} {rcp} \n eta = {eta}, rho = {rho}, weitzman parameter = {wp}, # draws = {no_draws}")
    plt.savefig(f'/mnt/CIL_integration/plots/integration_paper_diagnostics/AR5_samples_{recipe}_{disc}_{model}_{ssp}_{rcp}_eta{eta}_rho{rho}_wp{wp}.png', dpi=300, bbox_inches='tight')