In [29]:
import jax.numpy as jnp
from jax import random

from jaxgp.utils import Logger
from jaxgp.regression import SparseGPR, ExactGPR
from jaxgp.kernels import RBF

# from jaxgp.jaxgp.utils import Logger
# from jaxgp.jaxgp.regression import SparseGPR, ExactGPR
# from jaxgp.jaxgp.kernels import RBF

In [23]:
def fun(x, noise=0.0, key = random.PRNGKey(0)):
    return (x[:,0]**2 + x[:,1] - 11)**2 / 800.0 + (x[:,0] + x[:,1]**2 -7)**2 / 800.0 + random.normal(key,(len(x),), dtype=jnp.float32)*noise

def grad(x, noise=0.0, key = random.PRNGKey(0)):
    dx1 = 4 * (x[:,0]**2 + x[:,1] - 11) * x[:,0] + 2 * (x[:,0] + x[:,1]**2 -7)
    dx2 = 2 * (x[:,0]**2 + x[:,1] - 11) + 4 * (x[:,0] + x[:,1]**2 -7) * x[:,1]
    return jnp.vstack((dx1, dx2)).T / 800.0 + random.normal(key,x.shape, dtype=jnp.float32)*noise

In [24]:
def create_training_data(num_f_vals, num_d_vals, dims, bounds, noise, fun, grad, seed=0):
    # create new keys and randomly sample the above interval for training features
    key, subkey = random.split(random.PRNGKey(seed))
    x_func = random.uniform(subkey, (num_f_vals, dims), minval=bounds[0], maxval=bounds[1])
    key, subkey = random.split(key)
    x_der = random.uniform(subkey, (num_d_vals, dims), minval=bounds[0], maxval=bounds[1])

    # noise with which to sample the training labels
    key, subkey = random.split(key)
    y_func = fun(x_func,noise, subkey)
    key, subkey = random.split(key)
    y_der = grad(x_der, noise, subkey)

    return (x_func, x_der), jnp.hstack((y_func, y_der.reshape(-1)))

def create_reference_points(X_train, subset_size, seed):
    num_training_points = len(X_train[0]) + len(X_train[1])
    num_ref_points = int((num_training_points)*subset_size + 1)

    key = random.PRNGKey(seed)
    ref_perm = random.permutation(key, num_training_points)[:num_ref_points]
    X_ref = X_train[ref_perm]

    return X_ref

In [27]:
# name of the function
name = "him"
# directory where to save stuff
in_dir = "."

# random stuff
seed = 0
key = random.PRNGKey(seed)

# Interval bounds from which to choose the data points
bounds = jnp.array([[-5.0, -5.0], [5.0, 5.0]])

# How many function and derivative observations should be chosen
list_f_vals = [1, 5, 20, 50]
list_d_vals = [200, 400, 800, 1500, 2000, 3000]
# Dimension of datapoints
dims = 2

# Noise in the data
noise = 0.1

# optimizer type
optimizers = ["L-BFGS-B", "TNC", "SLSQP"]

# Grid on which to evaluate the function
eval_grid = jnp.linspace(bounds[0], bounds[1],100).T
eval_grid = jnp.array(jnp.meshgrid(*eval_grid)).reshape(2,-1).T

# Initial parameters
param_shape = (2,)
param_bounds = (1e-3, 10.0)
kernel = RBF()

# sparsification
sparse = False
subset_size = 0.1

In [28]:
for i, num_f_vals in enumerate(list_f_vals):
    print("-"*80)
    print(f"Number function evals: {num_f_vals}")
    for j, num_d_vals in enumerate(list_d_vals):        
        print("-"*80)    
        print(f"Number derivative evals: {num_d_vals}")

        # create new training data
        X_train, Y_train = create_training_data(num_f_vals, num_d_vals, dims, bounds, noise, fun, grad, i*j)
        # create new initial parameters
        key, subkey = random.split(key)
        init_params = random.uniform(subkey, param_shape, minval=param_bounds[0], maxval=param_bounds[1])
        
        if sparse:
            X_ref = create_reference_points(X_train, subset_size, i*j)

        for optimizer in optimizers:
            print("-"*80)
            print(f"Optimizer: {optimizer}")

            logger = Logger(optimizer)

            if sparse:
                model = SparseGPR(kernel, init_params, noise, X_ref, optimize_method=optimizer, logger=logger)
            else:
                model = ExactGPR(kernel, init_params, noise, optimize_method=optimizer, logger=logger)

            model.train(X_train, Y_train)
            means, stds = model.eval(eval_grid)

            fname = f"{in_dir}/{name}_d{num_d_vals}_f{num_f_vals}_{optimizer}"
            if sparse:
                fname = f"{fname}_sparse{subset_size}"

            jnp.savez(f"{fname}_means.npz", *means)
            # files.download(f"{fname}_means.npz")
            jnp.savez(f"{fname}_stds.npz", *stds)
            # files.download(f"{fname}_stds.npz")
            params = []
            for elem in logger.iters_list:
                params.append(elem)
            jnp.savez(f"{fname}_params.npz", *params)
            # files.download(f"{fname}_params.npz")

(10000, 2)