In [1]:
from itertools import product

import pandas as pd
from SALib import ProblemSpec
import numpy as np
from plotly.subplots import make_subplots

import plotly.graph_objects as go

In [2]:
cost_function_results = pd.read_parquet("./output_sobol_index.parquet")
cost_function_results

station,BARENTS,BARENTS,BARENTS,HOT,HOT,HOT,BATS,BATS,BATS,PAPA,PAPA,PAPA,GUAM,GUAM,GUAM
quantity_of_interest,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax
0,0.004000,3.479870e-06,205.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006814,7.913743e-07,287.0,0.001023,6.137193e-09,113.0
1,0.000044,4.303894e-10,205.0,0.000033,2.816304e-11,50.0,0.000043,2.916630e-10,97.0,0.000076,9.787696e-11,287.0,0.000011,7.590464e-13,113.0
2,0.003988,3.623789e-06,208.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006795,9.268995e-07,289.0,0.001023,6.137193e-09,113.0
3,0.003999,3.456675e-06,205.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006810,8.046640e-07,289.0,0.001023,6.137193e-09,113.0
4,0.008151,5.712565e-06,254.0,0.006751,5.638046e-07,79.0,0.008491,4.738705e-06,111.0,0.014415,2.375407e-06,300.0,0.002324,1.664002e-08,113.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
476995,0.000235,1.172422e-07,131.0,0.000107,5.361083e-10,48.0,0.000120,4.139368e-09,77.0,0.000244,9.989848e-09,108.0,0.000047,4.177072e-11,196.0
476996,0.000234,1.195310e-07,131.0,0.000107,5.361083e-10,48.0,0.000120,3.721616e-09,91.0,0.000244,1.047774e-08,108.0,0.000047,4.177072e-11,196.0
476997,0.000814,8.356571e-07,151.0,0.000107,9.215270e-10,43.0,0.000120,3.997736e-09,92.0,0.000806,1.144070e-07,119.0,0.000047,4.177072e-11,196.0
476998,0.000512,3.765561e-07,148.0,0.000114,1.009265e-09,43.0,0.000148,6.845959e-09,92.0,0.000708,4.240498e-08,119.0,0.000047,4.181343e-11,196.0


In [3]:
cost_function_results.describe().transpose().reset_index().groupby(["quantity_of_interest", "station"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,count,mean,std,min,25%,50%,75%,max
quantity_of_interest,station,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
argmax,BARENTS,479000.0,189.6562,41.77737,115.0,147.0,200.0,213.0,321.0
argmax,BATS,479000.0,97.5266,16.47366,14.0,91.0,94.0,100.0,365.0
argmax,GUAM,479000.0,163.6027,70.27309,2.0,113.0,196.0,196.0,365.0
argmax,HOT,479000.0,71.55415,54.94519,10.0,48.0,48.0,68.0,365.0
argmax,PAPA,479000.0,190.4802,77.58618,64.0,134.0,151.0,287.0,365.0
mean,BARENTS,479000.0,0.005216743,0.005828877,1.319749e-08,0.001093722,0.003089055,0.007170793,0.03682
mean,BATS,479000.0,0.003262964,0.008344485,1.373881e-09,0.0001352807,0.0002706443,0.001685395,0.089212
mean,GUAM,479000.0,0.001015853,0.003047637,5.411974e-10,4.401892e-05,8.808126e-05,0.0002918673,0.036947
mean,HOT,479000.0,0.002719892,0.007591823,1.228245e-09,0.000109601,0.0002193919,0.001058672,0.086528
mean,PAPA,479000.0,0.007484984,0.01018934,1.053651e-08,0.001157039,0.003408163,0.009263482,0.073865


In [4]:
quantity_of_interest = ["_".join(col).strip() for col in cost_function_results.columns.values]
quantity_of_interest

['BARENTS_mean',
 'BARENTS_variance',
 'BARENTS_argmax',
 'HOT_mean',
 'HOT_variance',
 'HOT_argmax',
 'BATS_mean',
 'BATS_variance',
 'BATS_argmax',
 'PAPA_mean',
 'PAPA_variance',
 'PAPA_argmax',
 'GUAM_mean',
 'GUAM_variance',
 'GUAM_argmax']

In [5]:
# Get unique value but keep order
stations = cost_function_results.columns.levels[0]
quantity_of_interest_function = cost_function_results.columns.levels[1]
print(stations)
print(quantity_of_interest_function)

Index(['BARENTS', 'BATS', 'GUAM', 'HOT', 'PAPA'], dtype='object', name='station')
Index(['argmax', 'mean', 'variance'], dtype='object', name='quantity_of_interest')


In [6]:
parameters_name = [
    "energy_transfert",
    "tr_0",
    "gamma_tr",
    "lambda_0",
    "gamma_lambda",
]

In [7]:
coupled_parameters_name = [
    f"({i}, {j})"
    for i, j in product(parameters_name, parameters_name)
    if parameters_name.index(i) < parameters_name.index(j)
]

In [8]:
sp = ProblemSpec(
    {
        "names": parameters_name,
        "groups": None,
        "bounds": [
            [0, 1],
            [0.1, 50],
            [-0.5, -0.0001],
            [0.1, 500],
            [-0.5, -0.0001],
        ],
        "outputs": quantity_of_interest,
    }
)

In [9]:
sp.set_results(cost_function_results.to_numpy()[: int(np.floor(np.sqrt(55000)) ** 2), ...])

Outputs:
	15 outputs: ['BARENTS_mean', 'BARENTS_variance', 'BARENTS_argmax', 'HOT_mean', 'HOT_variance', 'HOT_argmax', 'BATS_mean', 'BATS_variance', 'BATS_argmax', 'PAPA_mean', 'PAPA_variance', 'PAPA_argmax', 'GUAM_mean', 'GUAM_variance', 'GUAM_argmax']
	54756 evaluations


In [10]:
sp.analyze_sobol()

sobol_results = sp.to_df()

  names = list(pd.unique(groups))


---

# Plotting the Sobol results


In [11]:
def plot_total_analysis(sobol: ProblemSpec):
    figures = []
    for qoi in sobol.analysis:
        fig = make_subplots(
            rows=1, cols=3, subplot_titles=("1st-order Sobol index", "2nd-order Sobol index", "Total Sobol index")
        )
        fig.add_trace(
            go.Bar(
                x=parameters_name,
                y=sobol.analysis[qoi]["S1"],
                error_y={"type": "data", "array": sobol.analysis[qoi]["S1_conf"]},
                name="1st-order Sobol index",
            ),
            row=1,
            col=1,
        )
        fig.add_trace(
            go.Bar(
                x=coupled_parameters_name,
                y=list(filter(np.isfinite, sobol.analysis[qoi]["S2"].flatten())),
                error_y={"type": "data", "array": list(filter(np.isfinite, sobol.analysis[qoi]["S2_conf"].flatten()))},
                name="2nd-order Sobol index",
            ),
            row=1,
            col=2,
        )
        fig.add_trace(
            go.Bar(
                x=parameters_name,
                y=sobol.analysis[qoi]["ST"],
                error_y={"type": "data", "array": sobol.analysis[qoi]["ST_conf"]},
                name="Total Sobol index",
            ),
            row=1,
            col=3,
        )
        fig.update_layout(showlegend=False, title_text=f"Sobol indices for {qoi}")
        figures.append(fig)
    return figures

In [12]:
for fig in plot_total_analysis(sp):
    fig.show()

## Total Index


In [13]:
def plot_total_sobol_analysis(sobol: ProblemSpec, stations: np.ndarray, quantity_of_interest_function: np.ndarray):
    figures = []
    for st in stations:
        fig = make_subplots(rows=1, cols=3, subplot_titles=(quantity_of_interest_function))
        for i, qoif in enumerate(quantity_of_interest_function):
            fig.add_trace(
                go.Bar(
                    x=parameters_name,
                    y=sobol.analysis[f"{st}_{qoif}"]["ST"],
                    error_y={"type": "data", "array": sobol.analysis[f"{st}_{qoif}"]["ST_conf"]},
                    name=qoif,
                ),
                row=1,
                col=i + 1,
            )

            fig.update_layout(showlegend=False, title_text=f"Total Sobol indices for {st}")
        figures.append(fig)
    return figures

In [14]:
for fig in plot_total_sobol_analysis(sp, stations, quantity_of_interest_function):
    fig.show()