# DASK

In [None]:
from dask.distributed import Client, LocalCluster
from dask_jobqueue import SLURMCluster
import os

In [None]:
which_pc = "merlin_paper_gsa"
if 'merlin' in which_pc:
    path_dask_logs = '/data/user/kim_a/dask_logs'
    if not os.path.exists(path_dask_logs):
        os.makedirs(path_dask_logs)
    cluster = SLURMCluster(cores     = 24,
                           processes = 6,
                           memory    ="60GB", 
                           walltime  = '12:00:00',
                           interface ='ib0',
                           local_directory = path_dask_logs,
                           log_directory   = path_dask_logs,
                           queue="daily",
                           ) 
elif 'local' in which_pc:
    cluster = LocalCluster(memory_limit='7GB') 

In [None]:
client = Client(cluster)

In [None]:
n_workers = 5
cluster.scale(n_workers)

In [None]:
client

In [None]:
client.close()
cluster.close()

# Validation

In [None]:
from gsa_framework.lca.lca_models import LCAModel
from gsa_framework.methods.correlations import CorrelationCoefficients
from gsa_framework.methods.saltelli_sobol import SaltelliSobol
from gsa_framework.methods.gradient_boosting import GradientBoosting
from gsa_framework.methods.delta_moment import DeltaMoment
from gsa_framework.validation import Validation
from pathlib import Path
import time
import brightway2 as bw
import dask

In [None]:
def validation_per_worker_base(iterations_validation):
    path_base = Path('/data/user/kim_a/paper_gsa/gsa_framework_files')

    # LCA model
    bw.projects.set_current("GSA for paper")
    co = bw.Database("CH consumption 1.0")
    act = [act for act in co if "Food" in act["name"]][0]
    demand = {act: 1}
    method = ("IPCC 2013", "climate change", "GTP 100a")

    # Define some variables
    num_params = 10000
    write_dir = path_base / "lca_model_{}".format(num_params)
    model = LCAModel(demand, method, write_dir, num_params=num_params)
    gsa_seed = 3403 
    fig_format = ["html", "pickle"]
    
    validation_seed = 7043
    num_influential = 100
    val = Validation(
        model=model,
        iterations=iterations_validation,
        seed=validation_seed,
        default_x_rescaled=None,
        write_dir=write_dir,
    )
    base_Y = val.generate_Y_all_parameters_vary()
    return base_Y

In [None]:
def validation_per_worker_correlations_lca(iterations_validation):
    path_base = Path('/data/user/kim_a/paper_gsa/gsa_framework_files')

    # LCA model
    bw.projects.set_current("GSA for paper")
    co = bw.Database("CH consumption 1.0")
    act = [act for act in co if "Food" in act["name"]][0]
    demand = {act: 1}
    method = ("IPCC 2013", "climate change", "GTP 100a")

    # Define some variables
    num_params = 10000
    write_dir = path_base / "lca_model_{}".format(num_params)
    model = LCAModel(demand, method, write_dir, num_params=num_params)
    gsa_seed = 3403
    fig_format = ["html", "pickle"]

    iterations = 20000
    gsa = CorrelationCoefficients(
        iterations=iterations,
        model=model,
        write_dir=write_dir,
        seed=gsa_seed,
    )
    S_dict = gsa.generate_gsa_indices()
    gsa_indices = S_dict["spearman"]
    
    t0 = time.time()
    validation_seed = 7043
    num_influential = 100
    val = Validation(
        model=model,
        iterations=iterations_validation,
        seed=validation_seed,
        default_x_rescaled=None,
        write_dir=write_dir,
    )
    tag = "SpearmanIndex"
    influential_Y = val.get_influential_Y_from_gsa(
        gsa_indices, num_influential, tag=tag
    )
    t1 = time.time()
    try:
        print("Total validation time  -> {:8.3f} s \n".format(t1 - t0))
    except:
        pass

In [None]:
def validation_per_worker_saltelli_lca(iterations_validation):
    path_base = Path('/data/user/kim_a/paper_gsa/gsa_framework_files')

    # LCA model
    bw.projects.set_current("GSA for paper")
    co = bw.Database("CH consumption 1.0")
    act = [act for act in co if "Food" in act["name"]][0]
    demand = {act: 1}
    method = ("IPCC 2013", "climate change", "GTP 100a")

    # Define some variables
    num_params = 10000
    write_dir = path_base / "lca_model_{}".format(num_params)
    model = LCAModel(demand, method, write_dir, num_params=num_params)
    gsa_seed = 3403
    fig_format = ["html", "pickle"]
    
    iterations = 100 * num_params
    gsa = SaltelliSobol(
        iterations=iterations,
        model=model,
        write_dir=write_dir,
    )
    S_dict = gsa.generate_gsa_indices()
    gsa_indices = S_dict["Total order"]
    
    t0 = time.time()
    validation_seed = 7043
    num_influential = 100
    val = Validation(
        model=model,
        iterations=iterations_validation,
        seed=validation_seed,
        default_x_rescaled=None,
        write_dir=write_dir,
    )
    tag = "SaltelliTotalIndex"
    influential_Y = val.get_influential_Y_from_gsa(
        gsa_indices, num_influential, tag=tag
    )
    t1 = time.time()
    try:
        print("Total validation time  -> {:8.3f} s \n".format(t1 - t0))
    except:
        pass

In [None]:
def validation_per_worker_xgboost_lca(iterations_validation):
    path_base = Path('/data/user/kim_a/paper_gsa/gsa_framework_files')

    # LCA model
    bw.projects.set_current("GSA for paper")
    co = bw.Database("CH consumption 1.0")
    act = [act for act in co if "Food" in act["name"]][0]
    demand = {act: 1}
    method = ("IPCC 2013", "climate change", "GTP 100a")

    # Define some variables
    num_params = 10000
    write_dir = path_base / "lca_model_{}".format(num_params)
    model = LCAModel(demand, method, write_dir, num_params=num_params)
    gsa_seed = 3403
    fig_format = ["html", "pickle"]

    parameter_inds_convergence_plot = [1,2,3]  # TODO choose for convergence

    num_boost_round = 400
    tuning_parameters = {
         'max_depth': 6,  
         'eta': 0.1,
         'objective': 'reg:squarederror',
         'n_jobs': -1,
         'refresh_leaf': True,
         'subsample': 0.6,
         'min_child_weight': 0.5,
    }
    iterations = 2 * num_params
    gsa = GradientBoosting(
        iterations=iterations,
        model=model,
        write_dir=write_dir,
        seed=gsa_seed,
        tuning_parameters=tuning_parameters,
        num_boost_round=num_boost_round,
        xgb_model=None,
    )
    S_dict, _, _ = gsa.generate_gsa_indices()
    gsa_indices = S_dict["fscores"]
    
    t0 = time.time()
    validation_seed = 7043
    num_influential = 100
    val = Validation(
        model=model,
        iterations=iterations_validation,
        seed=validation_seed,
        default_x_rescaled=None,
        write_dir=write_dir,
    )
    tag = "FscoresIndex"
    influential_Y = val.get_influential_Y_from_gsa(
        gsa_indices, num_influential, tag=tag
    )
    t1 = time.time()
    try:
        print("Total validation time  -> {:8.3f} s \n".format(t1 - t0))
    except:
        pass

In [None]:
def validation_per_worker_delta_lca(iterations_validation):
    path_base = Path('/data/user/kim_a/paper_gsa/gsa_framework_files')

    # LCA model
    bw.projects.set_current("GSA for paper")
    co = bw.Database("CH consumption 1.0")
    act = [act for act in co if "Food" in act["name"]][0]
    demand = {act: 1}
    method = ("IPCC 2013", "climate change", "GTP 100a")

    # Define some variables
    num_params = 10000
    write_dir = path_base / "lca_model_{}".format(num_params)
    model = LCAModel(demand, method, write_dir, num_params=num_params)
    gsa_seed = 3403
    fig_format = ["html", "pickle"]

    num_resamples = 1
    iterations = 2 * num_params
    gsa = DeltaMoment(
        iterations=iterations,
        model=model,
        write_dir=write_dir,
        num_resamples=num_resamples,
        seed=gsa_seed,
    )
    S_dict = gsa.generate_gsa_indices()
    gsa_indices = S_dict["delta"]
    
    t0 = time.time()
    validation_seed = 7043
    num_influential = 100
    val = Validation(
        model=model,
        iterations=iterations_validation,
        seed=validation_seed,
        default_x_rescaled=None,
        write_dir=write_dir,
    )
    tag = "DeltaIndex"
    influential_Y = val.get_influential_Y_from_gsa(
        gsa_indices, num_influential, tag=tag
    )
    t1 = time.time()
    try:
        print("Total validation time  -> {:8.3f} s \n".format(t1 - t0))
    except:
        pass

In [None]:
# %%time
# # test
# iterations_validation = 6
# validation_per_worker_base(iterations_validation)

In [None]:
# %%time
# # test
# iterations_validation = 6
# validation_per_worker_correlations_lca(iterations_validation)

In [None]:
# %%time
# # test
# iterations_validation = 4
# validation_per_worker_saltelli_lca(iterations_validation)

In [None]:
# %%time
# # test
# iterations_validation = 4
# validation_per_worker_xgboost_lca(iterations_validation)

In [None]:
# %%time
# # test
# iterations_validation = 4
# validation_per_worker_delta_lca(iterations_validation)

In [None]:
iterations_validation = 2000

task_per_worker_base = dask.delayed(validation_per_worker_base)
model_eval_base = task_per_worker_base(iterations_validation)

task_per_worker_corr = dask.delayed(validation_per_worker_correlations_lca)
model_eval_corr = task_per_worker_corr(iterations_validation)

task_per_worker_salt = dask.delayed(validation_per_worker_saltelli_lca)
model_eval_salt = task_per_worker_salt(iterations_validation)

task_per_worker_xgbo = dask.delayed(validation_per_worker_xgboost_lca)
model_eval_xgbo = task_per_worker_xgbo(iterations_validation)

task_per_worker_delt = dask.delayed(validation_per_worker_delta_lca)
model_eval_delt = task_per_worker_delt(iterations_validation)

model_evals = [
    model_eval_base,
    model_eval_corr,
    model_eval_salt,
    model_eval_xgbo,
    model_eval_delt,
]

In [None]:
%%time
dask.compute(model_evals)

# Validation for combination of inputs

In [14]:
from gsa_framework.lca import LCAModel
from gsa_framework.methods.correlations import CorrelationCoefficients
from gsa_framework.methods.extended_FAST import eFAST
from gsa_framework.methods.saltelli_sobol import SaltelliSobol
from gsa_framework.methods.gradient_boosting import GradientBoosting
from gsa_framework.validation import Validation
from gsa_framework.convergence import Convergence
from pathlib import Path
import brightway2 as bw
import time
import numpy as np
from gsa_framework.plotting import histogram_Y1_Y2
from gsa_framework.utils import read_hdf5_array, read_pickle
from scipy.stats import wasserstein_distance, spearmanr

path_base = Path('/data/user/kim_a/paper_gsa/gsa_framework_files')

p_lca_10000 = path_base / "lca_model_10000"
p_lca_10000_arr = p_lca_10000 / "arrays"

fp_corr = p_lca_10000_arr / "validation.Y.100inf.2000.7043.SpearmanIndex.hdf5"
fp_salt = p_lca_10000_arr / "validation.Y.100inf.2000.7043.SaltelliTotalIndex.hdf5"
fp_xgbo = p_lca_10000_arr / "validation.Y.100inf.2000.7043.FscoresIndex.hdf5"
fp_delt = p_lca_10000_arr / "validation.Y.100inf.2000.7043.DeltaIndex.hdf5"
fp_vall = p_lca_10000_arr / "validation.Y.100inf.2000.7043.CorrSaltXgbDelta.hdf5"

corr = read_hdf5_array(fp_corr).flatten()
salt = read_hdf5_array(fp_salt).flatten()
xgbo = read_hdf5_array(fp_xgbo).flatten()
delt = read_hdf5_array(fp_delt).flatten()
vall = read_hdf5_array(fp_vall).flatten()

# GSA indices
fp_corr_gsa = p_lca_10000_arr / "S.correlationsGsa.randomSampling.20000.3403.pickle"
fp_salt_gsa = p_lca_10000_arr / "S.saltelliGsa.saltelliSampling.990198.None.pickle"
fp_xgbo_gsa = p_lca_10000_arr / "S.xgboostGsaN400D6E10S60.randomSampling.20000.3403.pickle"
fp_delt_gsa = p_lca_10000_arr / "S.deltaGsaNr1.latinSampling.20000.3403.pickle"

corr_gsa = read_pickle(fp_corr_gsa)["spearman"]
salt_gsa = read_pickle(fp_salt_gsa)["Total order"]
xgbo_gsa = read_pickle(fp_xgbo_gsa)[0]["fscores"]
delt_gsa = read_pickle(fp_delt_gsa)["delta"]

In [15]:
# LCA model
bw.projects.set_current("GSA for paper")
co = bw.Database("CH consumption 1.0")
act = [act for act in co if "Food" in act["name"]][0]
demand = {act: 1}
method = ("IPCC 2013", "climate change", "GTP 100a")

# Define some variables
num_params = 10000
num_influential = num_params // 100
iterations_validation = 2000
write_dir = path_base / "lca_model_{}".format(num_params)
model = LCAModel(demand, method, write_dir, num_params=num_params)
gsa_seed = 3403
validation_seed = 7043
fig_format = ["html", "pickle"]

val = Validation(
    model=model,
    iterations=iterations_validation,
    seed=validation_seed,
    default_x_rescaled=None,
    write_dir=write_dir,
)

tags = {
    "SpearmanIndex": corr,
    "SaltelliTotalIndex": salt,
    "FscoresIndex": xgbo,
    "DeltaIndex": delt,
    "All": vall,
}
for tag, influential_Y in tags.items():
    val.plot_histogram_Y_all_Y_inf(
        influential_Y, num_influential, tag=tag, fig_format=fig_format
    )
    val.plot_correlation_Y_all_Y_inf(
        influential_Y, num_influential, tag=tag, fig_format=fig_format
    )
    wdist = wasserstein_distance(val.Y_all, influential_Y)
    sr = spearmanr(val.Y_all, influential_Y)
    print(wdist, sr)


2.598007784141515 SpearmanrResult(correlation=0.9733630173407543, pvalue=0.0)


2.495511736089809 SpearmanrResult(correlation=0.9793804983451245, pvalue=0.0)


3.0172477852599373 SpearmanrResult(correlation=0.971551331887833, pvalue=0.0)


5.419340664294521 SpearmanrResult(correlation=0.9379598014899503, pvalue=0.0)


3.477891332182186 SpearmanrResult(correlation=0.9615301578825394, pvalue=0.0)


In [8]:
num_influential = 1880
influential_inds_corr = np.argsort(corr_gsa)[::-1][:num_influential]
influential_inds_salt = np.argsort(salt_gsa)[::-1][:num_influential]
influential_inds_xgbo = np.argsort(xgbo_gsa)[::-1][:num_influential]
influential_inds_delt = np.argsort(delt_gsa)[::-1][:num_influential]
influential_inds_corr.sort()
influential_inds_salt.sort()
influential_inds_xgbo.sort()
influential_inds_delt.sort()

In [12]:
params_choice = np.intersect1d(
    np.intersect1d(
        influential_inds_corr, 
        influential_inds_salt
    ),
    np.intersect1d(
        influential_inds_xgbo, 
        influential_inds_delt
    )
)

In [13]:
%%time
tag = "CorrSaltXgbDelta"
Yinf = val.get_influential_Y_from_parameter_choice(params_choice, tag)

CPU times: user 2h 58min 39s, sys: 5min 8s, total: 3h 3min 48s
Wall time: 16min 33s
