### Grid Search for GPR Hyperparameters
- Used custom `GridSearchCV` function to work with Multi-Output Exact GP Model

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import sys
import pandas as pd
import time
from datetime import datetime
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error
import itertools
sys.path.append('../')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import gpytorch
import warnings
warnings.filterwarnings("ignore", message="The input matches the stored training data. Did you forget to call model.train()?") 
import gc
print('Using ', device)
dataType = torch.float64
torch.set_default_dtype(dataType)

Using  cpu


In [2]:
noise = 10 # ADJUST level of gaussian noise added to outputs
mod_type = 'gpr'
description = mod_type + '_noise-' + str(noise)
filename = '../datasets/fuchs_v3-2_seed-5_points_25000_noise_' + str(noise) + '.csv'  # CHANGE TO DESIRED DATA FILE
df = pd.read_csv(filename)

In [3]:
input_list = ['Intensity_(W_cm2)', 'Target_Thickness (um)', 'Focal_Distance_(um)'] # independent variables
output_list = ['Max_Proton_Energy_(MeV)', 'Total_Proton_Energy_(MeV)', 'Avg_Proton_Energy_(MeV)',
               'Max_Proton_Energy_Exact_(MeV)', 'Total_Proton_Energy_Exact_(MeV)', 'Avg_Proton_Energy_Exact_(MeV)'] # training outputs

X = df[input_list].copy()
y = df[output_list].copy()
# X.iloc[:, 0] = X.iloc[:,0].apply(lambda x: np.log(x)) # Apply Log Scale to Intensity
# y.iloc[:] = y.iloc[:].apply(lambda x: np.log(x)) # Apply Log Scale to Max,Tot,Avg Energy
X[X.columns[0]] = np.log(X[X.columns[0]])
for col in y.columns:
    y[col] = np.log(y[col])



X = X.to_numpy()
y = y.to_numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, shuffle = False)
y_train = y_train[:, 0:3]
y_test = y_test[:, 3:6]
pct = 25 # Using all 20,000 points in training/validation set
len_df = int(len(X_train)*(pct/100))
X_train = X_train[0:len_df]
y_train = y_train[0:len_df]

ss_in = StandardScaler()
ss_in.fit(X_train)
X_train_norm = ss_in.transform(X_train)
X_test_norm = ss_in.transform(X_test)

ss_out = StandardScaler()
ss_out.fit(y_train)
y_train_norm = ss_out.transform(y_train)

X_train_norm = torch.tensor(X_train_norm, dtype=dataType)
y_train_norm = torch.tensor(y_train_norm, dtype=dataType)

print('train ds length: ', len(X_train_norm))
print(X_train_norm.dtype)

train ds length:  5000
torch.float64


In [4]:
class Exact_GP(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(Exact_GP, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [5]:
param_grid = {
    'lr': [1e0, 5e-1, 2e-1, 1e-1, 5e-2],
    'num_epochs': [30, 50],
}
#param_grid = {'lr': [5e-1], 'num_epochs': [1]}

param_nested_list = [param_grid[key] for key in param_grid.keys()]
param_list = list(itertools.product(*param_nested_list))

def k_fold_split(X, k=5):
    # Calculate the size of each fold
    fold_size = len(X) // k

    # Initialize a list to store the folds
    folds = []

    # Create each fold
    for i in range(k):
        # Calculate the start and end indices of the current fold
        start = i * fold_size
        end = (i + 1) * fold_size if i < k - 1 else len(X)

        # Get the fold data
        fold = X[start:end]

        # Add the fold to the list of folds
        folds.append(fold)

    return folds
    
def k_fold_cv(cv=5, lr=1e-3, num_epochs=50):
    print('starting CV for lr={}, epochs={}'.format(lr, num_epochs))
    folds_X = k_fold_split(X_train_norm, k=cv)
    folds_y = k_fold_split(y_train_norm, k=cv)
    num_outputs = len(y_train_norm[0])
    mse_list = np.zeros((cv, num_outputs))
    mean_uncertainty_list = np.zeros((cv, num_outputs))
    median_uncertainty_list = np.zeros((cv, num_outputs))

    
    for i in trange(cv, desc='CV'):
        idx_list = list(range(cv))
        del idx_list[i]
        
        X_train_cv = torch.concat([folds_X[k] for k in idx_list], axis=0)
        y_train_cv = torch.concat([folds_y[k] for k in idx_list], axis=0)
        X_val_cv = folds_X[i]
        y_val_cv = folds_y[i]
        
        for j in trange(num_outputs, desc='output', leave=False, position=0):
            likelihood = gpytorch.likelihoods.GaussianLikelihood()
            model = Exact_GP(X_train_cv, y_train_cv[:, j], likelihood)
        
            model.train()
            likelihood.train()

            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
        
            current_it = 0
            while(current_it < num_epochs):
                # Zero the Gradients
                optimizer.zero_grad()
                
                # Perform Forward Pass
                model_output = model(X_train_cv)
                
                # Compute Loss
                loss = -mll(model_output, y_train_cv[:, j])
                
                # Perform Backward Pass
                loss.backward()
                
                # Clear cache
                gc.collect()
                torch.cuda.empty_cache()
                current_it += 1
                
                # Optimization
                optimizer.step()
            
            model.eval()
            likelihood.eval()
            with torch.no_grad(), gpytorch.settings.fast_pred_var():
                pred_dist_train = likelihood(model(X_train_cv))
                pred_dist_val = likelihood(model(X_val_cv))
                y_train_predict = pred_dist_train.mean
                y_val_predict = pred_dist_val.mean
                #variance = pred_dist_val.variance
                lower, upper = pred_dist_val.confidence_region()
                lower_unscaled = np.exp(ss_out.inverse_transform(lower.reshape(-1, 1).repeat(1, 3)))[:, j]
                upper_unscaled = np.exp(ss_out.inverse_transform(upper.reshape(-1, 1).repeat(1, 3)))[:, j]
                y_val_predict_unscaled = np.exp(ss_out.inverse_transform(y_val_predict.reshape(-1, 1).repeat(1, 3)))[:, j]
                mse_list[i, j] = mean_squared_error(y_val_cv[:, j], y_val_predict)
                median_uncertainty_list[i, j] = np.median((upper_unscaled - lower_unscaled)/(4*y_val_predict_unscaled)*100)
                mean_uncertainty_list[i, j] = np.mean((upper_unscaled - lower_unscaled)/(4*y_val_predict_unscaled)*100)
                print(median_uncertainty_list[i, j], mean_uncertainty_list[i, j])
                X_train_cv = torch.concat([X_train_cv, y_train_predict.reshape(-1, 1)], axis=1)
                X_val_cv = torch.concat([X_val_cv, y_val_predict.reshape(-1, 1)], axis=1)
        
    mse_list_energy_averaged = np.mean(mse_list, axis=1)
    mean_cv_mse = np.mean(mse_list_energy_averaged)
    std_cv_mse = np.std(mse_list_energy_averaged)
    mean_uncertainty_max = np.mean(mean_uncertainty_list[:, 0])
    median_uncertainty_max = np.mean(median_uncertainty_list[:, 0])
    return [mean_cv_mse, std_cv_mse, mean_uncertainty_max, median_uncertainty_max]

def GridSearchCV(param_list, cv=5):
    mse_list = np.zeros(len(param_list))
    std_list = np.zeros(len(param_list))
    mean_unc_list = np.zeros(len(param_list))
    median_unc_list = np.zeros(len(param_list))
    for k, param in enumerate(param_list):
        lr = param[0]
        num_epochs = param[1]
        mse, std, σ1, σ2 = k_fold_cv(cv=cv, lr=lr, num_epochs=num_epochs)
        mse_list[k] = mse
        std_list[k] = std
        mean_unc_list[k] = σ1
        median_unc_list[k] = σ2
    best_idx = np.argmin(mse_list)
    print('best (lowest) mse: ', mse_list[best_idx], ' with σ=', std_list[best_idx], ' and estimated uncertainty ', median_unc_list[best_idx])
    print('with params lr={}, epochs={}'.format(param_list[int(best_idx)][0], param_list[int(best_idx)][1]))
    return pd.DataFrame({'Params (lr, epochs)':param_list, 'Mean Squared Error':mse_list, 'Standard Deviation':std_list, 'Mean Uncertainty': mean_unc_list, 'Median Uncertainty': median_unc_list})
        
        
output_df = GridSearchCV(param_list, cv=3)

starting CV for lr=1.0, epochs=30


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

11.263386063229095 11.292660590549886
8.758189876952326 8.768135661467827
11.518935048899332 11.528095776724301


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.055850044141042 11.083047543141323
8.570801476856996 8.579923253360496
11.282215750957164 11.290814026297769


output:   0%|          | 0/3 [00:00<?, ?it/s]

10.975263053536834 11.00282066864438
8.459463770185522 8.46966214589579
11.118926826209393 11.127814182453086
starting CV for lr=1.0, epochs=50


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

11.188697839982655 11.213385762393957
10.670672223613863 10.67963632655081
11.084390755219363 11.091740860325137


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.088831668831489 11.113709785041747
10.509008554166133 10.51890129446633
10.96955827020117 10.97701593012026


output:   0%|          | 0/3 [00:00<?, ?it/s]

10.938219253233896 10.964369033459798
10.346048550292073 10.35494399662706
10.859856681919927 10.866959737429076
starting CV for lr=0.5, epochs=30


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

12.902341426558714 12.925711250860791
14.602866762355847 14.621762556611325
12.531160544591854 12.54504619407253


output:   0%|          | 0/3 [00:00<?, ?it/s]

12.777636870407097 12.801875223032027
14.403993133034106 14.422239783323079
12.422853026734359 12.43620027460947


output:   0%|          | 0/3 [00:00<?, ?it/s]

12.679080873187308 12.701454731331854
14.254807917340056 14.27328065911231
12.345421373375109 12.359239332441499
starting CV for lr=0.5, epochs=50


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

11.42030915519303 11.445039943567657
10.79027728981961 10.803571094135823
11.363845877321582 11.374339195822442


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.265205544715915 11.29046346676003
10.610174122949953 10.622532118726165
11.212229562464048 11.22247338549606


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.147066645776968 11.171762988093828
10.460626715196078 10.473963033824065
11.108140653639868 11.118608267600347
starting CV for lr=0.2, epochs=30


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

8.454500369012699 8.473506547460596
12.038331865990713 12.060990252363768
8.181780478792941 8.197406251297336


output:   0%|          | 0/3 [00:00<?, ?it/s]

8.387758637498665 8.406358941739096
12.014421246612077 12.038545689757514
8.111293861457998 8.127499042493504


output:   0%|          | 0/3 [00:00<?, ?it/s]

8.349547615893162 8.36818239594245
11.9942580145754 12.016910279094324
8.055759318324336 8.0711292862998
starting CV for lr=0.2, epochs=50


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

11.813089095770223 11.83882219801778
10.293741980297607 10.30958206929516
11.83293104874923 11.848584074969304


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.62851330076452 11.653788723124935
10.096207554073391 10.112638035839723
11.668468839377663 11.684567166171616


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.49975016486847 11.524537996369885
9.937107028420668 9.953148775607394
11.53539444690955 11.551110258428343
starting CV for lr=0.1, epochs=30


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

31.773778826065346 31.830588195900937
59.06943489958872 59.193548064957035
29.298662021700075 29.35303951304826


output:   0%|          | 0/3 [00:00<?, ?it/s]

31.772028841878388 31.82819735249921
59.06817010072299 59.19073436242173
29.30163347377165 29.355698572406638


output:   0%|          | 0/3 [00:00<?, ?it/s]

31.767028157742892 31.820218824113887
59.06861658272892 59.185486789003654
29.289593877751642 29.341355887629256
starting CV for lr=0.1, epochs=50


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

12.025315908163138 12.050270428914732
18.79157758804598 18.82935989618891
11.373006749705493 11.396646298869664


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.978501083794768 12.003830003631778
18.77158896610197 18.81050011897741
11.326411852231537 11.350421550726102


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.951511513912724 11.97608968113913
18.758657577731746 18.796368235233043
11.286609900269934 11.309694178195164
starting CV for lr=0.05, epochs=30


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

79.69872511740078 79.89885815493172
201.18094488270725 201.97332858410277
71.83326904325799 72.03335224249037


output:   0%|          | 0/3 [00:00<?, ?it/s]

79.72723442908011 79.91871342250714
201.2896534477087 202.0593300186903
71.85013635873912 72.0439645508763


output:   0%|          | 0/3 [00:00<?, ?it/s]

79.73232573883364 79.91895989991629
201.28568012853998 202.0480492888008
71.854614537771 72.04057568461131
starting CV for lr=0.05, epochs=50


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

41.6178693469607 41.702822473004005
82.12343882299947 82.33766871602272
38.19733689060815 38.28014711171427


output:   0%|          | 0/3 [00:00<?, ?it/s]

41.62213328722847 41.706047422909286
82.13573206639985 82.34928536495968
38.20070470371074 38.28368918876271


output:   0%|          | 0/3 [00:00<?, ?it/s]

41.61192686241033 41.69315028327739
82.10897745696016 82.31911864554108
38.19203117970075 38.27075167020817
best (lowest) mse:  0.004472033495938135  with σ= 0.00017372453771800754  and estimated uncertainty  12.786353056717706
with params lr=0.5, epochs=30


In [6]:
display(output_df.sort_values(by='Mean Squared Error', ascending=True))

Unnamed: 0,"Params (lr, epochs)",Mean Squared Error,Standard Deviation,Mean Uncertainty,Median Uncertainty
2,"(0.5, 30)",0.004472,0.000174,12.80968,12.786353
5,"(0.2, 50)",0.004476,0.000173,11.672383,11.647118
3,"(0.5, 50)",0.004495,0.000176,11.302422,11.277527
1,"(1.0, 50)",0.0045,0.000187,11.097155,11.071916
4,"(0.2, 30)",0.00451,0.000164,8.416016,8.397269
7,"(0.1, 50)",0.004516,0.000159,12.010063,11.98511
0,"(1.0, 30)",0.004523,0.000198,11.126176,11.098166
6,"(0.1, 30)",0.004644,0.000182,31.826335,31.770945
9,"(0.05, 50)",0.004744,0.000181,41.700673,41.61731
8,"(0.05, 30)",0.005301,0.000245,79.912177,79.719428


In [7]:
import os
os.makedirs('gpr_cv_results', exist_ok=True)

# Then save
output_df.to_csv('gpr_cv_results/grid_search.csv', index=False)