In [None]:
import jax.numpy as jnp
from jax import random, jit

In [None]:
! git clone https://github.com/LukasEin/jaxgp.git

In [None]:
! pip install jaxopt

In [None]:
from timeit import repeat, timeit

from jaxgp.jaxgp.covar import full_covariance_matrix, sparse_covariance_matrix
from jaxgp.jaxgp.kernels import RBF


def fun(x, noise=0.0, key = random.PRNGKey(0)):
    return (x[:,0]**2 + x[:,1] - 11)**2 / 800.0 + (x[:,0] + x[:,1]**2 -7)**2 / 800.0 + random.normal(key,(len(x),), dtype=jnp.float32)*noise

def grad(x, noise=0.0, key = random.PRNGKey(0)):
    dx1 = 4 * (x[:,0]**2 + x[:,1] - 11) * x[:,0] + 2 * (x[:,0] + x[:,1]**2 -7)
    dx2 = 2 * (x[:,0]**2 + x[:,1] - 11) + 4 * (x[:,0] + x[:,1]**2 -7) * x[:,1]

    return jnp.vstack((dx1, dx2)).T / 800.0 + random.normal(key,x.shape, dtype=jnp.float32)*noise

# Constants
BOUNDS = jnp.array([-5.0, 5.0])
NUM_F_VALS = 1
KERNEL = RBF()
KERNEL_PARAMS = jnp.ones(2)*jnp.log(2)
NOISE = 0.02

# Number of repeats in test
REPEAT = 10

def _train_data(num_d_vals):
    # initial seed for the pseudo random key generation
    seed = 3

    # create new keys and randomly sample the above interval for training features
    key, subkey = random.split(random.PRNGKey(seed))
    x_func = random.uniform(subkey, (NUM_F_VALS, 2), minval=BOUNDS[0], maxval=BOUNDS[1])
    key, subkey = random.split(key)
    x_der = random.uniform(subkey, (num_d_vals,2), minval=BOUNDS[0], maxval=BOUNDS[1])

    X_split = [x_func,x_der]

    key, subkey = random.split(key)
    y_func = fun(x_func, NOISE, subkey)
    key, subkey = random.split(key)
    y_der = grad(x_der, NOISE, subkey)

    Y_train = jnp.hstack((y_func, y_der.reshape(-1)))

    return X_split, Y_train

def ref_from_data(X_split, num_ref_points):
    key = random.PRNGKey(0)
    key, subkey = random.split(key)
    X_ref_rand = random.permutation(subkey, jnp.vstack(X_split))[:num_ref_points]

    return X_ref_rand

def full_timing(start, stop=None, step=None):
    times = []
    if stop is None or step is None:
        iterator = start
    else:
        iterator = range(start, stop, step)

    for num in iterator:
        X_train, Y_train = _train_data(num)

        def test():
            X = jit(full_covariance_matrix)(X_train, Y_train, KERNEL, KERNEL_PARAMS, NOISE)

        times.append(repeat(test, number=REPEAT)[1:])

    times = jnp.array(times)
    avg_times = jnp.mean(times, axis=1) / REPEAT
    jnp.save(f"./data/full_time_{start}_{stop}_{step}", avg_times)

def sparse_timing_fixed_ref(start, stop=None, step=None, num_ref_points=50):
    times = []
    if stop is None or step is None:
        iterator = start
    else:
        iterator = range(start, stop, step)

    for num in iterator:
        X_train, Y_train = _train_data(num)
        X_ref = ref_from_data(X_train, num_ref_points)

        def test():
            X = jit(sparse_covariance_matrix)(X_train, Y_train, X_ref, KERNEL, KERNEL_PARAMS, NOISE)

        times.append(repeat(test, number=REPEAT)[1:])

    times = jnp.array(times)
    avg_times = jnp.mean(times, axis=1) / REPEAT
    jnp.save(f"./data/sparse_time_{start}_{stop}_{step}_ref{num_ref_points}", avg_times)

def sparse_timing_fixed_percent(start, stop=None, step=None, percent=0.1):
    times = []
    if stop is None or step is None:
        iterator = start
    else:
        iterator = range(start, stop, step)

    for num in iterator:
        X_train, Y_train = _train_data(num)
        num_ref_points = int((len(X_train[0]) + len(X_train[1]))*percent) + 1
        X_ref = ref_from_data(X_train, num_ref_points)

        def test():
            X = jit(sparse_covariance_matrix)(X_train, Y_train, X_ref, KERNEL, KERNEL_PARAMS, NOISE)

        times.append(repeat(test, number=REPEAT)[1:])

    times = jnp.array(times)
    avg_times = jnp.mean(times, axis=1) / REPEAT
    jnp.save(f"./data/sparse_time_{start}_{stop}_{step}_{percent}", avg_times)

def sparse_timing_fixed_max(percentages, num_data):
    X_train, Y_train = _train_data(num_data)
    times = []

    for percent in percentages:
        num_ref_points = int((len(X_train[0]) + len(X_train[1]))*percent) + 1
        X_ref = ref_from_data(X_train, num_ref_points)

        def test():
            X = jit(sparse_covariance_matrix)(X_train, Y_train, X_ref, KERNEL, KERNEL_PARAMS, NOISE)

        times.append(repeat(test, number=REPEAT)[1:])

    times = jnp.array(times)
    avg_times = jnp.mean(times, axis=1) / REPEAT
    jnp.save(f"./data/sparse_time_{num_data}", avg_times)

In [None]:
! mkdir ./data

In [None]:
point_list = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384]

In [None]:
full_timing(point_list)

In [None]:
sparse_timing_fixed_ref(point_list, num_ref_points=128)

In [None]:
sparse_timing_fixed_max(point_list[:-2], 4096)

In [None]:
! zip -r data.zip data/ 