In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import sys
import pandas as pd
import time
from datetime import datetime
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_percentage_error
import itertools
sys.path.append('../')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import gpytorch
import warnings
warnings.filterwarnings("ignore", message="The input matches the stored training data. Did you forget to call model.train()?") 
import gc
print('Using ', device)
dataType = torch.float64
torch.set_default_dtype(dataType)

Using  cuda


In [2]:
noise = 10 # ADJUST level of gaussian noise added to outputs
mod_type = 'gpr'
description = mod_type + '_noise-' + str(noise)
filename = '../datasets/fuchs_v3-2_seed-5_points_25000_noise_' + str(noise) + '.csv'  # CHANGE TO DESIRED DATA FILE
df = pd.read_csv(filename)

In [3]:
input_list = ['Intensity_(W_cm2)', 'Target_Thickness (um)', 'Focal_Distance_(um)'] # independent variables
output_list = ['Max_Proton_Energy_(MeV)', 'Total_Proton_Energy_(MeV)', 'Avg_Proton_Energy_(MeV)',
               'Max_Proton_Energy_Exact_(MeV)', 'Total_Proton_Energy_Exact_(MeV)', 'Avg_Proton_Energy_Exact_(MeV)'] # training outputs

X = df[input_list].copy()
y = df[output_list].copy()
# X.iloc[:, 0] = X.iloc[:,0].apply(lambda x: np.log(x)) # Apply Log Scale to Intensity
# y.iloc[:] = y.iloc[:].apply(lambda x: np.log(x)) # Apply Log Scale to Max,Tot,Avg Energy
X[X.columns[0]] = np.log(X[X.columns[0]])
for col in y.columns:
    y[col] = np.log(y[col])



X = X.to_numpy()
y = y.to_numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, shuffle = False)
y_train = y_train[:, 0:3]
y_test = y_test[:, 3:6]
pct = 25 # Using all 20,000 points in training/validation set
len_df = int(len(X_train)*(pct/100))
X_train = X_train[0:len_df]
y_train = y_train[0:len_df]

ss_in = StandardScaler()
ss_in.fit(X_train)
X_train_norm = ss_in.transform(X_train)
X_test_norm = ss_in.transform(X_test)

ss_out = StandardScaler()
ss_out.fit(y_train)
y_train_norm = ss_out.transform(y_train)

X_train_norm = torch.tensor(X_train_norm, dtype=dataType)
y_train_norm = torch.tensor(y_train_norm, dtype=dataType)

print('train ds length: ', len(X_train_norm))
print(X_train_norm.dtype)

train ds length:  5000
torch.float64


In [4]:
class Exact_GP(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(Exact_GP, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [5]:
param_grid = {
    'lr': [1e0, 5e-1, 2e-1, 1e-1, 5e-2],
    'num_epochs': [30, 50],
}
#param_grid = {'lr': [5e-1], 'num_epochs': [1]}

param_nested_list = [param_grid[key] for key in param_grid.keys()]
param_list = list(itertools.product(*param_nested_list))

def k_fold_split(X, k=5):
    # Calculate the size of each fold
    fold_size = len(X) // k

    # Initialize a list to store the folds
    folds = []

    # Create each fold
    for i in range(k):
        # Calculate the start and end indices of the current fold
        start = i * fold_size
        end = (i + 1) * fold_size if i < k - 1 else len(X)

        # Get the fold data
        fold = X[start:end]

        # Add the fold to the list of folds
        folds.append(fold)

    return folds
    
def k_fold_cv(cv=5, lr=1e-3, num_epochs=50):
    print('starting CV for lr={}, epochs={}'.format(lr, num_epochs))
    folds_X = k_fold_split(X_train_norm, k=cv)
    folds_y = k_fold_split(y_train_norm, k=cv)
    num_outputs = len(y_train_norm[0])
    mse_list = np.zeros((cv, num_outputs))
    mean_uncertainty_list = np.zeros((cv, num_outputs))
    median_uncertainty_list = np.zeros((cv, num_outputs))

    
    for i in trange(cv, desc='CV'):
        idx_list = list(range(cv))
        del idx_list[i]
        
        X_train_cv = torch.concat([folds_X[k] for k in idx_list], axis=0)
        y_train_cv = torch.concat([folds_y[k] for k in idx_list], axis=0)
        X_val_cv = folds_X[i]
        y_val_cv = folds_y[i]
        
        for j in trange(num_outputs, desc='output', leave=False, position=0):
            likelihood = gpytorch.likelihoods.GaussianLikelihood()
            model = Exact_GP(X_train_cv, y_train_cv[:, j], likelihood)
        
            model.train()
            likelihood.train()

            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
        
            current_it = 0
            while(current_it < num_epochs):
                # Zero the Gradients
                optimizer.zero_grad()
                
                # Perform Forward Pass
                model_output = model(X_train_cv)
                
                # Compute Loss
                loss = -mll(model_output, y_train_cv[:, j])
                
                # Perform Backward Pass
                loss.backward()
                
                # Clear cache
                gc.collect()
                torch.cuda.empty_cache()
                current_it += 1
                
                # Optimization
                optimizer.step()
            
            model.eval()
            likelihood.eval()
            with torch.no_grad(), gpytorch.settings.fast_pred_var():
                pred_dist_train = likelihood(model(X_train_cv))
                pred_dist_val = likelihood(model(X_val_cv))
                y_train_predict = pred_dist_train.mean
                y_val_predict = pred_dist_val.mean
                #variance = pred_dist_val.variance
                lower, upper = pred_dist_val.confidence_region()
                lower_unscaled = np.exp(ss_out.inverse_transform(lower.reshape(-1, 1).repeat(1, 3)))[:, j]
                upper_unscaled = np.exp(ss_out.inverse_transform(upper.reshape(-1, 1).repeat(1, 3)))[:, j]
                y_val_predict_unscaled = np.exp(ss_out.inverse_transform(y_val_predict.reshape(-1, 1).repeat(1, 3)))[:, j]
                mse_list[i, j] = mean_squared_error(y_val_cv[:, j], y_val_predict)
                median_uncertainty_list[i, j] = np.median((upper_unscaled - lower_unscaled)/(4*y_val_predict_unscaled)*100)
                mean_uncertainty_list[i, j] = np.mean((upper_unscaled - lower_unscaled)/(4*y_val_predict_unscaled)*100)
                print(median_uncertainty_list[i, j], mean_uncertainty_list[i, j])
                X_train_cv = torch.concat([X_train_cv, y_train_predict.reshape(-1, 1)], axis=1)
                X_val_cv = torch.concat([X_val_cv, y_val_predict.reshape(-1, 1)], axis=1)
        
    mse_list_energy_averaged = np.mean(mse_list, axis=1)
    mean_cv_mse = np.mean(mse_list_energy_averaged)
    std_cv_mse = np.std(mse_list_energy_averaged)
    mean_uncertainty_max = np.mean(mean_uncertainty_list[:, 0])
    median_uncertainty_max = np.mean(median_uncertainty_list[:, 0])
    return [mean_cv_mse, std_cv_mse, mean_uncertainty_max, median_uncertainty_max]

def GridSearchCV(param_list, cv=5):
    mse_list = np.zeros(len(param_list))
    std_list = np.zeros(len(param_list))
    mean_unc_list = np.zeros(len(param_list))
    median_unc_list = np.zeros(len(param_list))
    for k, param in enumerate(param_list):
        lr = param[0]
        num_epochs = param[1]
        mse, std, σ1, σ2 = k_fold_cv(cv=cv, lr=lr, num_epochs=num_epochs)
        mse_list[k] = mse
        std_list[k] = std
        mean_unc_list[k] = σ1
        median_unc_list[k] = σ2
    best_idx = np.argmin(mse_list)
    print('best (lowest) mse: ', mse_list[best_idx], ' with σ=', std_list[best_idx], ' and estimated uncertainty ', median_unc_list[best_idx])
    print('with params lr={}, epochs={}'.format(param_list[int(best_idx)][0], param_list[int(best_idx)][1]))
    return pd.DataFrame({'Params (lr, epochs)':param_list, 'Mean Squared Error':mse_list, 'Standard Deviation':std_list, 'Mean Uncertainty': mean_unc_list, 'Median Uncertainty': median_unc_list})
        
        
output_df = GridSearchCV(param_list, cv=3)

starting CV for lr=1.0, epochs=30


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

11.299093204743034 11.327267603898488
8.753550000104461 8.76265320838631
11.52476573292332 11.533655622652395


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.046664636608888 11.074110414771253
8.568949422385948 8.578168537163716
11.273872864859811 11.282491381047782


output:   0%|          | 0/3 [00:00<?, ?it/s]

10.95994421926057 10.987265055674495
8.45793043472608 8.467885304833771
11.127539382175598 11.136390578341675
starting CV for lr=1.0, epochs=50


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

11.192062036169347 11.217418901891195
10.674286841372965 10.68335898643542
11.086936935374005 11.094497462288345


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.095515496922935 11.120455205756771
10.509128658931619 10.519002421155266
10.971744333593719 10.97926191305321


output:   0%|          | 0/3 [00:00<?, ?it/s]

10.960754571967659 10.987977622144417
10.353991583094011 10.362962722335062
10.865849032190962 10.872950314428943
starting CV for lr=0.5, epochs=30


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

12.898719968176819 12.923756443571497
14.599536707080363 14.617825840190328
12.529035322204518 12.5426203594805


output:   0%|          | 0/3 [00:00<?, ?it/s]

12.777552321089372 12.802677525894634
14.401908521390196 14.419581142424414
12.414110663932863 12.427311628066978


output:   0%|          | 0/3 [00:00<?, ?it/s]

12.678995320434058 12.701170331332001
14.255852165800217 14.274020031770705
12.341984519270184 12.355589645752557
starting CV for lr=0.5, epochs=50


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

11.417513447535281 11.442747546162382
10.78235201012544 10.793494921099441
11.349781180546163 11.35953711373544


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.264642139552144 11.290296075295949
10.608512141206138 10.621155813330198
11.21461004624059 11.224739499089038


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.148077638949182 11.173669002467483
10.449576500533771 10.460447948412728
11.08841697661122 11.097849852346144
starting CV for lr=0.2, epochs=30


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

8.45508025049703 8.473747781328623
12.040334420222681 12.063123357624615
8.183655917408585 8.199339251342161


output:   0%|          | 0/3 [00:00<?, ?it/s]

8.38715282868354 8.405881056132717
12.017162902884746 12.04083916645892
8.111105709077108 8.127123241217273


output:   0%|          | 0/3 [00:00<?, ?it/s]

8.347938495055848 8.366822084806198
11.992443245158302 12.015302454466111
8.055341962329683 8.070638795223404
starting CV for lr=0.2, epochs=50


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

11.812219550624015 11.838038084447707
10.291598643598137 10.307544317844512
11.832810754797567 11.848515318877617


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.632757318531805 11.658683747435184
10.096322777417951 10.113101668017979
11.669251870245066 11.68565639927024


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.505010536853941 11.5297272516704
9.938095705579 9.954205182673512
11.539068128582652 11.554891999532368
starting CV for lr=0.1, epochs=30


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

31.765856298687257 31.82249295323597
59.05751775206947 59.18218513041553
29.30334550304523 29.357675785320385


output:   0%|          | 0/3 [00:00<?, ?it/s]

31.770102355754634 31.82614274042344
59.07309168631529 59.19561991407065
29.300357365566196 29.354776912691584


output:   0%|          | 0/3 [00:00<?, ?it/s]

31.77053488009046 31.82429188992405
59.06437411707759 59.18381210400407
29.289504346313098 29.341434290335062
starting CV for lr=0.1, epochs=50


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

12.023322047080965 12.04837431804616
18.791613896049874 18.829433643661076
11.373656039988747 11.396841865375446


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.977980594543538 12.003211588542264
18.778291529559553 18.816881025071968
11.323919433487793 11.347897963294592


output:   0%|          | 0/3 [00:00<?, ?it/s]

11.949779830117977 11.974424828669582
18.761221716434083 18.798803729767606
11.285661303836662 11.30858257973994
starting CV for lr=0.05, epochs=30


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

79.70126410938519 79.89961865208976
201.1735758639963 201.97150612731696
71.83495432748586 72.03438070344858


output:   0%|          | 0/3 [00:00<?, ?it/s]

79.72695893702554 79.91875205914478
201.287551235474 202.06282058993784
71.85312776257919 72.04287652489406


output:   0%|          | 0/3 [00:00<?, ?it/s]

79.72480013552129 79.91178261702059
201.27903128847737 202.02935890003894
71.85530342577425 72.04323631675992
starting CV for lr=0.05, epochs=50


CV:   0%|          | 0/3 [00:00<?, ?it/s]

output:   0%|          | 0/3 [00:00<?, ?it/s]

41.619596427744554 41.703630329127364
82.10633357012959 82.32092087444894
38.19210554150568 38.274709389802


output:   0%|          | 0/3 [00:00<?, ?it/s]

41.620335858171806 41.70326557159341
82.13950720499591 82.35141806570277
38.189108827927456 38.27210558952392


output:   0%|          | 0/3 [00:00<?, ?it/s]

41.614506432323495 41.695968026255734
82.11518336603395 82.32602070424623
38.192820790436755 38.27133151285379
best (lowest) mse:  0.0044761239907867385  with σ= 0.00017116592584453318  and estimated uncertainty  12.785089203233417
with params lr=0.5, epochs=30


In [7]:
display(output_df.sort_values(by='Mean Squared Error', ascending=True))

Unnamed: 0,"Params (lr, epochs)",Mean Squared Error,Standard Deviation,Mean Uncertainty,Median Uncertainty
2,"(0.5, 30)",0.004476,0.000171,12.809201,12.785089
3,"(0.5, 50)",0.004483,0.00017,11.302238,11.276744
5,"(0.2, 50)",0.004485,0.000175,11.675483,11.649996
4,"(0.2, 30)",0.004508,0.000169,8.415484,8.396724
7,"(0.1, 50)",0.004509,0.000165,12.00867,11.983694
0,"(1.0, 30)",0.004512,0.000198,11.129548,11.101901
1,"(1.0, 50)",0.004519,0.000192,11.108617,11.082777
6,"(0.1, 30)",0.004645,0.000182,31.824309,31.768831
9,"(0.05, 50)",0.004741,0.000181,41.700955,41.618146
8,"(0.05, 30)",0.0053,0.000244,79.910051,79.717674


In [8]:
output_df.to_csv('gpr_cv_results/grid_search.csv', index=False)