### Deleted from Utils

In [None]:
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ----  User Selected ML Plotting Funct   ----
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# function to plot a cluster for to build on for ml
def train_emulator(param, var):
    '''Train the emulator based on the selected parameter and variable'''
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # ----      Split Data 90/10        ----
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # data for splitting
    X_train, X_test, y_train, y_test = train_test_split(param,
                                                        var,
                                                        test_size=0.2,
                                                       # setting a seed
                                                        random_state=0)

    gpr_model = GaussianProcessRegressor(normalize_y=True)

    gpr_model.fit(X_train, y_train)

    y_pred, y_std = gpr_model.predict(X_test, return_std=True) 

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # ----         Collect Metrics      ----
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Calculate Mean Absolute Error
    mae = mean_absolute_error(y_test, y_pred)

    # Calculate R^2
    r2_train = r2_score(y_test, y_pred)
    # Calculate RMSE
    rmse_train = np.sqrt(mean_squared_error(y_test, y_pred))

    # Create a DataFrame to store the results for plotting
    results_df = pd.DataFrame({
         'y_pred': y_pred,
        'y_std': y_std,
        'y_test': y_test,
         'X_test': [x.tolist() for x in X_test],  # Convert array to list for DataFrame
    })

    # Add metrics to the DataFrame
    results_df['R^2'] = r2_train
    results_df['RMSE'] = rmse_train
    results_df['Mean Absolute Error'] = mae

    return results_df

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ---- Emulator Accuracy Plot ----------------
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

def accuracy_plot(results_df):
    coef_deter = r2_score(results_df.y_test,results_df.y_pred)
    fig = plt.Figure()
    plt.errorbar(results_df.y_test,
                     results_df.y_pred,
                     yerr=3*results_df.y_std,
                     fmt="o",
                     color='#134611')
    
    plt.text(0,np.max(results_df.y_test),
                'R2_score = '+str(np.round(coef_deter,2)),
                fontsize=12)
    
    plt.plot([0,np.max(results_df.y_test)],
             [0,np.max(results_df.y_pred)],
              linestyle='--',
               c='k')

    plt.xlabel('Variable Test')
    plt.ylabel('Emulated Variable')
    plt.title('Emulator Validation')
    
    return fig
    

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ---- Emulator Accuracy Plot ----------------
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

def emulator_plot(results_df, param, var):
    #Create an array that sets the value of all 32 parameters to 0.5
    X_values = np.full((10, 32), 0.5)  # Fill array with 0.5
    #For the parameter of interest, replace the 0.5 with a range of values between 0 and 1
    X_values[:, 15] = np.linspace(0, 1, 10)  # Set the 15th column values to evenly spaced values from 0 to 1
    coef_deter = r2_score(results_df.y_test,results_df.y_pred)

    # bc this is not a nested function, gpr_model needs to be included
    gpr_model = GaussianProcessRegressor(normalize_y=True)
    # collect predictions using X_values for plotting trend appropriately
    y_pred, y_std = gpr_model.predict(X_values, return_std=True)
    
    fig, ax = plt.subplots()
    plt.figure(figsize=(10, 6))
    # Plot the mean line
    plt.plot(X_values[:, {param}], results_df.y_pred, color='#134611', linestyle='-', label='Gaussian Process Regression Emulation')

    # Calculate the z-score for the 99.7% confidence interval
    z_score = norm.ppf(0.99865)  # 99.7th percentile (three standard deviations)

    # Plot the shaded region for the 99.7% confidence interval with three standard deviations
    plt.fill_between(X_values[:,{param}], results_df.y_pred - z_score * results_df.y_std, results_df.y_pred + z_score * results_df.y_std,
                     color='#9d6b53',
                     alpha=0.3,
                     label = 'Confidence Interval within 3 Standard Deviations')

    plt.text(0.5, np.max(results_df.y_test),
                'R2_score = '+str(np.round(coef_deter,2)),
                fontsize=12)

    plt.xlabel(f'Perturbed Parameter: {param}')
    plt.ylabel(f'Variable: {var}')
    plt.title('Parameter Perturbation Uncertainty Estimation')

    plt.legend()
    return fig