In [1]:
import sys
sys.path.append('./JaxBo')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from fb_gp import saas_fbgp
import numpy as np
import time
import jax.numpy as jnp
from jax import random,vmap, grad
from jax.lax import while_loop, map
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from acquisition import EI, IPV, optim_scipy_bh
import scipy.optimize
from scipy.stats import qmc
from jaxns import NestedSampler
import corner
from nested_sampler import nested_sampling_jaxns, nested_sampling_Dy
from getdist import plots,MCSamples,loadMCSamples


matplotlib.rc('font', size=16,family='serif')
matplotlib.rc('legend', fontsize=16)
matplotlib.rc('figure',figsize=(6,3))

ModuleNotFoundError: No module named 'fb_gp'

In [None]:
np.random.seed(10004118) # fixed for reproducibility
# train_x = np.random.uniform(0,1,(12,2))
ninit = 16
ndim = 4
train_x = qmc.Sobol(ndim, scramble=True).random(ninit)
f_mean = np.array([0.41,0.63,0.58,0.45])[:ndim]
cov1 = 0.1*np.eye(ndim) + 0.2*np.random.uniform(0.,0.5,(ndim,ndim))
cov = 0.05*(cov1 + np.transpose(cov1))
print(cov)

from scipy.stats import multivariate_normal

mnorm =  multivariate_normal(mean = f_mean,cov = cov)

f = lambda x: mnorm.logpdf(x).reshape(-1,1) # -0.5*np.sum((x - f_mean)**2,axis=-1,keepdims=True)/ f_std**2 

train_y = f(train_x)
print(train_x.shape,train_y.shape)

train_yvar = 1e-6*jnp.ones_like(train_y)

print("Testing lightweight implementation")

gp = saas_fbgp(train_x,train_y,noise=1e-6)
seed = 0
rng_key, _ = random.split(random.PRNGKey(seed), 2)
gp.fit(rng_key,warmup_steps=512,num_samples=512,thinning=16,verbose=True)



In [None]:
test_x = np.random.uniform(0,1,size=(9999,4))

In [None]:
# direct method, not recommended
start = time.time()
mean1, var1 =  gp.posterior(test_x, single=True)
print(f"direct took {time.time()-start:.4f} s")

# with jax.lax.map
start = time.time()
f = lambda x: gp.posterior(x,single=True)
mean2, var2 = map(f,test_x,batch_size=10)
print(f"Map took {time.time()-start:.4f} s")


# in batches
start = time.time()
num_inputs = len(test_x)
batch_size = 10
num_batches = (num_inputs + batch_size - 1 ) // batch_size
f = lambda x: gp.posterior(x,single=True)
input_arrays = (test_x,)
batch_idxs = [jnp.arange( i*batch_size, min( (i+1)*batch_size,num_inputs  )) for i in range(num_batches)]
# print(batch_idxs)
res = [f(*tuple([arr[idx] for arr in input_arrays])) for idx in batch_idxs]
print(len(res))
nres = len(res[0])
# now combine results across batches and function outputs to return a tuple (num_outputs, num_inputs, ...)
results = tuple(jnp.concatenate([x[i] for x in res]) for i in range(nres))
print(f"Batchwise took {time.time()-start:.4f} s")

mean3, var3 = results

print(np.allclose(mean1,mean2.squeeze(-1)))
print(np.allclose(mean1,mean3))

print(np.allclose(var1,var2.squeeze(-1)))
print(np.allclose(var1,var3))

In [None]:
plt.plot(range(len(test_x)),mean1-mean2.squeeze(-1))
plt.plot(range(len(test_x)),mean3-mean1)
