In [3]:
from itertools import product

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from SALib import ProblemSpec

In [4]:
cost_function_results = pd.read_parquet("./output_sobol_index.parquet")
cost_function_results

station,BARENTS,BARENTS,BARENTS,HOT,HOT,HOT,BATS,BATS,BATS,PAPA,PAPA,PAPA,GUAM,GUAM,GUAM
quantity_of_interest,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax,mean,variance,argmax
0,0.004000,3.479870e-06,205.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006814,7.913743e-07,287.0,0.001023,6.137193e-09,113.0
1,0.000044,4.303894e-10,205.0,0.000033,2.816304e-11,50.0,0.000043,2.916630e-10,97.0,0.000076,9.787696e-11,287.0,0.000011,7.590464e-13,113.0
2,0.003988,3.623789e-06,208.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006795,9.268995e-07,289.0,0.001023,6.137193e-09,113.0
3,0.003999,3.456675e-06,205.0,0.002950,2.277094e-07,50.0,0.003824,2.358212e-06,97.0,0.006810,8.046640e-07,289.0,0.001023,6.137193e-09,113.0
4,0.008151,5.712565e-06,254.0,0.006751,5.638046e-07,79.0,0.008491,4.738705e-06,111.0,0.014415,2.375407e-06,300.0,0.002324,1.664002e-08,113.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1190695,0.000499,3.177044e-07,139.0,0.000066,2.351754e-10,166.0,0.000074,1.607287e-09,79.0,0.000495,3.933408e-08,112.0,0.000029,1.595918e-11,198.0
1190696,0.000507,3.413467e-07,130.0,0.000067,2.065802e-10,48.0,0.000074,1.434071e-09,91.0,0.000498,4.364409e-08,105.0,0.000029,1.609563e-11,196.0
1190697,0.002547,1.786764e-06,216.0,0.000066,3.636120e-10,139.0,0.000078,1.927438e-09,109.0,0.002297,4.228186e-07,155.0,0.000029,5.423783e-11,200.0
1190698,0.001458,1.194931e-06,191.0,0.000351,7.279864e-09,54.0,0.000544,9.299581e-08,101.0,0.002183,1.692077e-07,276.0,0.000101,1.518171e-10,76.0


In [5]:
stations_table = pd.read_json("../1_data_processing/1_3_Sensibility/stations_locations.json")
stations_table = stations_table.set_index("name").rename(columns={"primary production": "primary_production"})
stations_table

Unnamed: 0_level_0,longitude,latitude,temperature,primary_production
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
BARENTS,26.969,74.62,4.036164,121.380569
HOT,-158.004,22.752,23.839729,254.277267
BATS,-64.2,31.604,21.537741,265.166229
PAPA,-149.996,50.006,6.785365,276.715942
GUAM,149.995,13.001,27.390701,112.10244


In [6]:
cost_function_results.describe().transpose().reset_index().groupby(["quantity_of_interest", "station"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,count,mean,std,min,25%,50%,75%,max
quantity_of_interest,station,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
argmax,BARENTS,1197400.0,189.6542,41.78356,114.0,147.0,200.0,213.0,322.0
argmax,BATS,1197400.0,97.54293,16.39207,14.0,91.0,94.0,100.0,365.0
argmax,GUAM,1197400.0,163.6088,70.35997,2.0,113.0,196.0,196.0,365.0
argmax,HOT,1197400.0,71.66184,55.11521,10.0,48.0,48.0,68.0,365.0
argmax,PAPA,1197400.0,190.4992,77.57854,63.0,134.0,151.0,287.0,365.0
mean,BARENTS,1197400.0,0.005217319,0.005830048,1.319749e-08,0.001092206,0.003089444,0.007166852,0.03682
mean,BATS,1197400.0,0.003264357,0.008349648,1.373881e-09,0.000135329,0.0002706947,0.001688974,0.08922
mean,GUAM,1197400.0,0.001016397,0.003049677,5.411974e-10,4.401971e-05,8.806513e-05,0.000291709,0.037072
mean,HOT,1197400.0,0.002721216,0.007596781,1.228245e-09,0.0001096352,0.0002193715,0.001061145,0.086614
mean,PAPA,1197400.0,0.007486171,0.01019243,1.053651e-08,0.001158301,0.003405778,0.009268034,0.073865


In [7]:
quantity_of_interest = ["_".join(col).strip() for col in cost_function_results.columns.values]
quantity_of_interest

['BARENTS_mean',
 'BARENTS_variance',
 'BARENTS_argmax',
 'HOT_mean',
 'HOT_variance',
 'HOT_argmax',
 'BATS_mean',
 'BATS_variance',
 'BATS_argmax',
 'PAPA_mean',
 'PAPA_variance',
 'PAPA_argmax',
 'GUAM_mean',
 'GUAM_variance',
 'GUAM_argmax']

In [8]:
# Get unique value but keep order
stations = cost_function_results.columns.levels[0]
quantity_of_interest_function = cost_function_results.columns.levels[1]
print(stations)
print(quantity_of_interest_function)

Index(['BARENTS', 'BATS', 'GUAM', 'HOT', 'PAPA'], dtype='object', name='station')
Index(['argmax', 'mean', 'variance'], dtype='object', name='quantity_of_interest')


In [9]:
parameters_name = [
    "energy_transfert",
    "tr_0",
    "gamma_tr",
    "lambda_0",
    "gamma_lambda",
]

In [10]:
coupled_parameters_name = [
    f"({i}, {j})"
    for i, j in product(parameters_name, parameters_name)
    if parameters_name.index(i) < parameters_name.index(j)
]

In [11]:
sp = ProblemSpec(
    {
        "names": parameters_name,
        "groups": None,
        "bounds": [
            [0, 1],
            [0.1, 50],
            [-0.5, -0.0001],
            [0.1, 500],
            [-0.5, -0.0001],
        ],
        "outputs": quantity_of_interest,
    }
)

In [12]:
sp.set_results(cost_function_results.to_numpy()[: int(np.floor(np.sqrt(55000)) ** 2), ...])

Outputs:
	15 outputs: ['BARENTS_mean', 'BARENTS_variance', 'BARENTS_argmax', 'HOT_mean', 'HOT_variance', 'HOT_argmax', 'BATS_mean', 'BATS_variance', 'BATS_argmax', 'PAPA_mean', 'PAPA_variance', 'PAPA_argmax', 'GUAM_mean', 'GUAM_variance', 'GUAM_argmax']
	54756 evaluations


In [13]:
_ = sp.analyze_sobol()

  names = list(pd.unique(groups))


In [14]:
all_entries = []
for st in stations:
    for qoi in quantity_of_interest_function:
        index = pd.MultiIndex.from_product([[st], [qoi], parameters_name], names=["Station", "QoI", "Parameter"])
        data = pd.DataFrame(
            np.asarray([sp.analysis[f"{st}_{qoi}"]["ST"], sp.analysis[f"{st}_{qoi}"]["ST_conf"]]).T,
            index=index,
            columns=["ST", "ST_conf"],
        )

        all_entries.append(data)
sobol_total_order = pd.concat(all_entries).reset_index()
sobol_total_order["is_hot"] = sobol_total_order["Station"].isin(
    stations_table.reset_index().query("temperature > 10")["name"]
)
sobol_total_order["is_productive"] = sobol_total_order["Station"].isin(
    stations_table.reset_index().query("primary_production > 150")["name"]
)
sobol_total_order

Unnamed: 0,Station,QoI,Parameter,ST,ST_conf,is_hot,is_productive
0,BARENTS,argmax,energy_transfert,0.410912,0.024108,False,False
1,BARENTS,argmax,tr_0,0.357162,0.018326,False,False
2,BARENTS,argmax,gamma_tr,0.034118,0.002362,False,False
3,BARENTS,argmax,lambda_0,0.876789,0.034710,False,False
4,BARENTS,argmax,gamma_lambda,0.840164,0.035221,False,False
...,...,...,...,...,...,...,...
70,PAPA,variance,energy_transfert,0.488864,0.079866,False,True
71,PAPA,variance,tr_0,0.432385,0.090768,False,True
72,PAPA,variance,gamma_tr,0.002714,0.000569,False,True
73,PAPA,variance,lambda_0,0.816263,0.123220,False,True


In [15]:
all_entries = []
for st in stations:
    for qoi in quantity_of_interest_function:
        index = pd.MultiIndex.from_product([[st], [qoi], parameters_name], names=["Station", "QoI", "Parameter"])
        data = pd.DataFrame(
            np.asarray([sp.analysis[f"{st}_{qoi}"]["S1"], sp.analysis[f"{st}_{qoi}"]["S1_conf"]]).T,
            index=index,
            columns=["S1", "S1_conf"],
        )

        all_entries.append(data)
sobol_first_order = pd.concat(all_entries).reset_index()
sobol_first_order["is_hot"] = sobol_first_order["Station"].isin(
    stations_table.reset_index().query("temperature > 10")["name"]
)
sobol_first_order["is_productive"] = sobol_first_order["Station"].isin(
    stations_table.reset_index().query("primary_production > 150")["name"]
)
sobol_first_order

Unnamed: 0,Station,QoI,Parameter,S1,S1_conf,is_hot,is_productive
0,BARENTS,argmax,energy_transfert,0.031184,0.022209,False,False
1,BARENTS,argmax,tr_0,0.071883,0.023935,False,False
2,BARENTS,argmax,gamma_tr,-0.000500,0.007241,False,False
3,BARENTS,argmax,lambda_0,0.276074,0.034200,False,False
4,BARENTS,argmax,gamma_lambda,0.241149,0.034411,False,False
...,...,...,...,...,...,...,...
70,PAPA,variance,energy_transfert,0.044374,0.017506,False,True
71,PAPA,variance,tr_0,0.028525,0.017251,False,True
72,PAPA,variance,gamma_tr,0.000508,0.002336,False,True
73,PAPA,variance,lambda_0,0.088591,0.040932,False,True


---

# Plotting the Sobol results


In [16]:
def plot_total_analysis(sobol: ProblemSpec):
    figures = []
    for qoi in sobol.analysis:
        fig = make_subplots(
            rows=1, cols=3, subplot_titles=("1st-order Sobol index", "2nd-order Sobol index", "Total Sobol index")
        )
        fig.add_trace(
            go.Bar(
                x=parameters_name,
                y=sobol.analysis[qoi]["S1"],
                error_y={"type": "data", "array": sobol.analysis[qoi]["S1_conf"]},
                name="1st-order Sobol index",
            ),
            row=1,
            col=1,
        )
        fig.add_trace(
            go.Bar(
                x=coupled_parameters_name,
                y=list(filter(np.isfinite, sobol.analysis[qoi]["S2"].flatten())),
                error_y={"type": "data", "array": list(filter(np.isfinite, sobol.analysis[qoi]["S2_conf"].flatten()))},
                name="2nd-order Sobol index",
            ),
            row=1,
            col=2,
        )
        fig.add_trace(
            go.Bar(
                x=parameters_name,
                y=sobol.analysis[qoi]["ST"],
                error_y={"type": "data", "array": sobol.analysis[qoi]["ST_conf"]},
                name="Total Sobol index",
            ),
            row=1,
            col=3,
        )
        fig.update_layout(showlegend=False, title_text=f"Sobol indices for {qoi}")
        figures.append(fig)
    return figures

In [17]:
for fig in plot_total_analysis(sp):
    fig.show()

## Aggregated plots


### Total order Sobol indices


In [33]:
fig = px.bar(
    sobol_total_order,
    x="Station",
    facet_col="Parameter",
    y="ST",
    error_y="ST_conf",
    facet_row="QoI",
    color="is_hot",
    pattern_shape="is_productive",
    title="Total order Sobol index",
)

fig.update_layout(
    title="Total order Sobol index",
    width=1200,
    height=800,
)

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_yaxes(range=[0, 1.5])
# je veux plus de ticks selon y
fig.update_yaxes(
    # ticks="outside",
    tickmode="linear",
    dtick=0.25,
)

In [19]:
fig.write_html("sobol_total_order.html")

### First order Sobol indices


In [22]:
fig = px.bar(
    sobol_first_order,
    x="Station",
    facet_col="Parameter",
    y="S1",
    error_y="S1_conf",
    facet_row="QoI",
    color="is_hot",
    pattern_shape="is_productive",
    title="First order Sobol index",
)

fig.update_layout(width=1200, height=800)

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
# je veux ajouter des limites pour laxe y
fig.update_yaxes(range=[0, 0.51])

In [20]:
# fig.write_html("sobol_first_order.html")z