In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import sys
import pandas as pd
import time
from datetime import datetime
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error
import itertools
sys.path.append('../')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import gpytorch
import warnings
warnings.filterwarnings("ignore", message="The input matches the stored training data. Did you forget to call model.train()?") 
import gc
print('Using ', device)

Using  cuda


In [2]:
noise = 10 # ADJUST level of gaussian noise added to outputs
mod_type = 'gpr'
description = mod_type + '_noise-' + str(noise)
filename = '../datasets/fuchs_v3-2_seed-5_points_25000_noise_' + str(noise) + '.csv'  # CHANGE TO DESIRED DATA FILE
df = pd.read_csv(filename)

In [3]:
input_list = ['Intensity_(W_cm2)', 'Target_Thickness (um)', 'Focal_Distance_(um)'] # independent variables
output_list = ['Max_Proton_Energy_(MeV)', 'Total_Proton_Energy_(MeV)', 'Avg_Proton_Energy_(MeV)',
               'Max_Proton_Energy_Exact_(MeV)', 'Total_Proton_Energy_Exact_(MeV)', 'Avg_Proton_Energy_Exact_(MeV)'] # training outputs

X = df[input_list].copy()
y = df[output_list].copy()
# X.iloc[:, 0] = X.iloc[:,0].apply(lambda x: np.log(x)) # Apply Log Scale to Intensity
# y.iloc[:] = y.iloc[:].apply(lambda x: np.log(x)) # Apply Log Scale to Max,Tot,Avg Energy
X[X.columns[0]] = np.log(X[X.columns[0]])
for col in y.columns:
    y[col] = np.log(y[col])

dataType = torch.float32

X = X.to_numpy()
y = y.to_numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, shuffle = False)
y_train = y_train[:, 0:3]
y_test = y_test[:, 3:6]
pct = 25 # Using all 20,000 points in training/validation set
len_df = int(len(X_train)*(pct/100))
X_train = X_train[0:len_df]
y_train = y_train[0:len_df]

ss_in = StandardScaler()
ss_in.fit(X_train)
X_train_norm = ss_in.transform(X_train)
X_test_norm = ss_in.transform(X_test)

ss_out = StandardScaler()
ss_out.fit(y_train)
y_train_norm = ss_out.transform(y_train)

X_train_norm = torch.tensor(X_train_norm, dtype=dataType)
y_train_norm = torch.tensor(y_train_norm, dtype=dataType)

print('train ds length: ', len(X_train_norm))

train ds length:  5000


In [4]:
class Exact_GP(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(Exact_GP, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [30]:
param_grid = {
    'lr': [5e-1, 2e-1, 1e-1, 5e-2],
    'num_epochs': [10, 20, 30],
}

param_nested_list = [param_grid[key] for key in param_grid.keys()]
param_list = list(itertools.product(*param_nested_list))

def k_fold_split(X, k=5):
    # Calculate the size of each fold
    fold_size = len(X) // k

    # Initialize a list to store the folds
    folds = []

    # Create each fold
    for i in range(k):
        # Calculate the start and end indices of the current fold
        start = i * fold_size
        end = (i + 1) * fold_size if i < k - 1 else len(X)

        # Get the fold data
        fold = X[start:end]

        # Add the fold to the list of folds
        folds.append(fold)

    return folds
    
def k_fold_cv(cv=5, lr=1e-3, num_epochs=50):
    print('starting CV for lr={}, epochs={}'.format(lr, num_epochs))
    folds_X = k_fold_split(X_train_norm, k=cv)
    folds_y = k_fold_split(y_train_norm, k=cv)
    num_outputs = len(y_train_norm[0])
    mse_list = np.zeros((cv, num_outputs))
    uncertainty_list = np.zeros((cv, num_outputs))

    
    for i in trange(cv, desc='CV'):
        idx_list = list(range(cv))
        del idx_list[i]
        
        X_train_cv = torch.concat([folds_X[k] for k in idx_list], axis=0)
        y_train_cv = torch.concat([folds_y[k] for k in idx_list], axis=0)
        X_val_cv = folds_X[i]
        y_val_cv = folds_y[i]
        

        likelihoods = [gpytorch.likelihoods.GaussianLikelihood() for j in range(num_outputs)]
        models = [Exact_GP(X_train_cv, y_train_cv[:, j], likelihoods[j]) for j in range(num_outputs)]
        optimizers = [torch.optim.Adam(models[j].parameters(), lr=lr) for j in range(num_outputs)]
        mlls = [gpytorch.mlls.ExactMarginalLogLikelihood(likelihoods[j], models[j]) for j in range(num_outputs)]
    
        current_it = 0
        while(current_it < num_epochs):
            for j in range(num_outputs):
                # Set models to train mode
                models[j].train()
                likelihoods[j].train()

                # Zero the Gradients
                optimizers[j].zero_grad()
                
                # Perform Forward Pass
                model_output = models[j](X_train_cv)
                
                # Compute Loss
                loss = -mlls[j](model_output, y_train_cv[:, j])
                
                # Perform Backward Pass
                loss.backward()
                
                # Optimization
                optimizers[j].step()

                # Set models to eval mode
                models[j].eval()
                models[j].eval()



            # Clear cache
            gc.collect()
            torch.cuda.empty_cache()
            current_it += 1
                
                
        
        model.eval()
        likelihood.eval()
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            pred_dist_train = likelihood(model(X_train_cv))
            pred_dist_val = likelihood(model(X_val_cv))
            y_train_predict = pred_dist_train.mean
            y_val_predict = pred_dist_val.mean
            variance = pred_dist_val.variance
            mse_list[i, j] = mean_squared_error(y_val_cv[:, j], y_val_predict)
            uncertainty_list[i, j] = torch.median(variance)
            print(uncertainty_list[i, j])
            X_train_cv = torch.concat([X_train_cv, y_train_predict.reshape(-1, 1)], axis=1)
            X_val_cv = torch.concat([X_val_cv, y_val_predict.reshape(-1, 1)], axis=1)
        
    mse_list_energy_averaged = np.mean(mse_list, axis=1)
    mean_cv_mse = np.mean(mse_list_energy_averaged)
    std_cv_mse = np.std(mse_list_energy_averaged)
    mean_uncertainty_max = np.mean(uncertainty_list[:, 0])
    return [mean_cv_mse, std_cv_mse, mean_uncertainty_max]

def GridSearchCV(param_list, cv=5):
    mse_list = np.zeros(len(param_list))
    std_list = np.zeros(len(param_list))
    unc_list = np.zeros(len(param_list))
    for k, param in enumerate(param_list):
        lr = param[0]
        num_epochs = param[1]
        mse, std, σ = k_fold_cv(cv=cv, lr=lr, num_epochs=num_epochs)
        mse_list[k] = mse
        std_list[k] = std
        unc_list[k] = σ
    best_idx = np.argmin(mse_list)
    print('best (lowest) mse: ', mse_list[best_idx], ' with σ=', std_list[best_idx], ' and estimated uncertainty ', unc_list[best_idx])
    print('with params lr={}, epochs={}'.format(param_list[int(best_idx)][0], param_list[int(best_idx)][1]))
    return pd.DataFrame({'Params (lr, epochs)':param_list, 'Mean Squared Error':mse_list, 'Standard Deviation':std_list, 'Uncertainty': unc_list})
        
        
output_df = GridSearchCV(param_list)

starting CV for lr=0.5, epochs=10


CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.006708519998937845
0.006592962890863419
0.006900375708937645


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.0067316084168851376
0.006584614980965853
0.006899978965520859


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.006821370217949152
0.006592955440282822
0.006894970778375864


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.006820884998887777
0.006573040969669819
0.006870026234537363


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.00683019682765007
0.006590529344975948
9.999999974752427e-07
starting CV for lr=0.5, epochs=20




CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.005342410411685705
0.00029034691397100687
0.0066732908599078655


output:   0%|          | 0/3 [00:00<?, ?it/s]



9.999999974752427e-07
0.0007598806405439973
0.0066291759721934795


output:   0%|          | 0/3 [00:00<?, ?it/s]



9.999999974752427e-07
0.0006670866278000176
9.999999974752427e-07




output:   0%|          | 0/3 [00:00<?, ?it/s]



9.999999974752427e-07
0.0007168573210947216
0.006279856897890568


output:   0%|          | 0/3 [00:00<?, ?it/s]



9.999999974752427e-07
0.00025309284683316946
9.999999974752427e-07
starting CV for lr=0.5, epochs=30




CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.007930418476462364




9.999999974752427e-07
0.008987504057586193


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.007942837662994862
0.0015684037934988737
0.008900975808501244


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.008119214326143265




9.999999974752427e-07
0.008797471411526203


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.007860089652240276
0.0035430362913757563
0.008798116818070412


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.007338345050811768




9.999999974752427e-07
0.008797519840300083
starting CV for lr=0.2, epochs=10


CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.12577179074287415
0.12576746940612793
0.12593699991703033


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.1258154809474945
0.12580989301204681
0.12602557241916656


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.12580782175064087
0.12580552697181702
0.12599065899848938


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.12582749128341675
0.12581902742385864
0.12600962817668915


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.1257854402065277
0.12578414380550385
0.125980943441391
starting CV for lr=0.2, epochs=20


CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.017435424029827118
0.015758967027068138
0.017351653426885605


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.017433125525712967
0.01701456494629383
0.017505966126918793


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.017414981499314308




0.0010039806365966797
0.017499733716249466


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.017421264201402664
0.01703294925391674
0.017480842769145966


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.017428524792194366
0.01703137531876564
0.01748817227780819
starting CV for lr=0.2, epochs=30


CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.0034968892578035593
0.002537077059969306
0.0038796011358499527


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.0035664448514580727
0.0025046810042113066
0.003876074915751815


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.0035521320533007383
0.0025332916993647814
0.0038552882615476847


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.0035029726568609476
0.0025260767433792353
9.999999974752427e-07




output:   0%|          | 0/3 [00:00<?, ?it/s]

0.00352024519816041




9.999999974752427e-07
0.003831848269328475
starting CV for lr=0.1, epochs=10


CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.31481319665908813
0.31560811400413513
0.31621819734573364


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.31493863463401794
0.31577497720718384
0.31647247076034546


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.31495699286460876
0.3157331645488739
0.3164100646972656


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.3150140345096588
0.3157911002635956
0.31648164987564087


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.31488946080207825
0.3156856298446655
0.3163365423679352
starting CV for lr=0.1, epochs=20


CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.12395139038562775
0.12375164031982422
0.12405560910701752


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.1239708811044693
0.12380868196487427
0.12413470447063446


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.12397588044404984
0.12382130324840546
0.12410075962543488


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.12398044764995575
0.12382287532091141
0.12413612008094788


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.12395352125167847
0.12378571927547455
0.1241065040230751
starting CV for lr=0.1, epochs=30


CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.04565177485346794
0.04515351727604866
0.04569306969642639


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.04567170888185501
0.04516629874706268
0.04573625698685646


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.04566651210188866
0.045174069702625275
0.045710861682891846


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.04566594585776329
0.04517252370715141
0.045686718076467514


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.04566759616136551
0.045162133872509
0.045690134167671204
starting CV for lr=0.05, epochs=10


CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.4813026189804077
0.48467686772346497
0.4880453646183014


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.4813728928565979
0.48494991660118103
0.4884839355945587


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.48144811391830444
0.48493754863739014
0.4883747696876526


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.48152485489845276
0.48503682017326355
0.48821762204170227


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.4813496768474579
0.4846315383911133
0.48824965953826904
starting CV for lr=0.05, epochs=20


CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.3134565055370331
0.31416672468185425
0.31486108899116516


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.31358060240745544
0.31435051560401917
0.31514057517051697


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.3136022686958313
0.314301997423172
0.3150677978992462


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.3136546015739441
0.3143708407878876
0.3150807023048401


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.3135280907154083
0.31423449516296387
0.31501150131225586
starting CV for lr=0.05, epochs=30


CV:   0%|          | 0/5 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

0.1974838227033615
0.1974511295557022
0.19787298142910004


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.19755956530570984
0.19754894077777863
0.1980285793542862


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.1975584179162979
0.1975450962781906
0.1979849636554718


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.1975746601819992
0.1975691169500351
0.19801214337348938


output:   0%|          | 0/3 [00:00<?, ?it/s]

0.19752663373947144
0.19751620292663574
0.19792334735393524
best (lowest) mse:  0.004479287828629215  with σ= 0.00020035324596718478  and estimated uncertainty  0.007838181033730508
with params lr=0.5, epochs=30


In [23]:
display(output_df)

Unnamed: 0,"Params (lr, epochs)",Mean Squared Error,Standard Deviation,Uncertainty
0,"(0.1, 5)",0.005782,0.000329,0.4841


In [None]:
output_df.to_csv('gpr_results/grid_search_1.csv', index=False)