In [None]:
import os
import pandas as pd
from autoemulate.core.compare import AutoEmulate
import torch
figsize = (9, 5)
from tqdm import tqdm

In [None]:
param_filename = 'parameters_naghavi_constrained_fixed_T_v_tot_v_ref_lower_k_pas_further'

n_samples = [1024, 2048]

max_samples = 8192
simulation_out_path = f'../outputs/simulations/output_{max_samples}_samples_{param_filename}/'
parameters_json_file = os.path.join(simulation_out_path, 'parameters.json')

# Load the simulation input parameters
max_samples_input_params = pd.read_csv(os.path.join(simulation_out_path, f'input_samples_{max_samples}.csv'))

# Load the summary statistics
max_samples_summary_stats = pd.read_csv(os.path.join(simulation_out_path, f'simulations_summary.csv'))

# Get column names for parameters up to 'lv.k_pas'
parameter_names = list(max_samples_input_params.columns[:max_samples_input_params.columns.get_loc('lv.k_pas') + 1])

for i_n_samples in tqdm(n_samples, desc="Sample sizes"):

    print(f'Training emulators on {i_n_samples} samples')

    # Subset the input parameters up to row i_n_samples
    i_input_params = max_samples_input_params.iloc[:i_n_samples]

    # Subset the summary statistics
    i_summary_stats = max_samples_summary_stats.iloc[:i_n_samples]

    # Turn x into a pytorch tensor
    x = torch.tensor(i_input_params[parameter_names].values, dtype=torch.float32)

    for i_output in tqdm(i_summary_stats.columns, desc="Output features", leave=False):
        
        print(f'Processing {i_output}')

        # Create a directory for the emulators
        emulators_path = os.path.join(simulation_out_path, 'emulators', i_output, f'trained_on_{i_n_samples}_samples')
        os.makedirs(emulators_path, exist_ok=True)

        # Get the target variable
        Y = i_summary_stats[i_output].values
        y = torch.tensor(Y, dtype=torch.float32)

        # Initialize the AutoEmulate class
        ae = AutoEmulate(x = x,
                         y = y,
                         models=["GP"])

        best = ae.best_result()

        # Save the best model
        ae.save(model_obj=best,
                path=emulators_path,
                use_timestamp=False)
        
        # Add cleanup at the end of inner loop
        del ae, y, Y, best
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Add cleanup at the end of outer loop
    del x, i_input_params, i_summary_stats

    import gc
    gc.collect()