In [1]:
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from dataclasses import dataclass
import optax as ox
import jax
import cola

from jaxtyping import (
    Float,
)
from gpjax.typing import (
    Array,
    ScalarFloat,
)
import jax.numpy as jnp
import jax.random as jr
import jax.numpy as jnp
import gpjax as gpx
from jax import jit, config, grad, hessian

from kernels import arcSin

import tensorflow_probability.substrates.jax.bijectors as tfb
from gpjax.base import param_field, static_field
from functools import partial

config.update("jax_enable_x64", True)

# set jnp random key
key = jr.PRNGKey(42)

In [2]:
########################################## edit the parameters here ##########################################

# define the parameters of the model
dt = 0.001

# define noise
noise = 0.001

# number of test points
n_test = 31
n_boundary = 2

In [3]:
#read from the json file
with open('../../testdata/burger_eqn.json','r') as f:
    data = json.load(f)
    datax = jnp.array(data['x'])
    datau = jnp.array(data['u'])

# the first dimension is t, second dimension is x

In [4]:
# save the true solution
u = datau.copy()

# add noise to u_new
u_new = datau + jnp.sqrt(noise) * jr.normal(key, datau.shape)
x_new = datax

In [5]:
# select n random points that is spaced out 
n_points = 50
idx = np.random.choice(len(datax), n_points, replace=True)
idx = jnp.linspace(0, len(x_new)-1, n_points).astype(int)

In [6]:
# training data
x_train = jnp.hstack([x_new[idx].T, x_new[idx].T]).T
t_train = jnp.hstack([jnp.array([0]*n_points),jnp.array([1]*n_points)]).T
mu_train = jnp.hstack([u_new[0,idx], u_new[0,idx]]).T

X_train = jnp.vstack([x_train.T,t_train,mu_train]).T
u_train = jnp.hstack([u_new[0,idx].T, u_new[1,idx].T]).reshape(-1,1)

In [7]:
# testing data
x_test = jnp.hstack([x_new.T, x_new.T]).T
t_test = jnp.hstack([jnp.array([0]*len(x_new)),jnp.array([1]*len(x_new))]).T
mu_test = jnp.hstack([u_new[0,:], u_new[0,:]]).T

X_test = jnp.vstack([x_test.T,t_test,mu_test]).T
u_test = jnp.hstack([u_new[0,:].T, u_new[1,:].T]).reshape(-1,1)


In [8]:
dataset_train = gpx.Dataset(X_train, u_train)
dataset_test = gpx.Dataset(X_test, u_test)

In [9]:
@dataclass
class BurgerKernel(gpx.kernels.AbstractKernel):
    # define the base kernel
    kernel: gpx.kernels.AbstractKernel = arcSin(active_dims = [0])
    
    # add parameters for the kernel
    nu: ScalarFloat = param_field(jnp.array(0.01/jnp.pi), trainable = False, bijector=tfb.Softplus())
    delta_t: ScalarFloat = param_field(jnp.array(dt), trainable = False, bijector=tfb.Softplus())

    def __call__(
        self, 
        X: Float[Array, "1 D"], 
        Xp: Float[Array, "1 D"]
    ) -> Float[Array, "1"]:

        # compute the kernel matrix as a block matrix
        z = jnp.array(X[1], dtype=int)
        zp = jnp.array(Xp[1], dtype=int)

        # mu has the third dimension in X
        mu = jnp.array(X[2], dtype=jnp.float64)
        mu_p = jnp.array(Xp[2], dtype=jnp.float64)

        # X and Xp
        X = jnp.array(X[0], dtype=jnp.float64)
        Xp = jnp.array(Xp[0], dtype=jnp.float64)

        # switch_0_0 is 1 when z == zp == 0, 0 otherwise
        switch_0_0 = jnp.where((z == 0) & (zp == 0), 1, 0)
        switch_0_1 = jnp.where((z == 0) & (zp == 1), 1, 0)
        switch_1_0 = jnp.where((z == 1) & (zp == 0), 1, 0)
        switch_1_1 = jnp.where((z == 1) & (zp == 1), 1, 0)

        # gradient of the kernel
        grad_kernel = self.kernel.dX(X, Xp)
        grad_p_kernel = self.kernel.dXp(X, Xp)

        # hessian of the kernel
        hess_kernel = self.kernel.dX_dX(X, Xp)
        hess_p_kernel = self.kernel.dXp_dXp(X, Xp)
        grad_grad_p_kernel = self.kernel.dX_dXp(X, Xp)

        # hessian of the hessian of the kernel
        hess_hess_p_kernel = self.kernel.dX2_dXp2(X, Xp)
        
        # hessian of the gradient of the kernel
        hess_grad_p_kernel = self.kernel.dX2_dXp(X, Xp)
        hess_p_grad_kernel = self.kernel.dXp2_dX(X, Xp)

        
        # kernels
        kernel_1_1 = (self.kernel(X, Xp))

        kernel_1_0 = (self.kernel(X, Xp)
                        + self.delta_t * mu_p * grad_p_kernel
                        - self.delta_t * self.nu * hess_p_kernel 
                        )
        
        kernel_0_1 = (self.kernel(X, Xp)
                        + self.delta_t * mu * grad_kernel
                        - self.delta_t * self.nu * hess_kernel
                        )

        kernel_0_0 = (self.kernel(X, Xp)
                        + self.delta_t * mu_p * grad_p_kernel
                        - self.delta_t * self.nu * hess_p_kernel 
                        + self.delta_t * mu * grad_kernel
                        - self.delta_t * self.nu * hess_kernel
                        + self.delta_t**2 * mu * mu_p * grad_grad_p_kernel
                        - self.nu * self.delta_t**2 * mu_p * hess_grad_p_kernel
                        - self.nu * self.delta_t**2 * mu * hess_p_grad_kernel 
                        + self.nu**2 * self.delta_t**2 * hess_hess_p_kernel 
                        )

        return (switch_0_0 * kernel_0_0 + switch_0_1 * kernel_0_1 + switch_1_0 * kernel_1_0 + switch_1_1 * kernel_1_1).squeeze()

In [10]:
#define the mean and kernel functions
mean = gpx.mean_functions.Zero()
kernel = BurgerKernel()

# Construct GP prior
prior = gpx.gps.Prior(mean_function = mean, kernel = kernel)

# Construct GP likelihood
likelihood = gpx.likelihoods.Gaussian(num_datapoints = n_points, obs_stddev=jnp.sqrt(noise))

# Construct the posterior
posterior = prior * likelihood

# define the log marginal likelihood
negative_mll = gpx.objectives.ConjugateMLL(negative=True)

In [11]:
opt_posterior, history = gpx.fit(
    model=posterior,
    objective=negative_mll,
    train_data=dataset_train,
    optim=ox.adamw(learning_rate=1e-2),
    num_iters=10000,
    key=key,
    safe = False
)

  0%|          | 0/10000 [00:00<?, ?it/s]

In [12]:
# boundary terms
bt_idx = jnp.linspace(0, len(u_new[0])-1, n_boundary).astype(int)
idx = jnp.linspace(0, len(x_new)-1, n_test).astype(int)

t_train = jnp.hstack([jnp.array([0]*n_test),jnp.array([1]*n_boundary)]).T
x_train = jnp.hstack([x_new[idx], x_new[bt_idx]]).T
mu_train = jnp.hstack([u_new[0,idx], u_new[0,bt_idx]]).T
u_train = jnp.hstack([u_new[0,idx], u_new[1,bt_idx]]).reshape(-1,1)

x_test = jnp.array(x_new[idx])
t_test = jnp.array([1]*n_test)
mu_test = jnp.array(u_new[0,idx]*n_test)
X_test = jnp.vstack([x_test,t_test,mu_test]).T

print (x_train.shape, t_train.shape, mu_train.shape)
print (x_test.shape, t_test.shape, mu_test.shape)

(33,) (33,) (33,)
(31,) (31,) (31,)


In [13]:
import time 
# intial u0:
u0 = u_new[0,idx]

# initialize the accumulated uncertainty
updated_std = jnp.array([0]*n_test)

# initialize the covariance matrix
updated_cov = jnp.eye(n_test) * noise

# initialize the error, u_mean, u_std
error, u_mean, u_std, u_true = [], [], [], []

for t in range(int(dt*1000), 501, int(dt*1000)):
    # start a timer
    tstart = time.time()

    # u_train is the previous time step u and the boundary conditions
    # u_train = jnp.hstack([u0, u_new[t,bt_idx].T]).reshape(-1,1)
    u_train = jnp.hstack([u0, jnp.array([0,0])]).reshape(-1,1)
    X_train = jnp.vstack([x_train.T,t_train,mu_train]).T 

    dataset_train = gpx.Dataset(X_train, u_train)

    # train the model 
    opt_posterior, history = gpx.fit(
        model=opt_posterior,
        objective=negative_mll,
        train_data=dataset_train,
        optim=ox.adamw(learning_rate=1e-2),
        num_iters=20,
        key = key,
        verbose= False,
        safe = False 
    )

    
    # latent_dist = opt_posterior.predict(X_test, train_data=dataset_train)
    latent_dist = opt_posterior.predict_with_prev_cov(X_test, train_data=dataset_train, prev_cov=updated_cov)
    predictive_dist = opt_posterior.likelihood(latent_dist)
    predictive_mean = predictive_dist.mean()
    predictive_std = predictive_dist.stddev()
    predictive_cov = predictive_dist.covariance()

    updated_std = predictive_std
    updated_cov = predictive_cov

    updated_std = jnp.sqrt(jnp.diag(updated_cov))

    # the u that is predicted, use this u0 for the next time step
    u0 = predictive_mean

    # save the error
    error.append((u[t,idx] - predictive_mean)**2)
    u_mean.append(predictive_mean)
    u_std.append(updated_std)
    u_true.append(u[t,idx])

    # update the mu_train, and hence, dataset_train
    mu_train = jnp.hstack([u0, u_new[t,bt_idx]]).T

    print (u_mean)

    #
    print (f"Time step {t} is done. Time taken: {time.time() - tstart}")

[Array([-0.00625449,  0.22181859,  0.38119524,  0.56658989,  0.75311322,
        0.85881087,  0.91189334,  0.9947518 ,  1.03663717,  0.92642941,
        0.82665584,  0.74991692,  0.55479482,  0.40773576,  0.22766946,
       -0.01424895, -0.20506263, -0.39126325, -0.57719359, -0.76704826,
       -0.86255345, -0.91663591, -0.95299016, -0.92766595, -0.93991527,
       -0.90277889, -0.77334156, -0.60269965, -0.41015267, -0.21124656,
        0.02008933], dtype=float64)]
Time step 1 is done. Time taken: 44.19962024688721
[Array([-0.00625449,  0.22181859,  0.38119524,  0.56658989,  0.75311322,
        0.85881087,  0.91189334,  0.9947518 ,  1.03663717,  0.92642941,
        0.82665584,  0.74991692,  0.55479482,  0.40773576,  0.22766946,
       -0.01424895, -0.20506263, -0.39126325, -0.57719359, -0.76704826,
       -0.86255345, -0.91663591, -0.95299016, -0.92766595, -0.93991527,
       -0.90277889, -0.77334156, -0.60269965, -0.41015267, -0.21124656,
        0.02008933], dtype=float64), Array([-0

In [None]:
# create the folder
import os

# file name 
foldername = f'Burger_arcsin_{n_test}_test_points_{noise}_noise_{dt}_timestep'
folderpath = f'../../result/{foldername}'

if not os.path.exists(folderpath):
    os.makedirs(folderpath)

# save the error, u_mean, u_std, and x
np.save(f'{folderpath}/error.npy', np.array(error))
np.save(f'{folderpath}/u_mean.npy', np.array(u_mean))
np.save(f'{folderpath}/u_std.npy', np.array(u_std))
np.save(f'{folderpath}/x.npy', np.array(x_test))
np.save(f'{folderpath}/u_true.npy', np.array(u_true))