# DASK

In [None]:
from dask.distributed import Client, LocalCluster
from dask_jobqueue import SLURMCluster
from pathlib import Path
import os
import dask

In [None]:
which_pc = "merlin_paper_gsa"
if 'merlin' in which_pc:
    path_dask_logs = Path('/data/user/kim_a/dask_logs')
    path_dask_logs.mkdir(parents=True, exist_ok=True)
    cluster = SLURMCluster(cores     = 8,
                           memory    ="60GB", 
                           walltime  = '10:00:00',
                           interface ='ib0',
                           local_directory = path_dask_logs.as_posix(),
                           log_directory   = path_dask_logs.as_posix(),
                           queue="daily",
                           ) 
elif 'local' in which_pc:
    cluster = LocalCluster(memory_limit='7GB') 

In [None]:
client = Client(cluster)

In [None]:
n_workers = 2
cluster.scale(n_workers)

In [None]:
client

In [None]:
# client.close()
# cluster.close() 

# 1. GSA setups

In [None]:
from setups_paper_gwp import *

In [None]:
num_params = 10000
iter_corr = 4*num_params
iter_salt = 40*num_params
iter_delt = 4*num_params
iter_xgbo = 4*num_params

n_workers_corr = 4
n_workers_salt = 39
n_workers_delt = 4
n_workers_xgbo = 4

options = {
    'corr': {
        "iterations": iter_corr,
        "n_workers":  n_workers_corr,
    }, 
    'salt': {
        "iterations": iter_salt,
        "n_workers": n_workers_salt,
    }, 
    'delt': {
        "iterations": iter_delt,
        "n_workers": n_workers_delt,
    },
    'xgbo': {
        "iterations": iter_delt,
        "n_workers": n_workers_delt,
    }
}
gsa_corr = setup_corr(num_params, iter_corr, setup_lca_model_paper)
gsa_salt = setup_salt(num_params, iter_salt, setup_lca_model_paper)
gsa_delt = setup_delt(num_params, iter_delt, setup_lca_model_paper)
gsa_xgbo = setup_xgbo(num_params, iter_xgbo, setup_lca_model_paper)

# 2. Model runs

In [None]:
write_X_chunks(gsa_corr, n_workers_corr)
write_X_chunks(gsa_salt, n_workers_salt)
write_X_chunks(gsa_delt, n_workers_delt)
# write_X_chunks(gsa_xgbo, n_workers_xgbo)

In [None]:
# Compute model outputs for all gsa methods with dask
task_per_worker = dask.delayed(compute_scores_per_worker)
model_evals = []
for option,dict_ in options.items():
    iterations = dict_["iterations"]
    n_workers = dict_["n_workers"]
    for i in range(n_workers):
        print(option, num_params, iterations, i, n_workers)
        model_eval = task_per_worker(option, num_params, iterations, i, n_workers, setup_lca_model_paper)
        model_evals.append(model_eval)

In [None]:
%%time
dask.compute(model_evals)

## 2.5. Collect model Y chunks into one array

In [None]:
def generate_model_output_from_chunks(gsa, n_workers):
    Y = np.zeros(
        shape=(0,)
    )
    for i in range(n_workers):
        filepath_Y_chunk = (
            gsa.dirpath_Y
            / "{}.{}.pickle".format(i, n_workers)
        )
        Y_chunk = read_pickle(filepath_Y_chunk)
        Y = np.hstack(
            [Y, Y_chunk]
        )  # TODO change to vstack for multidimensional output
    write_hdf5_array(Y,gsa.filepath_Y)
    return Y

In [None]:
Ycorr = generate_model_output_from_chunks(gsa_corr, n_workers_corr)
Ysalt = generate_model_output_from_chunks(gsa_salt, n_workers_salt)
Ydelt = generate_model_output_from_chunks(gsa_delt, n_workers_delt)
Yxgbo = generate_model_output_from_chunks(gsa_xgbo, n_workers_xgbo)

# 3. Run GSA

In [None]:
gsa_delt.generate_unitcube_samples(return_X=False)
gsa_delt.generate_rescaled_samples(return_X=False)

In [None]:
gsa_xgbo.generate_unitcube_samples(return_X=False)
gsa_xgbo.generate_rescaled_samples(return_X=False)

In [None]:
%%time
gsa_corr.perform_gsa()

In [None]:
%%time
gsa_salt.perform_gsa()

In [None]:
worker_delt = dask.delayed(gsa_delt.perform_gsa)
model_eval_delt = worker_delt()
worker_xgbo = dask.delayed(gsa_xgbo.perform_gsa)
model_eval_xgbo = worker_xgbo()
model_evals = [model_eval_delt, model_eval_xgbo]

In [None]:
%%time
dask.compute(model_evals)

# 4. Validation

In [None]:
from gsa_framework.validation import Validation
import dask

model, write_dir, gsa_seed = setup_lca_model(num_params)
validation_seed = 23467
num_influential = 60

iterations_validation = 2000
val = Validation(
    model=model,
    iterations=iterations_validation,
    seed=validation_seed,
    default_x_rescaled=None,
    write_dir=write_dir,
)

In [None]:
worker_validation =  dask.delayed(val.get_influential_Y_from_gsa)

In [None]:
S_dict = gsa_corr.generate_gsa_indices()
Scorr = abs(S_dict['spearman'])
tag_corr = "SpearmanIndex"
model_eval_corr = worker_validation(Scorr, num_influential, tag_corr)

S_dict = gsa_salt.generate_gsa_indices()
Ssalt = S_dict['Total order']
tag_salt = "TotalIndex"
model_eval_salt = worker_validation(Ssalt, num_influential, tag_salt)

S_dict = gsa_delt.generate_gsa_indices()
Sdelt = np.array(S_dict['delta'])
tag_delt = "DeltaIndexNr{}".format(gsa_delt.num_resamples)
model_eval_delt = worker_validation(Sdelt, num_influential, tag_delt)

S_dict = gsa_xgbo.generate_gsa_indices()
Sxgbo = S_dict['fscores']
tag_xgbo = "FscoresIndex"
model_eval_xgbo = worker_validation(Sxgbo, num_influential, tag_xgbo)

In [None]:
model_evals = [
    model_eval_corr,
    model_eval_salt,
    model_eval_delt, 
    model_eval_xgbo,
]

In [None]:
# %%time
# dask.compute(model_evals)

In [None]:
fig_format = ['pickle']

influential_Y_corr = val.get_influential_Y_from_gsa(Scorr, num_influential, tag=tag_corr)
val.plot_histogram_Y_all_Y_inf(
    influential_Y_corr, num_influential, tag=tag_corr, fig_format=fig_format
)

influential_Y_salt = val.get_influential_Y_from_gsa(Ssalt, num_influential, tag=tag_salt)
val.plot_histogram_Y_all_Y_inf(
    influential_Y_salt, num_influential, tag=tag_salt, fig_format=fig_format
)

influential_Y_delt = val.get_influential_Y_from_gsa(Sdelt, num_influential, tag=tag_delt)
val.plot_histogram_Y_all_Y_inf(
    influential_Y_delt, num_influential, tag=tag_delt, fig_format=fig_format
)

influential_Y_xgbo = val.get_influential_Y_from_gsa(Sxgbo, num_influential, tag=tag_xgbo)
val.plot_histogram_Y_all_Y_inf(
    influential_Y_xgbo, num_influential, tag=tag_xgbo, fig_format=fig_format
)

# Stability

In [None]:
def compute_per_worker_delt(iterations_current, stability_seed, write_dir_stability):
    num_params = 10000
    iter_delt = 4*num_params
    filepath_Y = write_dir_stability / "Y.step{}.seed{}.pickle".format(iterations_current, stability_seed)
    Y = read_pickle(filepath_Y).flatten()
    gsa_delt = setup_delt(num_params, iter_delt)
    np.random.seed(gsa_delt.seed)
    X = np.random.rand(iter_delt, num_params)
    np.random.seed(stability_seed)
    choice = np.random.choice(np.arange(iter_delt), iterations_current, replace=False)
    Xr = gsa_delt.model.rescale(X[choice, :])
    del X
    filepath_S = write_dir_stability / "S.step{}.seed{}.pickle".format(iterations_current, stability_seed)
    if not filepath_S.exists():
        S_dict = delta_moment_stability(
            Y, Xr, num_resamples=1
        )
        write_pickle(S_dict, filepath_S)
    else:
        S_dict = read_pickle(filepath_S)
    
    return S_dict

def compute_per_worker_xgbo(iterations_current, stability_seed, write_dir_stability):
    num_params = 10000
    iter_xgbo = 4*num_params
    filepath_Y = write_dir_stability / "Y.step{}.seed{}.pickle".format(iterations_current, stability_seed)
    Y = read_pickle(filepath_Y).flatten()
    gsa_xgbo = setup_xgbo(num_params, iter_xgbo)
    np.random.seed(gsa_xgbo.seed)
    X = np.random.rand(iter_xgbo, num_params)
    np.random.seed(stability_seed)
    choice = np.random.choice(np.arange(iter_xgbo), iterations_current, replace=False)
    Xr = gsa_xgbo.model.rescale(X[choice, :])
    del X
    filepath_S = write_dir_stability / "S.step{}.seed{}.pickle".format(iterations_current, stability_seed)
    if not filepath_S.exists():
        S_dict = xgboost_scores_stability(
            Y,
            Xr,
            tuning_parameters=gsa_xgbo.tuning_parameters,
            num_boost_round=gsa_xgbo.num_boost_round,
        )
        write_pickle(S_dict, filepath_S)
    else:
        S_dict = read_pickle(filepath_S)
    
    return S_dict

In [None]:
num_steps = 50
num_bootstrap = 60

option = 'delta'
if option=='delta':
    gsa = gsa_delt
    compute_per_worker = compute_per_worker_delt
elif option=='xgboost':
    gsa = gsa_xgbo
    compute_per_worker = compute_per_worker_xgbo

task_per_worker = dask.delayed(compute_per_worker)

In [None]:
write_dir_scratch = Path("/shared-scratch/kim_a")
write_dir_stability = write_dir_scratch / 'stability_intermediate_{}'.format(gsa.gsa_label)
write_dir_stability.mkdir(parents=True, exist_ok=True)
conv = Convergence(
    gsa.filepath_Y,
    gsa.num_params,
    gsa.generate_gsa_indices,
    gsa.gsa_label,
    write_dir_scratch,
    num_steps=num_steps,
)

np.random.seed(gsa.seed)
stability_seeds = np.random.randint(
    low=0,
    high=2147483647,
    size=(len(conv.iterations_for_convergence), num_bootstrap),
)

X_rescaled = read_hdf5_array(gsa.filepath_X_rescaled)
Y = read_hdf5_array(gsa.filepath_Y).flatten()
model_evals = []
for i,iterations_current in enumerate(conv.iterations_for_convergence):
    model_evals_bootstrap_j = []
    for j in range(num_bootstrap):
        stability_seed = stability_seeds[i,j]
        np.random.seed(stability_seed)
        choice = np.random.choice(np.arange(X_rescaled.shape[0]), iterations_current, replace=False)
        # Write Y
        filepath_Y_ij = write_dir_stability / "Y.step{}.seed{}.pickle".format(iterations_current, stability_seed)
        if not filepath_Y_ij.exists():
            Y_ij = Y[choice]
            write_pickle(Y_ij, filepath_Y_ij)
        else:
#             print("{} already exists".format(filepath_Y_ij.name))  
            pass
        # Model evals
        model_eval = task_per_worker(iterations_current, stability_seed, write_dir_stability)
        model_evals_bootstrap_j.append(model_eval)
    model_evals.append(model_evals_bootstrap_j)

In [None]:
for model_evals_bootstrap_j in model_evals:
    dask.compute(model_evals_bootstrap_j)

# Archived

# 1. Construct LCA model

In [None]:
from gsa_framework.lca import LCAModel
from gsa_framework.methods.correlations import CorrelationCoefficients
from gsa_framework.methods.extended_FAST import eFAST
from gsa_framework.methods.saltelli_sobol import SaltelliSobol
from gsa_framework.methods.gradient_boosting import GradientBoosting
from gsa_framework.validation import Validation
from pathlib import Path
import brightway2 as bw
import time
import numpy as np
from gsa_framework.plotting import histogram_Y1_Y2
from gsa_framework.utils import read_hdf5_array

if __name__ == "__main__":

#     path_base = Path(
#         "/Users/akim/PycharmProjects/gsa_framework/dev/write_files/paper_gsa/"
#     )
    path_base = Path('/data/user/kim_a/paper_gsa/gsa_framework_files')

    # LCA model
    bw.projects.set_current("GSA for paper")
    co = bw.Database("CH consumption 1.0")
    demand_act = [act for act in co if "Food and non-alcoholic beverages sector" in act['name']][0]
    print(demand_act)
    demand = {demand_act: 1}
    method = ("IPCC 2013", "climate change", "GWP 100a")

    # Define some variables
    num_params = 162299
    iterations_validation = 2000
    write_dir = path_base / "lca_model_food_{}".format(num_params)
    model = LCAModel(demand, method, write_dir) # TODO add num_params later
    gsa_seed = 3403
    validation_seed = 7043
    fig_format = ["html", "pickle"]

    # Make sure  that the chosen num_params in LCA are appropriate
    val = Validation(
        model=model,
        iterations=iterations_validation,
        seed=4444,
        default_x_rescaled=model.default_uncertain_amounts,
        write_dir=write_dir,
    )
    num_params_paper = 10000
    tag = "numParams{}".format(num_params_paper)
    scores_dict = model.get_lsa_scores_pickle(model.write_dir / "LSA_scores")
    uncertain_tech_params_where_subset, _ = model.get_nonzero_params_from_num_params(scores_dict, num_params_paper)
    parameter_choice = []
    for u in uncertain_tech_params_where_subset:
        where_temp = np.where(model.uncertain_tech_params_where == u)[0]
        assert len(where_temp) == 1
        parameter_choice.append(where_temp[0])
    parameter_choice.sort()

In [None]:
Y_subset = val.get_influential_Y_from_parameter_choice(parameter_choice=parameter_choice, tag=tag)
val.plot_histogram_Y_all_Y_inf(Y_subset, num_influential=num_params_paper)