In [2]:
import jax.numpy as jnp
from jax.scipy.stats import norm
from jax.scipy.optimize import minimize
import gpjax as gpx

## Method 1: Quadratic expected loss + BFGS

In [93]:
def Expected_loss(mean, std, alt, cost):
    '''
    This function calculates (E[(g-V)^+]-c)^2 for given values g, c and a normal random variable V~N(mu, sigma^2)
    In Pandora Box problem, g is the outside value or the termination cost, V is the box value, c is the inspection cost
    In Bayesian optimization, g is the best value observed so far, V can be drawn from the predictive distribution, c is the query cost
    '''
    z = (alt-mean) / std
    return jnp.squeeze(((alt-mean) * norm.cdf(z) + std * norm.pdf(z) - cost) ** 2)

In [40]:
def EI(mean, std, best):
    '''
    EI policy finds c^* that minimizes (E[(g-V)^+]-c)^2 given best observed value g and the distribution of normal random variable V
    '''
    size = jnp.size(mean)
    
    c0 = jnp.array([0.0])
    ei = jnp.zeros(size)
    
    for i in jnp.arange(size):
        ei = ei.at[i].set(minimize(lambda c: Expected_loss(mean=mean[i], std=std[i], alt=best[i], cost=c), c0, method="BFGS").x.item())
        
    return ei

In [94]:
def Gittins(mean, std, cost):
    '''
    Gittins policy finds g^* that minimizes (E[(g-V)^+]-c)^2 given inspection/query cost c and the distribution of normal random variable V
    '''
    size = jnp.size(mean)
    
    g0 = jnp.array([0.0])
    gi = jnp.zeros(size)
    
    for i in jnp.arange(size):
        gi = gi.at[i].set(minimize(lambda g: Expected_loss(mean=mean[i], std=std[i], alt=g, cost=cost[i]), g0, method="BFGS").x.item())
    
    return gi

In [5]:
mu=jnp.zeros(2)
sigma=jnp.ones(2)
c=jnp.ones(2)
g=jnp.ones(2)


# Find the expected improvement and Gittins index by optimizing the objective function
gi = Gittins(mean=mu, std=sigma, cost=c)
ei = EI(mean=mu,std=sigma,best=g)

# Print the result
print("The solution for Gittins Index g* is:", gi)
print("The expected improvement is:", ei)

The solution for Gittins Index g* is: [0.89947134 0.89947134]
The expected improvement is: [1.0833155 1.0833155]


In [95]:
mu=jnp.array([0])
sigma=jnp.array([1.0])
c=jnp.array([0.001])


# Find the expected improvement and Gittins index by optimizing the objective function
gi = Gittins(mean=mu, std=sigma, cost=c)
# ei = EI(mean=mu,std=sigma,best=g)

# Print the result
print("The solution for Gittins Index g* is:", gi)
print("The expected loss of Gittin Index g* is:", Expected_loss(mean=mu, std=sigma, alt=gi, cost=c))
# print("The expected improvement is:", ei)

The solution for Gittins Index g* is: [-2.5976803]
The expected loss of Gittin Index g* is: 2.2536209e-07


## Method 2: Abs expected loss + Bisection

In [87]:
def ExpectedLoss(mean, std, alt, cost):
    '''
    This function calculates E[(g-V)^+]-c for given values g, c and a normal random variable V~N(mu, sigma^2)
    Input: mean and standard deviation the normal random variable V, alternative g, cost c
    '''
    
    z = (alt - mean) / std
    
    EI = (alt - mean) * jnp.where(std == 0, 0.0, jnp.nan_to_num(norm.cdf(z))) + std * jnp.nan_to_num(norm.pdf(z)) 
    
    exp_loss = EI - cost
    
    return exp_loss

In [88]:
# Define the acquisition function of the Gittins policy
def Gittins_acq(mean, std, cost):
    '''
    Gittins index minimizes the difference between the expected improvement and the query cost 
    Favor points where there is a smaller Gittins index compared to the current best observed value
    Return the point with minimum Gittins index
    '''

    size = jnp.size(mean)
    
    l = bound[:,0]*jnp.ones(size)
    h = bound[:,1]*jnp.ones(size)
    m = (h+l)/2
    
    # Bisection method
    while jnp.nanmax(jnp.abs(ExpectedLoss(mean, std, m, cost))) >= eps:
        sgn_l = jnp.sign(ExpectedLoss(mean, std, l, cost))
        sgn_m = jnp.sign(ExpectedLoss(mean, std, m, cost))
        sgn_h = jnp.sign(ExpectedLoss(mean, std, h, cost))
        l = jnp.where(sgn_m == 0, m, l)
        h = jnp.where(sgn_m == 0, m, h)
        l = jnp.where(sgn_l == sgn_m, m, l)
        h = jnp.where(sgn_h == sgn_m, m, h)
        m = (h+l)/2
        
    return m

In [92]:
mu=jnp.array([0])
sigma=jnp.array([1.0])
c=jnp.array([0.001])
bound = jnp.array([[-10, 10]])
eps = 0.00001

# Find the expected improvement and Gittins index by optimizing the objective function
gi = Gittins_acq(mean=mu, std=sigma, cost=c)

# Print the result
print("The solution for Gittins Index g* is:", gi)
print("The expected loss of Gittin Index g* is:", ExpectedLoss(mean=mu, std=sigma, alt=gi, cost=c))

The solution for Gittins Index g* is: [-2.7148438]
The expected loss of Gittin Index g* is: [9.7734155e-06]
