# Explains the conceptual idea of the paper
The notebook generates Figure 1 to confirm the validity of the approach.

In [None]:
%load_ext autoreload
%autoreload 2
import jax
import optax as ox
import gpjax as gpx
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
jax.config.update("jax_enable_x64", True)
import seaborn as sns
%matplotlib widget
sns.set()
from dataclasses import dataclass, field
from fasthgp.utils import integrate
key = jr.PRNGKey(13)

### Generate data

In [None]:
s2 = 0.1
L = 1
key, subkey = jr.split(key)
x = jr.uniform(subkey, maxval=2*L, shape=(100,)).sort()-L
key, subkey = jr.split(key)
f = lambda x: jnp.tanh(x*2*jnp.pi) * jnp.sin(x*2*jnp.pi)
y = f(x) + jnp.sqrt(s2)*jr.normal(subkey, shape=x.shape)

### RBF model
Finds the posterior parameters $(m,V)$ of the RBF model.

In [None]:
def phi(x, c, l):
    e = x - c
    return jnp.exp(-1/l**2 * jnp.sum(e**2))

Phi = jax.vmap(jax.vmap(phi, (None, 0, None), 0), (0, None, None), 0)
M = 10
c = jnp.linspace(-1, 1, M)
l = 0.25
m = jnp.linalg.lstsq(Phi(x, c, l), y)[0]
V = s2* jnp.linalg.inv(Phi(x, c, l).T @ Phi(x, c, l))

### Identifies the relevant BFs
int_ind uses the integral criterion

ind uses the simplified criterion

In [None]:
lims = [-.5, 0] # Omega
fun = lambda x, c: jnp.exp(-2/l**2 * jnp.sum((x - c)**2))
sc = integrate(jax.vmap(jax.vmap(fun, (None, 0), 0), (0, None), 0), lims, args=[c]) # Numerical integration for simplicity

cost = jnp.abs(m)**2
int_cost = sc * cost
M = 2
ind = jnp.argsort(cost)[-M:]
int_ind = jnp.argsort(int_cost)[-M:]

### Compute predictions on a test grid

In [None]:
def model(xtest, m, V, inds=None):
    Phit = Phi(xtest, c, l)
    if inds is not None:
        Phit = Phit[:, inds]
        m = m[inds]
        V = V[inds[:, None], inds[None, :]]
    return Phit @ m, Phit @ V @ Phit.T
    
xtest = jnp.linspace(-1, 1, 100)
mu, S = model(xtest, m, V)
mu_ind, S_ind = model(xtest, m, V, ind)
mu_int, S_int = model(xtest, m, V, int_ind)

### Generate a mock-up of Figure 1

In [None]:
def plot_bfs(ax, c, **kwargs):
    for ci in c:
        xi = jnp.linspace(ci-2*l, ci+2*l, 50)
        ax.plot(xi, jax.vmap(phi, (0, None, None), 0)(xi, ci, l)*0.1 + 1.5, linewidth=.5, **kwargs)

def conf_int(ax, mu, S, **kwargs):
    ub = mu + jnp.sqrt(S.diagonal())
    lb = mu - jnp.sqrt(S.diagonal())
    l = ax.plot(xtest, mu, **kwargs)[0]
    ax.fill_between(xtest, lb, ub, color=l.get_color(), alpha=.3)
    return l 

plt.close("all")
fig, ax = plt.subplots()
plt.plot(xtest, f(xtest), 'k')
conf_int(ax, mu, S, label='Full model')
l_ind = conf_int(ax, mu_ind, S_ind, label='Standard')
l_int = conf_int(ax, mu_int, S_int, label='Integral')
plot_bfs(ax, c[ind], color=l_ind.get_color())
plot_bfs(ax, c[int_ind], color=l_int.get_color(), linestyle='--')
plt.vlines(lims, ymin=-1.5, ymax=1.5, color='k', label='Integration limits')
plt.legend()
plt.show()

### Save to .csv
The paper plot is generated in pgfplots from .csv files.

In [None]:
import pandas as pd
filenames = ["rbf_base.csv", "rbf_standard.csv", "rbf_integral.csv"]
fs = [mu, mu_ind, mu_int]
Ss = [S, S_ind, S_int]
for i, fname in enumerate(filenames):
    with open(fname, "w") as file:
        pd.DataFrame(dict(x=xtest, f=fs[i], std=jnp.sqrt(Ss[i].diagonal()))).to_csv(file, index=False)