In [None]:
import os
import pandas as pd
from autoemulate.core.compare import AutoEmulate
import torch
figsize = (9, 5)

import matplotlib.pyplot as plt

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [None]:
param_filename = 'parameters_naghavi_constrained_fixed_T_v_tot_v_ref_lower_k_pas_further'

save_dir = f'../../outputs/sa_results/plots/{param_filename}'

In [None]:
# Load the combined sobol df
combined_sobol_df = pd.read_csv(os.path.join('../../outputs/sa_results/aggregated_dfs/',
                                             param_filename,
                                             'combined_sobol_df.csv'))

In [None]:
data_type = 'simulations'

sims_df = combined_sobol_df[combined_sobol_df['data_type'] == data_type]

In [None]:
index_to_plot = 'S1'

# Set style once outside the loop
sns.set_style("whitegrid")

# Plot the index convergence for simulations
df = sims_df[sims_df['index'] == index_to_plot]

output_to_plot = ['p_ao_min',
                  'p_ao_max']

# Define different markers for each parameter
markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h', 'H', '+', 'x']

for i_feature in output_to_plot:

    print('Plotting for ', i_feature)

    i_df = df[df['output'] == i_feature]

    plt.figure(figsize=(10, 6))
    
    # Get unique parameters for plotting
    parameters = i_df['parameter'].unique()
    
    for i, param in enumerate(parameters):
        param_data = i_df[i_df['parameter'] == param].sort_values('n_model_evals')

        x = param_data['n_model_evals']
        y = param_data['value']
        ci = param_data['confidence']
        
        # Calculate confidence bounds
        y_lower = y - ci/2
        y_upper = y + ci/2
        
        # Plot the line with different markers
        marker = markers[i % len(markers)]  # Cycle through markers if more parameters than markers
        plt.plot(x, y, marker=marker, label=param, linewidth=2, markersize=6)
        
        # Add confidence interval shading
        plt.fill_between(x, y_lower, y_upper, alpha=0.2)

    # Use log scale for x-axis
    plt.xscale('log', base=2)

    # Set y axis scale manually
    if index_to_plot == 'S1':
        plt.ylim(-0.2, 0.6)
    elif index_to_plot == 'ST':
        plt.ylim(0, 0.5)

    # Set x-axis ticks to actual n_model_evals values
    unique_x_vals_n_model_evals = sorted(i_df['n_model_evals'].unique())
    
    # Create custom labels showing both values
    custom_labels = []
    for n_evals in unique_x_vals_n_model_evals:
        # Find corresponding n_saltelli_samples for this n_model_evals
        n_samples = i_df[i_df['n_model_evals'] == n_evals]['n_saltelli_samples'].iloc[0]
        custom_labels.append(f"{int(n_samples)}\n({int(n_evals):,})")
    
    plt.xticks(unique_x_vals_n_model_evals, custom_labels)

    # Optional: Format large numbers with commas for readability
    # ax = plt.gca()
    # ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: format(int(x), ',')))
    font_size = 20

    plt.xlabel('Samples (N Model Evals) \n(log scale)', size=font_size)
    plt.ylabel(f'{index_to_plot} Sensitivity Index', size=font_size)
    plt.title(f'{index_to_plot} Index Convergence for {i_feature}', size=font_size)
    plt.legend(bbox_to_anchor=(-0.1, 1), loc='upper right')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'{save_dir}/gsa_{index_to_plot}_convergence_{i_feature}.png', dpi=300, bbox_inches='tight')

    plt.show()