In [4]:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

In [28]:
def gp_predict(x_test,x_train,y_train,lengthscale=1.,outputscale=1.,noise=1e-6):
    """
    Predicts the mean and variance of a Gaussian Process at test points given training data.
    
    Parameters:
    - x_test: Test input points (shape: [n_test, d])
    - x_train: Training input points (shape: [n_train, d])
    - y_train: Training output values (shape: [n_train])
    - lengthscale: Lengthscale of the kernel
    - outputscale: Output scale of the kernel
    - noise: Noise level in the observations
    
    Returns:
    - mean: Predicted mean at test points (shape: [n_test])
    - var: Predicted variance at test points (shape: [n_test])
    """
    
    # Compute squared distances

    sqdist = jnp.sum(x_train**2, axis=1, keepdims=True) + jnp.sum(x_train**2, axis=1) - 2 * jnp.dot(x_train, x_train.T)
    K = outputscale**2 * jnp.exp(-0.5 * sqdist / lengthscale**2)
    # print("K shape:", K.shape)

    x_test = jnp.atleast_2d(x_test)

    sqdist_12 = jnp.sum(x_test**2, axis=1, keepdims=True) + jnp.sum(x_train**2, axis=1) - 2 * jnp.dot(x_test, x_train.T)
    
    # Compute covariance matrix
    K12 = outputscale**2 * jnp.exp(-0.5 * sqdist_12 / lengthscale**2)
    
    # print("K12 shape:", K12.shape)

    # Add noise to the diagonal
    K += noise * jnp.eye(K.shape[0])

    # Compute the Cholesky decomposition
    L = jnp.linalg.cholesky(K)
    
    # Solve for alpha
    alpha = jax.scipy.linalg.cho_solve((L, True), y_train) #jnp.linalg.solve(L.T, jnp.linalg.solve(L, y_train))
    # print("Alpha shape:", alpha.shape)

    # Compute the mean at test points
    mean = jnp.dot(K12, alpha)
    # mean = jnp.einsum('ij,ji',K12.T, alpha)
    
    # # Compute the variance at test points
    # v = jnp.linalg.solve(L, K.T)
    # var = outputscale**2 - jnp.sum(v**2, axis=0)
    return jnp.reshape(mean,())
    # return mean, var

n = 5
x_train = jnp.array([(i)*0.25 for i in range(n)]).reshape(-1, 1)
y_train = jnp.sin(x_train)

print("Training data:"  , x_train, y_train)
# print("Shapes:", x_train.shape, y_train.shape)

x_test = jnp.array([(i)*0.33 for i in range(1,4)]).reshape(-1, 1)

print("Test data:"  , x_test)

# print("Shapes:", x_test.shape)

# mean = gp_predict(x_test, x_train, y_train, lengthscale=0.5, outputscale=1.0, noise=1e-6)

# print("Predicted mean at test points:", mean)

for x in x_test:
    pred = gp_predict(x, x_train, y_train, lengthscale=0.5, outputscale=1.0, noise=1e-6)
    print(f"GP prediction at {x}: {pred}")
    grad = lambda x: jax.grad(gp_predict, argnums=0)(x, x_train, y_train, lengthscale=0.5, outputscale=1.0, noise=1e-6).squeeze()
    print(f"GP gradient at {x}: {grad(x)}")
    # gradient of gradient

    print(f"GP gradient of gradient at {x}: {jax.grad(grad)(x)} ")


Training data: [[0.  ]
 [0.25]
 [0.5 ]
 [0.75]
 [1.  ]] [[0.        ]
 [0.24740396]
 [0.47942554]
 [0.68163876]
 [0.84147098]]
Test data: [[0.33]
 [0.66]
 [0.99]]
GP prediction at [0.33]: 0.3255984553690249
GP gradient at [0.33]: 0.9558227774529487
GP gradient of gradient at [0.33]: [-0.58985326] 
GP prediction at [0.66]: 0.611425137763322
GP gradient at [0.66]: 0.797045382858282
GP gradient of gradient at [0.66]: [-0.33433219] 
GP prediction at [0.99]: 0.8368924779938347
GP gradient at [0.99]: 0.4682674105907383
GP gradient of gradient at [0.99]: [-2.10317375] 


In [None]:
from matplotlib import pyplot as plt


plt.plot(x_train, y_train, 'ro', label='Training data')
x_test = jnp.linspace(0, 1.5, 100).reshape(-1, 1)