In [1]:
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import jaxgp.regression as gpr
from jaxgp.kernels import RBF, Linear
import timeit

# Full GPR

## Benchmark memory of fitmatrix and fitvector

## Benchmarking with a constant function

In [2]:
data_sizes = [100, 200, 500, 1000, 2000, 3000, 4000, 5000]

kernel_1 = RBF()
kernel_2 = Linear()

### Timing of the training

In [3]:
for elem in data_sizes:
    data_split = (elem, )

    X_data = jnp.linspace(0.0, 1.0, data_split[0]).reshape(-1,1)
    Y_data =jnp.ones(data_split[0]).reshape(-1,1)

    def test_RBF():
        model = gpr.ExactGPR(kernel=kernel_1, data_split=data_split, noise=0.1)
        model.train(X_data=X_data, Y_data=Y_data)

    print(timeit.repeat(test_RBF, repeat=5, number=1))



[0.7717570389995672, 0.5988218300008157, 0.5611509289992682, 0.5543549270005315, 0.5437723270006245]
[0.570482028000697, 0.5201093259984191, 0.5697450280003977, 0.6846777339997061, 0.504502425001192]
[0.6331115319990204, 0.686985033999008, 0.6604375720016833, 0.651841343000342, 0.727478560000236]
[0.6877459509996697, 0.7139661569999589, 0.7595390660007979, 0.7005478540013428, 0.7348133610012155]
[1.0828829380006937, 0.9623849110012088, 0.9885806600013893, 0.9686283850005566, 0.9552003839999088]
[1.8677035640012036, 1.7806573559992103, 1.715139451000141, 1.649879239999791, 1.8217385799998738]
[3.0275515649991576, 2.90103315199849, 2.992188336998879, 3.1312867480010027, 3.4582399299997633]
[6.222947199999908, 5.875851210999826, 5.797133195999777, 5.959414672999628, 6.051635342000736]


### Timing of eval with constant X_grid for different training sizes

In [7]:
X = jnp.linspace(0.2, 0.8, 10000)

for elem in data_sizes:
    data_split = (elem, )

    X_data = jnp.linspace(0.0, 1.0, data_split[0]).reshape(-1,1)
    Y_data =jnp.ones(data_split[0]).reshape(-1,1)

    model = gpr.ExactGPR(kernel=kernel_1, data_split=data_split, noise=0.1)
    model.train(X_data=X_data, Y_data=Y_data)

    def test_RBF():
        model.eval(X, return_std=True)


    print(timeit.repeat(test_RBF, repeat=5, number=1))

[0.22007308999855013, 0.009095800000068266, 0.007168100000853883, 0.007208598999568494, 0.006715299999996205]
[0.22867819000020972, 0.017340999000225565, 0.016675999000653974, 0.017060398999092286, 0.01753669999925478]
[0.2619532890003029, 0.049606898001002264, 0.055834196999057895, 0.04348789899995609, 0.05943329699948663]
[0.37868828299906454, 0.14374369399956777, 0.15317559300092398, 0.15216089400018973, 0.1494772930000181]
[0.6759339679992991, 0.5475161750000552, 0.4536105780007347, 0.5185008759999619, 0.5095852760005073]
[1.4078279329987708, 1.1115338550007436, 1.042998388000342, 1.1854012859985232, 1.2884399850008776]
[2.8885470919994987, 2.4990912069988553, 2.4752440090014716, 2.626823667998906, 2.841024596000352]


: 

: 