In [3]:
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax as ox
from jax import jit
from jax.config import config
from jaxutils import Dataset
import jaxkern as jk
from jax.scipy.stats import norm
from jax.scipy.optimize import minimize
import gpjax as gpx
import time
import pickle as pkl

In [4]:
# Define the latent function to optimize
def LatentFunction(x, sample_key, kernel):
    
    prior = gpx.Prior(kernel=kernel, mean_function = gpx.mean_functions.Zero())
    prior_dist = prior(x)
    
    return prior_dist.sample(seed=sample_key, sample_shape=(1,)).T

In [5]:
def ExpectedLoss(mean, std, alt, cost):
    '''
    This function calculates E[(g-V)^+]-c for given values g, c and a normal random variable V~N(mu, sigma^2)
    Input: mean and standard deviation the normal random variable V, alternative g, cost c
    '''
    
    z = (alt - mean) / std
    
    EI = (alt - mean) * jnp.where(std == 0, 0.0, jnp.nan_to_num(norm.cdf(z))) + std * jnp.nan_to_num(norm.pdf(z)) 
    
    exp_loss = EI - cost
    
    return exp_loss

In [6]:
# Define the acquisition function of the expected improvement policy
def EI_acq(mean, std, best):  
    '''
    Combine the difference between the best observed value and the mean predicted value with the uncertainty (standard deviation) of the predictions
    Favor points where there is a higher potential for improvement compared to the current best observed value
    Return the point with maximum expected improvement
    '''
    
    size = jnp.size(mean)
    
    expected_improvement = ExpectedLoss(mean=mean, std=std, alt=best, cost=0)
        
    return expected_improvement

In [7]:
# Define the acquisition function of the Gittins policy
def Gittins_acq(mean, std, cost, bound):
    '''
    Gittins index minimizes the difference between the expected improvement and the query cost 
    Favor points where there is a smaller Gittins index compared to the current best observed value
    Return the point with minimum Gittins index
    '''

    size = jnp.size(mean)
    
    l = bound[:,0]*jnp.ones(size)
    h = bound[:,1]*jnp.ones(size)
    m = (h+l)/2
    
    while jnp.nanmax(jnp.abs(ExpectedLoss(mean, std, m, cost))) >= eps:
        sgn_l = jnp.sign(ExpectedLoss(mean, std, l, cost))
        sgn_m = jnp.sign(ExpectedLoss(mean, std, m, cost))
        sgn_h = jnp.sign(ExpectedLoss(mean, std, h, cost))
        l = jnp.where(sgn_m == 0, m, l)
        h = jnp.where(sgn_m == 0, m, h)
        l = jnp.where(sgn_l == sgn_m, m, l)
        h = jnp.where(sgn_h == sgn_m, m, h)
        m = (h+l)/2
        
    return m

In [8]:
def UCB_acq(mean, std):
    
    return mean - beta*std

In [9]:
def BayesianOptimization(seed, policy, num_points, num_iterations, query_cost, kernel):
    # Generate a random key
    key = jr.PRNGKey(seed)

    # Define the search space
    bounds = jnp.array([[-8, 8]])

    # Define the grid points
    x_plot = jnp.linspace(bounds[:, 0], bounds[:, 1], num_points).reshape(-1, 1)

    # Generate the latent function
    (k1,key) = jr.split(key)
    y_plot = LatentFunction(x=x_plot, sample_key=k1, kernel=kernel)
    
    # Find the global minimum of the latent function in the search space
    global_minimum_value = jnp.min(y_plot)
    global_minimum_point = x_plot[jnp.argmin(y_plot)]

    # Initialize the best observed value, stopping time, and arrays for storing visited points and objective values
    i = 0
    X = jnp.array([]).reshape(-1,1)
    y = jnp.array([]).reshape(-1,1)
    visited = jnp.array([], dtype=int)
    best_observed_point = jnp.array([])
    best_observed_value = jnp.array([0])
    best_observed_arr = best_observed_value
    acq_m = jnp.array([])
    gittins_m = jnp.array([])
    
    # Initialize the Gaussian process model
    prior = gpx.Prior(kernel=kernel, mean_function = gpx.mean_functions.Zero())
    
    # Predict the mean and the standard deviation of the GP model on the plot range
    prior_dist = prior(x_plot)
    prior_mean = prior_dist.mean()
    prior_std = prior_dist.stddev()
    
    # Select the next point to evaluate using the EI policy
    if policy == 'EI':
        expected_improvement = EI_acq(mean=prior_mean, std=prior_std, best=best_observed_value)
        next_index, ei_max = jnp.nanargmax(expected_improvement), jnp.nanmax(expected_improvement)
        acq_m = jnp.append(acq_m, ei_max)
        
    # Select the next point to evaluate using the Gittins policy
    if policy == 'Gittins':
        gittins_index = Gittins_acq(mean=prior_mean, std=prior_std, cost=query_cost, bound=2*bounds)
        next_index, gittins_min = jnp.nanargmin(gittins_index), jnp.nanmin(gittins_index)
        acq_m = jnp.append(acq_m, gittins_min)
        
    # Select the next point to evaluate using UCB
    if policy == 'UCB':
        ucb = UCB_acq(mean=prior_mean, std=prior_std)
        next_index = jnp.nanargmin(ucb)
        
    # Select the next point to evaluate using the UCB policy with EI filter
    if policy == 'UCB+EI':
        ucb = UCB_acq(mean=prior_mean, std=prior_std)
        expected_improvement = EI_acq(mean=prior_mean, std=prior_std, best=best_observed_value)
        ei_max = jnp.nanmax(expected_improvement)
        
        if ei_max < query_cost:
            stopping_time = jnp.minimum(stopping_time, i)
            next_index = jnp.nanargmin(ucb)
            
        else: 
            candidates = jnp.where(expected_improvement>=query_cost)[0]
            next_index = candidates[jnp.nanargmin(ucb[candidates])]
        
        acq_m = jnp.append(acq_m, ei_max)
        
    # Select the next point to evaluate using Thompson Sampling
    if policy == 'TS':
        (k2,key) = jr.split(key)
        thompson_sample = prior_dist.sample(seed=k2, sample_shape=(1,)).T

        next_index = jnp.nanargmin(thompson_sample)
            
    # Select the next point to evaluate using Thompson Sampling with EI filter
    if policy == 'TS+EI':
        (k2,key) = jr.split(key)
        thompson_sample = prior_dist.sample(seed=k2, sample_shape=(1,)).T
        expected_improvement = EI_acq(mean=prior_mean, std=prior_std, best=best_observed_value)
        ei_max = jnp.nanmax(expected_improvement)
        
        if ei_max < query_cost:
            next_index = jnp.argmin(thompson_sample)

        else: 
            candidates = jnp.where(expected_improvement>=query_cost)[0]
            next_index = candidates[jnp.argmin(thompson_sample[candidates])]
        
        acq_m = jnp.append(acq_m, ei_max)
        
    # Select the next point to evaluate using the Surrogate policy
    if policy == 'Surrogate':
        gittins_index = Gittins_acq(mean=prior_mean, std=prior_std, cost=query_cost, bound=bounds)
        (k2,key) = jr.split(key)
        thompson_sample = prior_dist.sample(seed=k2, sample_shape=(num_samples,))

        current_surrogate_matrix = jnp.maximum(gittins_index, thompson_sample)
        current_surrogate_sample =  jnp.min(current_surrogate_matrix, axis=1)
        new_surrogate_matrix = jnp.minimum(current_surrogate_sample, thompson_sample.T).T
        new_surrogate_price = jnp.mean(new_surrogate_matrix, axis=0)
        next_index = jnp.nanargmin(new_surrogate_price)
            
        acq_m = jnp.append(acq_m, jnp.nanmin(new_surrogate_price))
        gittins_m = jnp.append(gittins_m, jnp.nanmin(gittins_index))
        
    next_point = x_plot[next_index]
    visited = jnp.append(visited, next_index)
        
    # Construct the posterior
    likelihood = gpx.Gaussian(num_datapoints=0, obs_noise=False)
    posterior = prior * likelihood
    
    # Perform the Bayesian optimization
    while i < num_iterations:
        
        i = i + 1
        
        # Evaluate the latent function at the selected point
        next_value = y_plot[next_index]
        
        # Update the best observed value and point
        if next_value < best_observed_value:
            best_observed_value = next_value
            best_observed_point = next_point
        best_observed_arr = jnp.append(best_observed_arr, best_observed_value)
        
        # Update the Gaussian process model with the new observation
        X = jnp.vstack([X, next_point])
        y = jnp.vstack([y, next_value])
        D = Dataset(X=X, y=y)
        
        # Construct the posterior
        likelihood = gpx.Gaussian(num_datapoints=D.n, obs_noise=False)
        posterior = prior * likelihood
        
        # Predict the mean and standard deviation of the GP model on the plot range
        y_dist = posterior(x_plot, train_data=D)
        y_mean = y_dist.mean()
        y_std = y_dist.stddev()
        if jnp.isnan(y_std).any():
            print("instance "+str(seed)+" has nan in std at iteration "+str(i))
            
        # Select the next point to evaluate using the EI policy
        if policy == 'EI':
            expected_improvement = EI_acq(mean=y_mean, std=y_std, best=best_observed_value)
            next_index, ei_max = jnp.nanargmax(expected_improvement), jnp.nanmax(expected_improvement)
            acq_m = jnp.append(acq_m, ei_max)
            
        # Select the next point to evaluate using the Gittins policy
        if policy == 'Gittins':
            gittins_index = Gittins_acq(mean=y_mean, std=y_std, cost=query_cost, bound=2*bounds)
            candidates = jnp.setdiff1d(jnp.arange(num_points), visited)
            next_index, gittins_min = candidates[jnp.nanargmin(gittins_index[candidates])], jnp.nanmin(gittins_index[candidates])
            
            acq_m = jnp.append(acq_m, gittins_min)
            
        # Select the next point to evaluate using UCB
        if policy == 'UCB':
            ucb = UCB_acq(mean=y_mean, std=y_std)
            next_index = jnp.nanargmin(ucb)
            
        # Select the next point to evaluate using Thompson Sampling
        if policy == 'TS':
            (k2,key) = jr.split(key)
            thompson_sample = y_dist.sample(seed=k2, sample_shape=(1,)).T
            next_index = jnp.nanargmin(thompson_sample)
            
        # Select the next point to evaluate using UCB with EI filter
        if policy == 'UCB+EI':
            ucb = UCB_acq(mean=y_mean, std=y_std)
            expected_improvement = EI_acq(mean=y_mean, std=y_std, best=best_observed_value)
            ei_max = jnp.nanmax(expected_improvement)
                        
            if ei_max < query_cost:
                next_index = jnp.nanargmin(ucb)
            
            else: 
                candidates = jnp.where(expected_improvement>=query_cost)[0]
                next_index = candidates[jnp.nanargmin(ucb[candidates])]
            
            acq_m = jnp.append(acq_m, ei_max)
            
        # Select the next point to evaluate using Thompson Sampling
        if policy == 'TS+EI':
            (k2,key) = jr.split(key)
            thompson_sample = y_dist.sample(seed=k2, sample_shape=(1,)).T
            expected_improvement = EI_acq(mean=y_mean, std=y_std, best=best_observed_value)
            ei_max = jnp.nanmax(expected_improvement)
            
            if ei_max < query_cost:
                next_index = jnp.argmin(thompson_sample)

            else: 
                candidates = jnp.where(expected_improvement>=query_cost)[0]
                next_index = candidates[jnp.argmin(thompson_sample[candidates])]
            
            acq_m = jnp.append(acq_m, ei_max)
            
        # Select the next point to evaluate using the Surrogate policy
        if policy == 'Surrogate':
            gittins_index = Gittins_acq(mean=y_mean, std=y_std, cost=query_cost, bound=bounds)
            gittins_index = gittins_index.at[visited].set(gittins_index[visited]-query_cost)
            
            (k2,key) = jr.split(key)
            thompson_sample = y_dist.sample(seed=k2, sample_shape=(num_samples,))

            current_surrogate_matrix = jnp.maximum(gittins_index, thompson_sample)
            current_surrogate_sample =  jnp.min(current_surrogate_matrix, axis=1)
            new_surrogate_matrix = jnp.minimum(current_surrogate_sample, thompson_sample.T).T
            new_surrogate_price = jnp.mean(new_surrogate_matrix, axis=0)
            
            next_index = jnp.nanargmin(new_surrogate_price)
            
            acq_m = jnp.append(acq_m, jnp.nanmin(new_surrogate_price))
            gittins_m = jnp.append(gittins_m, jnp.nanmin(gittins_index))
            
        next_point = x_plot[next_index]
        visited = jnp.append(visited, next_index)
            
#     return global_minimum_value, best_observed_arr - global_minimum_value, acq_m
    return best_observed_arr - global_minimum_value, acq_m, gittins_m
#     return best_observed_arr - global_minimum_value

In [10]:
# Specifiy policy and kernel hyperparameters
num_instances = 400
num_samples = 400
num_points = 1000
num_iterations = 16
query_cost = 0.05
eps = 0.0001
kernel = gpx.kernels.Matern12(lengthscale = jnp.array([1.0]), variance = jnp.array([2.0]))

start = time.time()
print(BayesianOptimization(seed=68, policy='Surrogate', num_points=num_points, num_iterations=num_iterations, query_cost = query_cost, kernel = kernel))
end = time.time()
print("instance running time:", end-start)

(Array([3.3344145, 2.542779 , 2.542779 , 2.542779 , 2.4354377, 2.4354377,
       2.4354377, 2.4354377, 2.4354377, 2.4354377, 2.4354377, 2.4354377,
       2.4354377, 2.4354377, 2.2058485, 2.2058485, 2.2058485],      dtype=float32), Array([-2.0832534, -2.153221 , -2.1305835, -2.117302 , -2.160713 ,
       -2.12031  , -2.079728 , -2.048802 , -1.991608 , -2.0227199,
       -1.9868917, -1.975756 , -1.9606185, -1.947784 , -2.0252683,
       -2.0209954, -1.9594394], dtype=float32), Array([-2.0039062, -2.119751 , -2.119995 , -2.119995 , -2.15271  ,
       -2.15271  , -2.1068115, -2.1067505, -2.0187378, -2.0715942,
       -2.0715942, -2.0715942, -2.0715942, -2.0612183, -2.0612183,
       -2.0612183, -1.9783325], dtype=float32))
instance running time: 29.175341844558716


In [30]:
# Specifiy policy and kernel hyperparameters
num_instances = 400
num_samples = 400
num_points = 1000
num_iterations = 40
query_cost = 0.05
eps = 0.0001
kernel = gpx.kernels.Matern12(lengthscale = jnp.array([1.0]), variance = jnp.array([2.0]))

Surrogate_regret_arr = jnp.empty((0, num_iterations+1))
surrogate_price_min_arr = jnp.empty((0, num_iterations+1))
surrogate_gittins_min_arr = jnp.empty((0, num_iterations+1))

for i in range(num_instances):
    start = time.time()
    
    Surrogate_regret, surrogate_price_min, surrogate_gittins_min = BayesianOptimization(seed=i, policy='Surrogate', num_points=num_points, num_iterations=num_iterations, query_cost = query_cost, kernel = kernel)
    Surrogate_regret_arr = jnp.concatenate((Surrogate_regret_arr, jnp.expand_dims(Surrogate_regret, axis=0)), axis=0)
    surrogate_price_min_arr = jnp.concatenate((surrogate_price_min_arr, jnp.expand_dims(surrogate_price_min, axis=0)), axis=0)
    surrogate_gittins_min_arr = jnp.concatenate((surrogate_gittins_min_arr, jnp.expand_dims(surrogate_gittins_min, axis=0)), axis=0)

    end = time.time()
    print(i, end-start)

pkl.dump(Surrogate_regret_arr, open("Surrogate regrets (Matern12, eps=0.0001, c=0.05).pkl","wb"))
pkl.dump(surrogate_price_min_arr, open("minimum expected surrogate price (Matern12, eps=0.0001, c=0.05).pkl","wb"))
pkl.dump(surrogate_gittins_min_arr, open("minimum surrogate gittins index (Matern12, eps=0.0001, c=0.05).pkl","wb"))

0 11.167484045028687
1 13.15595006942749
2 10.749859094619751
3 12.535704851150513
4 13.3438880443573
5 11.10144567489624
6 11.306470155715942
7 11.364187955856323
8 12.741220951080322
9 11.800350904464722
10 11.796929836273193
11 12.464479923248291
12 12.906428098678589
13 11.325217008590698
14 11.446501970291138
15 11.714879989624023
16 11.674442052841187
17 11.26215410232544
18 11.779779195785522
19 11.41185212135315
20 11.152118682861328
21 11.500181913375854
22 11.88401484489441
23 13.173651933670044
24 16.09651470184326
25 13.350827693939209
26 12.259694814682007
27 13.847230911254883
28 12.105602979660034
29 12.791733026504517
30 12.685868978500366
31 11.162364959716797
32 11.439561367034912
33 11.815701007843018
34 12.097519159317017
35 11.957125186920166
36 12.148561239242554
37 12.550124645233154
38 12.42084288597107
39 11.584143161773682
40 12.623608827590942
41 13.473184823989868
42 13.409212112426758
43 12.32870602607727
44 12.823271751403809
45 12.01823902130127
46 12.523

364 12.141771078109741
365 12.6995530128479
366 12.358787059783936
367 12.175400972366333
368 12.715258121490479
369 14.043200969696045
370 14.037318229675293
371 14.077093124389648
372 12.209269046783447
373 13.01723313331604
374 12.177540063858032
375 11.966993808746338
376 12.209106922149658
377 11.606237888336182
378 13.310523986816406
379 12.243335247039795
380 11.71933126449585
381 12.266383171081543
382 11.737051725387573
383 11.860651016235352
384 12.139063119888306
385 11.693962097167969
386 12.024677276611328
387 12.182634115219116
388 12.665414810180664
389 11.798438310623169
390 14.156555652618408
391 12.807827949523926
392 17.28194308280945
393 12.195995807647705
394 12.626560926437378
395 11.963735103607178
396 12.006555080413818
397 12.010059833526611
398 12.038238048553467
399 12.19471287727356


In [36]:
Surrogate_regret_arr[370]

Array([4.1031523 , 3.4726357 , 3.4726357 , 3.0364294 , 3.0364294 ,
       3.0364294 , 0.4506507 , 0.4506507 , 0.4506507 , 0.4506507 ,
       0.4506507 , 0.4506507 , 0.28393698, 0.10778761, 0.10778761,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        ], dtype=float32)

In [31]:
surrogate_price_min_arr[370]

Array([-2.0798616, -2.1049137, -2.0716991, -2.2079206, -2.1104007,
       -2.219521 , -3.9100766, -3.8304431, -3.7522743, -3.733704 ,
       -3.7301826, -3.705711 , -3.8564153, -4.025452 , -4.0074368,
       -4.1141124, -4.116499 , -4.108186 , -4.11087  , -4.1103253,
       -4.1052876, -4.1050034, -4.1089587, -4.104845 , -4.103776 ,
       -4.1042924, -4.1057906, -4.1058264, -4.104532 , -4.103234 ,
       -4.1031284, -4.10671  , -4.1057115, -4.1031437, -4.103169 ,
       -4.103961 , -4.10308  , -4.1031084, -4.1031017, -4.103865 ,
       -4.10316  ], dtype=float32)