In [None]:
# Import relevant libraries
import os
import pandas as pd
from ModularCirc.Models.NaghaviModel import NaghaviModel, NaghaviModelParameters, TEMPLATE_TIME_SETUP_DICT
from SALib.sample import saltelli
import json
from ModularCirc import BatchRunner
from comparative_gsa.simulate_data import simulate_data
from comparative_gsa.calculate_output_features import calculate_output_features
import pickle
import gc

In [None]:
param_path = '../inputs/parameters_naghavi_constrained_fixed_T_v_tot_v_ref_lower_k_pas.json'

# For the batch runner
map_ = {
    'lv.t_tr' : ['lv.t_tr',],
    'la.t_tr' : ['la.t_tr',],
    'la.delay' : ['la.delay',],
    'lv.tau' : ['lv.tau',],
    'la.tau' : ['la.tau',],
    'lv.t_max' : ['lv.t_max',],
    'la.t_max' : ['la.t_max',],
}

## read in save parameters to the folder where simulations r saved.
with open(param_path, 'r') as f:
    params = json.load(f)

# Get the filename from the path, without extension
param_filename = os.path.splitext(os.path.basename(param_path))[0]

n_samples = [16, 32]

In [None]:
for i_n_samples in n_samples:

    print(f"Running simulation with {i_n_samples} samples...")

    # Set up the batch runner
    # Delete the br variable if it exists
    if 'br' in locals():
        del br
    br = BatchRunner('Sobol', 0) # why are we using 'Sobol' here?
    br.setup_sampler(param_path)
    br.sample(10**6)

    # Now, generate samples for sensitivity analysis by taking the keys of _parameters_2_sample and turning into a list
    relevant_columns = list(br._parameters_2_sample.keys())

    problem = {
        'num_vars': len(relevant_columns),
        'names': relevant_columns,
        'bounds' : br.samples[relevant_columns].describe().loc[['min', 'max']].T.values
    }

    param_values = saltelli.sample(problem, i_n_samples, calc_second_order=True)
    print(f'{param_values.shape[0]} evaluations generated with Saltelli sampling.')

    # Truncate br.samples to be only the first i_n_samples rows
    br._samples = br._samples.iloc[:param_values.shape[0]].copy()

    param_values_df = pd.DataFrame(param_values, columns=relevant_columns)

    # For those columns that are present in both DataFrames, overwrite the values in br._samples
    for col in param_values_df.columns:
        if col in br._samples.columns:
            br._samples[col] = param_values_df[col]

    n_model_evals = param_values.shape[0]

    simulation_out_path = f'../outputs/simulations_for_sa/n_samples_{i_n_samples}_n_evals_{n_model_evals}_{param_filename}/'
    # Make this directory if it doesn't exist
    os.makedirs(simulation_out_path, exist_ok=True)    

    # Save the problem definition
    with open(os.path.join(simulation_out_path, 'problem.pkl'), 'wb') as f:
        pickle.dump(problem, f)

    # Save the saltelli samples param_values
    param_values_df.to_csv(os.path.join(simulation_out_path, 'saltelli_samples.csv'), index=False)    

    # Finish setting up the batch runner 

    # Map the sample timings
    br.map_sample_timings(
        ref_time=1000.,
        map=map_
        )

    # Map the vessel volumes
    br.map_vessel_volume()

    # Save the samples to a CSV file
    br.samples.to_csv(os.path.join(simulation_out_path,
                                    f'input_samples_{n_model_evals}.csv'),
                                    index=False)

    # Set up the model with the parameters and time setup
    br.setup_model(model=NaghaviModel, po=NaghaviModelParameters,
                    time_setup=TEMPLATE_TIME_SETUP_DICT)

    print('Starting to simulate data...')   
    simulations, bool_indices = simulate_data(
        batch_runner=br,
        simulation_out_path=simulation_out_path,
        n_jobs=7
    )

    print('Calculating output features...')
    summary_df = calculate_output_features(
            simulations=simulations,
            simulation_out_path=simulation_out_path)    
    
    # Cleanup memory
    if 'simulations' in locals():
        del simulations
    if 'bool_indices' in locals():
        del bool_indices
    if 'summary_df' in locals():
        del summary_df
    if 'param_values_df' in locals():
        del param_values_df
    if 'param_values' in locals():
        del param_values
    
    # Force garbage collection
    gc.collect()