In [1]:
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import jaxgp.regression as gpr
from jaxgp.kernels import RBF, Linear
import time

from jaxgp.tests import testfunctions, optimizertesting
from jaxgp.utils import Logger

### 1D

In [13]:
# from jax import vmap, jit, grad
# sigma = 0.15

# def f(x):
#     return jnp.exp(-(x / sigma)**2) + 0.5*jnp.exp(-0.5*((x-0.5) / sigma)**2)

# def df(x):
#     return (-2*x*jnp.exp(-(x / sigma)**2) + -0.5*(x-0.5)*jnp.exp(-0.5*((x-0.5) / sigma)**2))/(sigma**2)

In [14]:
# seed = 0

# noise = 0.1
# ranges = jnp.array([0.0, 1.0])
# num_datapoints = 1000
# X_train = jnp.linspace(*ranges, num_datapoints).reshape(-1,1)
# y = f(X_train)
# dy = df(X_train)
# Y_train = jnp.hstack((y, dy))

# iters_per_optimizer = 1

# function_set_sizes = [1,]
# # derivative_set_sizes = [2,5,7,10,12,15,20, 100]
# derivative_set_sizes = [10,]

# kernel = RBF()
# param_shape = (2,)
# param_bounds = (1e-3, 10.0)

# grid = jnp.linspace(0,1,100).reshape(-1,1)

In [15]:
# # key = random.PRNGKey(int(time.time()))
# key = random.PRNGKey(0)

# means = []
# stds = []

# for fun_vals in function_set_sizes:
#     for der_vals in derivative_set_sizes:
#         # logger for each pair of function vals and derivative vals
#         logger = Logger(f"f{fun_vals}d{der_vals}")

#         key, subkey = random.split(key)
#         fun_perm = random.permutation(subkey, num_datapoints)[:fun_vals]
#         key, subkey = random.split(key)
#         d1_perm = random.permutation(subkey, num_datapoints)[:der_vals]

#         X_fun = X_train[fun_perm]
#         Y_fun = Y_train[fun_perm,0]
#         X_d1 = X_train[d1_perm]
#         Y_d1 = Y_train[d1_perm,1]

#         X = jnp.vstack((X_fun, X_d1))
#         Y = jnp.hstack((Y_fun, Y_d1))
#         data_split = jnp.array([fun_vals, der_vals])

#         for i in range(iters_per_optimizer):
#             key, subkey = random.split(key)
#             init_params = random.uniform(subkey, param_shape, minval=param_bounds[0], maxval=param_bounds[1])
#             logger.log(f"# iter {i+1}: init params {init_params}")

#             model = gpr.ExactGPR(kernel, init_params, noise, logger=logger)
#             model.train(X, Y, data_split=data_split)
#             m, s = model.eval(grid)
#             means.append(m)
#             stds.append(s)

In [16]:
# plt.plot(X_train, Y_train[:,0])
# for der,mean in zip(derivative_set_sizes,means):
#     plt.plot(grid, mean, label=f"{der} der obs")

# plt.grid()
# plt.legend()

### 2D

In [18]:
optimizers = ["L-BFGS-B", "TNC", "SLSQP"]#, "Nelder-Mead", "Powell", "trust-constr"]

seed = 0
num_gridpoints = jnp.array([100,100])
noise = 0.1

iters_per_optimizer = 5

num_f_vals = 20
num_d_vals = 100

kernel = RBF(3)
param_shape = (3,)
param_bounds = (1e-3, 10.0)

names = ["franke", "himmelblau", "easom", "ackley", "sin"]

functions = [testfunctions.franke, testfunctions.himmelblau, testfunctions.easom, testfunctions.ackley, testfunctions.sin2d]

ranges = [(jnp.array([0.0,1.0]), jnp.array([0.0,1.0])), 
          (jnp.array([-5.0,5.0]), jnp.array([-5.0,5.0])),
          (jnp.array([-10.0,10.0]), jnp.array([-10.0,10.0])),
          (jnp.array([-5.0,5.0]), jnp.array([-5.0,5.0])),
          (jnp.array([0.0,2*jnp.pi]), jnp.array([0.0,2*jnp.pi]))]

In [19]:
log_dict = {}

names = ["franke", "himmelblau", "easom", "ackley", "sin"]
functions = [testfunctions.franke, testfunctions.himmelblau, testfunctions.easom, testfunctions.ackley, testfunctions.sin2d]
ranges = [(jnp.array([0.0,1.0]), jnp.array([0.0,1.0])), 
          (jnp.array([-5.0,5.0]), jnp.array([-5.0,5.0])),
          (jnp.array([-10.0,10.0]), jnp.array([-10.0,10.0])),
          (jnp.array([-5.0,5.0]), jnp.array([-5.0,5.0])),
          (jnp.array([0.0,2*jnp.pi]), jnp.array([0.0,2*jnp.pi]))]

for fun, ran, name in zip(functions, ranges, names):
    X_train, Y_train = optimizertesting.create_training_data_2D(seed, num_gridpoints, ran, noise, fun)

    grid1 = jnp.linspace(*ran[0],100)
    grid2 = jnp.linspace(*ran[1],100)
    grid = jnp.array(jnp.meshgrid(grid1, grid1)).reshape(2,-1).T

    log_dict[name] = {}

    for optimizer in optimizers:
        print(f"Optimizer {optimizer}")
        logger = Logger(optimizer)

        means, stds = optimizertesting.create_test_data_2D(X_train=X_train, Y_train=Y_train, num_f_vals=num_f_vals, num_d_vals=num_d_vals,
                                                logger=logger, kernel=kernel, param_bounds=param_bounds, param_shape=param_shape, noise=noise, optimizer=optimizer,
                                                iters=iters_per_optimizer, evalgrid=grid, seed=int(time.time()))
        
        jnp.savez(f"./prediction_files/{name}means{optimizer}", *means)
        jnp.savez(f"./prediction_files/{name}stds{optimizer}", *stds)
        log_dict[name][optimizer] = logger.iters_list

Optimizer L-BFGS-B
OptStep(params=DeviceArray([3.4470170e+00, 3.4372774e-03, 3.2955947e+00], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.04120548, dtype=float32, weak_type=True), success=False, status=2, iter_num=3))
OptStep(params=DeviceArray([4.852865 , 3.0606432, 0.0094106], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(nan, dtype=float32, weak_type=True), success=False, status=2, iter_num=4))
OptStep(params=DeviceArray([7.613104 , 8.64159  , 5.0780387], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(nan, dtype=float32, weak_type=True), success=False, status=2, iter_num=1))
OptStep(params=DeviceArray([5.8078685e+00, 1.8793218e+00, 1.0000000e-03], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.04120548, dtype=float32, weak_type=True), success=False, status=2, iter_num=3))
OptStep(params=DeviceArray([0.054658  , 0.16858527, 0.18628271], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.00937389, dtype=float32, we



OptStep(params=DeviceArray([7.2194009e+00, 3.1388154e+00, 1.0000000e-03], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.06334427, dtype=float32, weak_type=True), success=True, status=0, iter_num=16))
OptStep(params=DeviceArray([9.621718  , 0.22480027, 0.29448816], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.00562126, dtype=float32, weak_type=True), success=True, status=0, iter_num=5))




OptStep(params=DeviceArray([9.062745  , 9.21687   , 0.12443336], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.20317098, dtype=float32, weak_type=True), success=True, status=0, iter_num=20))




OptStep(params=DeviceArray([7.0362573e+00, 4.5591097e+00, 1.0384790e-03], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(nan, dtype=float32, weak_type=True), success=False, status=5, iter_num=9))




OptStep(params=DeviceArray([7.6072702e+00, 5.5085268e+00, 1.3537814e-03], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.08193273, dtype=float32, weak_type=True), success=True, status=0, iter_num=9))
Optimizer L-BFGS-B
OptStep(params=DeviceArray([4.5974770e+04, 1.7813433e+00, 1.7397712e+00], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.02644118, dtype=float32, weak_type=True), success=True, status=0, iter_num=45))
OptStep(params=DeviceArray([1.4960025e+04, 1.4150007e+00, 1.3618716e+00], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.0384546, dtype=float32, weak_type=True), success=True, status=0, iter_num=40))
OptStep(params=DeviceArray([3.9594707e+04, 1.9188782e+00, 1.7220256e+00], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.02799315, dtype=float32, weak_type=True), success=True, status=0, iter_num=37))
OptStep(params=DeviceArray([2.9822912e+04, 1.7217273e+00, 1.7739794e+00], dtype=float32), state=ScipyMinimizeInfo(fun_v



OptStep(params=DeviceArray([5.4326441e+04, 2.1410148e+00, 2.0517983e+00], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.02537092, dtype=float32, weak_type=True), success=True, status=0, iter_num=49))
OptStep(params=DeviceArray([5.7613867e+04, 2.1647227e+00, 1.6945231e+00], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.02533889, dtype=float32, weak_type=True), success=True, status=0, iter_num=62))
OptStep(params=DeviceArray([3.4772261e+03, 6.9429404e-01, 1.9717389e-01], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.06285886, dtype=float32, weak_type=True), success=True, status=0, iter_num=49))




OptStep(params=DeviceArray([5.6558164e+04, 1.9015549e+00, 1.8510669e+00], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.02484203, dtype=float32, weak_type=True), success=True, status=0, iter_num=59))
Optimizer L-BFGS-B
OptStep(params=DeviceArray([1.0000000e-03, 1.8078594e+01, 1.4050544e+01], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01777602, dtype=float32, weak_type=True), success=True, status=0, iter_num=7))
OptStep(params=DeviceArray([1.0000000e-03, 9.4393902e+00, 6.9908743e+00], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01777841, dtype=float32, weak_type=True), success=True, status=0, iter_num=7))
OptStep(params=DeviceArray([1.0000000e-03, 1.0937326e+01, 6.3914599e+00], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01777956, dtype=float32, weak_type=True), success=True, status=0, iter_num=7))
OptStep(params=DeviceArray([1.0000000e-03, 5.2541356e+00, 7.9779072e+00], dtype=float32), state=ScipyMinimizeInfo(fun



OptStep(params=DeviceArray([1.0000000e-03, 1.9137666e+02, 2.1978125e+01], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01753219, dtype=float32, weak_type=True), success=True, status=0, iter_num=20))
OptStep(params=DeviceArray([6.920517e-03, 3.037643e+00, 8.516461e+00], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01765677, dtype=float32, weak_type=True), success=True, status=0, iter_num=8))
OptStep(params=DeviceArray([ 0.0138563,  7.8575034, 11.311414 ], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.0176114, dtype=float32, weak_type=True), success=True, status=0, iter_num=15))
OptStep(params=DeviceArray([4.568066 , 9.208293 , 4.2315516], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01589334, dtype=float32, weak_type=True), success=True, status=0, iter_num=1))
Optimizer L-BFGS-B
OptStep(params=DeviceArray([35.881165  ,  0.55907923,  0.61841625], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.02593165, 



OptStep(params=DeviceArray([37.727367  ,  0.5745152 ,  0.67891794], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.02528484, dtype=float32, weak_type=True), success=True, status=0, iter_num=36))




OptStep(params=DeviceArray([10.971128 ,  0.4688965,  0.4911654], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.02791846, dtype=float32, weak_type=True), success=True, status=0, iter_num=31))
Optimizer L-BFGS-B
OptStep(params=DeviceArray([8.780338 , 1.1745005, 2.151539 ], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.0127355, dtype=float32, weak_type=True), success=True, status=0, iter_num=10))
OptStep(params=DeviceArray([3.5295033, 1.0750915, 1.949005 ], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01284959, dtype=float32, weak_type=True), success=True, status=0, iter_num=36))
OptStep(params=DeviceArray([3.5174818, 1.0734063, 1.9510517], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01284961, dtype=float32, weak_type=True), success=True, status=0, iter_num=20))
OptStep(params=DeviceArray([5.4342446, 1.123    , 2.047586 ], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01282274, dtype=float32, weak_type=



OptStep(params=DeviceArray([4.364044 , 1.0404484, 2.09354  ], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01172265, dtype=float32, weak_type=True), success=True, status=0, iter_num=8))
OptStep(params=DeviceArray([3.1055763, 1.0402701, 1.9282653], dtype=float32), state=ScipyMinimizeInfo(fun_val=DeviceArray(-0.01177331, dtype=float32, weak_type=True), success=True, status=0, iter_num=21))


In [20]:
# means = mean_dict["SLSQP"]
# plt.pcolormesh(grid1, grid2, means[0].reshape(len(grid1), len(grid2)))
# plt.colorbar()

In [21]:
# plt.pcolormesh(grid1, grid2, Y_train[:,0].reshape(100,100))# - means[0].reshape(len(grid1), len(grid2)))
# plt.colorbar()

In [22]:
# markers = ["o", "x", "+"]

# for marker, (key, value) in zip(markers,log_dict.items()):
#     for i,elem in enumerate(value):
#         nums = jnp.arange(1,len(elem[1])+1)
#         plt.scatter(nums, elem[1],label=f"{key}$_{i}$", marker=marker)

# plt.legend()
# plt.grid()