In [105]:
import nll
from jax import numpy as np
import utils
from matplotlib import pyplot as plt
import dataset_sines_infinite
import dataset_sines_finite
import dataset_step_infinite
from jax import vmap
from jax import scipy
from jax.scipy.special import gamma

In [92]:
seed = 1655235988902897757
print(seed)

1655235988902897757


In [93]:
def gaussian_posterior_full(kernel_matrix, x_a, y_a, x_b, maddox_noise):
    """
    Computes the gaussian posterior with this kernel and this data, on the queried inputs.
    x_a is a (batch_size, input_dims) array (! has lost n_tasks)
    y_a is a (batch_size, reg_dim) array (! has lost n_tasks)
    Returns the posterior covariance matrix
    """
    y_a = np.reshape(y_a, (-1,))

    cov_a_a = kernel_matrix(x_a, x_a)
    cov_a_a = cov_a_a + maddox_noise ** 2 * np.eye(cov_a_a.shape[0])
    cov_b_a = kernel_matrix(x_b, x_a)
    cov_b_b = kernel_matrix(x_b, x_b)

    print(cov_a_a.shape)
    print(cov_b_a.shape)
    print(cov_b_b.shape)
    print(y_a.shape)

    L = scipy.linalg.cho_factor(cov_a_a)
    alpha = scipy.linalg.cho_solve(L, y_a)
    post_mean = cov_b_a @ alpha
    
    v = scipy.linalg.cho_solve(L, cov_b_a.T)
    post_cov = cov_b_b - cov_b_a @ v
    
    return post_mean, post_cov

In [94]:
def plot_gpr(x_a_all, y_a_all, x_b, y_b, kernel_matrix, K, dataset_provider):
    """
    Make an informative prediction plot in the singGP case (for the kernel specified)
    K is the number of context inputs
    Change dataset_provider to test on other datasets (e.g. dataset_sines_infinite)
    """
    y_min, y_max = np.min(y_b) - 0.5, np.max(y_b) + 0.5

    x_a = x_a_all[:K]
    y_a = y_a_all[:K]
    prediction, cov = gaussian_posterior_full(kernel_matrix, x_a, y_a, x_b, 0.05)

    error = dataset_provider.error_fn(prediction, y_b)
    loss = nll.nll(kernel_self_matrix, x_a, y_a, maddox_noise=0.05)

    variances = np.diag(cov)
    stds = np.sqrt(variances)

    plt.plot(x_b, y_b, "g--", label="Target")
    plt.plot(x_a, y_a, "ro", label="Context data")
    plt.plot(x_b, prediction, "b", label="Prediction")
    plt.fill_between(x_b, prediction - 1.96 * stds, prediction + 1.96 * stds, color='blue', alpha=0.1, label="+/- 1.96$\sigma$")
    plt.title(f"NLL={loss:.4f}, MSE={error:.4f} ($K$={K})")
    plt.legend()
    plt.gca().set_ylim([np.min(prediction), np.max(prediction)])
    plt.gca().set_xlabel("$x$")
    plt.gca().set_ylabel("$y$")
    plt.legend()

## Choice of kernels : 

In [95]:
l = 1

def RBF_kernel(x1, x2):
    # Now x1 and x2 are compatible for broadcasting
    # Compute squared Euclidean distance
    squared_diff = (x1 - x2) ** 2 / (2*l**2)
    return np.exp(-squared_diff)

In [96]:
def CosSim_kernel(x1, x2):
    normalized_factor = np.linalg.norm(x1)*np.linalg.norm(x2)
    return np.dot(x1, x2)

In [97]:
p = 2
c = 1

def polynomial_kernel(x1, x2):
    return (np.dot(x1, x2) + c)**p

In [99]:
length_scale = 1
nu = 2.5

def matern_kernel(x1, x2):
    # Euclidian distance
    sqdist = np.sum((x1 - x2) ** 2)
    r = np.sqrt(sqdist)

    # Scaling factor
    sqrt_2_nu_r_over_l = np.sqrt(2 * nu) * r / length_scale

    # Matérn kernel formula
    coefficient = (2 ** (1 - nu)) / gamma(nu)
    result = coefficient * (sqrt_2_nu_r_over_l ** nu) * bessel_k(nu, sqrt_2_nu_r_over_l)

    return result

In [100]:
#Choose here: 
kernel = matern_kernel

# Apply vmap to vectorize kernel function over pairs of inputs
kernel_matrix = vmap(vmap(kernel, in_axes=(None, 0)), in_axes=(0, None))

In [101]:
key = random.PRNGKey(0)

In [106]:
p = 2
c = 1

K = 100
x, y, fun = dataset_sines_infinite.get_fancy_test_batch(key, K=K, L=0, data_noise=0.05)

x_a_all = x[0, :K]
y_a_all = y[0, :K]
x_a_all = np.reshape(x_a_all, (-1,))
y_a_all = np.reshape(y_a_all, (-1,))
print(x_a_all.shape, y_a_all.shape)

x_b = np.linspace(-5, 5, 100)[:, np.newaxis]
y_b = fun(x_b)
x_b = np.reshape(x_b, (-1,))
y_b = np.reshape(y_b, (-1,))
print(x_b.shape, y_b.shape)

plot_gpr(x_a_all, y_a_all, x_b, y_b, kernel_matrix, K, dataset_sines_infinite)

(100,) (100,)
(100,) (100,)


AttributeError: module 'jax.numpy' has no attribute 'trapz'