## Exporting Packages

In [1]:
try:
  import distrax
except ModuleNotFoundError:
  %pip install distrax
  import distrax

try:
  import tinygp
except ModuleNotFoundError:
  %pip install tinygp
  import tinygp
try:
  import jax
except ModuleNotFoundError:
  %pip install jax 
  import jax

import jax.numpy as jnp
try:
  import matplotlib.pyplot as plt
except ModuleNotFoundError:
  %pip install matplotlib 
  import matplotlib.pyplot as plt

try:
  import GPy
except ModuleNotFoundError:
  %pip install GPy
  import GPy
try:
  from tqdm import tqdm
except ModuleNotFoundError:
  %pip install tqdm
  from tqdm import tqdm
try: 
  import jaxopt
except ModuleNotFoundError:
  %pip install jaxopt
  import jaxopt
try:
  import optax
except ModuleNotFoundError:
  %pip install optax
  import optax
try:
  import sklearn
except ModuleNotFoundError:
  %pip install sklearn
  import sklearn

try:
  import torch
except ModuleNotFoundError:
  %pip install torch
  import torch

try:
  import gpytorch
except ModuleNotFoundError:
  %pip install gpytorch
  import gpytorch

from gpytorch.kernels import ScaleKernel, RBFKernel
from torch.distributions import Normal, MultivariateNormal
from sklearn.neighbors import NearestNeighbors
import numpy as np
from tinygp import kernels, GaussianProcess

#### Built NLL

In [None]:
def NLL(theta,x,y):
  kernel = jnp.exp(theta["log_varf"])*kernels.ExpSquared(scale=jnp.exp(theta["log_scale"]))
  k = kernel(x,x) + ( jnp.exp(theta["log_vary"]) * jnp.eye(len(x)))
  mean_vec= jnp.zeros(y.shape[0])
  dist = distrax.MultivariateNormalFullCovariance(mean_vec, k)
  dist_logprob = dist.log_prob(y.reshape(-1,))
  return -dist_logprob

#### TIny GP

In [None]:
def build_gp(theta_, x):
  kernel = jnp.exp(theta_["log_varf"])*kernels.ExpSquared(scale=jnp.exp(theta_["log_scale"]))
  return GaussianProcess(kernel, x, diag = jnp.exp(theta_["log_vary"]))

def Nll(theta_, x, y):
  gp = build_gp(theta_, x)
  return -gp.log_probability(y)

In [1]:
def SGD(x, y, theta, batch_size, alpha, epochs):
  nll_epoch = []
  var_signal = []
  var_noise = []
  theta1 = theta
  lr = alpha
  nll_gradient = (jax.grad(Nll,  argnums = 0))
  
  ## determining the number of batches
  if (len(x) % batch_size  == 0):
    num_batches = int(len(x)/batch_size)
  else:
    num_batches = int((len(x)/batch_size)) + 1

  
  ## finiding NN indices
  neigh = NearestNeighbors(n_neighbors=batch_size, algorithm='kd_tree')
  neigh.fit(x)
  _,neigh_idx = neigh.kneighbors(x, batch_size)
  tx = optax.adam(lr)
    
  for i in range(epochs):
    
    ## use tfds for shuffling
    batch_index = 0
    # X_, Y_ = jax.random.shuffle(jax.random.PRNGKey(2), x), jax.random.shuffle(jax.random.PRNGKey(2), y)

  
    for k in range(num_batches):
      
    
      opt_state2 = tx.init((theta1["log_varf"]))
      opt_state3 = tx.init((theta1["log_vary"]))

      ## Random batches
      # if batch_index+batch_size > len(X_):
      #    X_batch, Y_batch = X_[batch_index:,:], Y_[batch_index:,:]
      # else:
      #   X_batch, Y_batch = X_[batch_index:batch_index+batch_size,:], Y_[batch_index:batch_index+batch_size,:]

      # grads = nll_gradient(theta1, X_batch, Y_batch)
      
      ## NN batches
      # center_idx  = jax.random.randint(jax.random.PRNGKey(0),(1,), 1, len(y))
      center_idx = torch.tensor((k - 1)%len(y))
      nn_batch_indices =  neigh_idx[center_idx, ]
      nn_batch_X  = x[nn_batch_indices, ].reshape(-1,1)
      nn_batch_y  = y[nn_batch_indices,].reshape(-1,1)
      
      grads = nll_gradient(theta1,  nn_batch_X,  nn_batch_y)


      ## updating params
      updates2,opt_state2 = tx.update(((batch_size*grads["log_varf"])/(3*jnp.log(batch_size))), opt_state2)
      theta1["log_varf"] = optax.apply_updates((theta1["log_varf"]), updates2)
      updates3,opt_state3 = tx.update(grads["log_vary"], opt_state3)
      theta1["log_vary"] = optax.apply_updates((theta1["log_vary"]), updates3)
      nll_epoch.append(Nll(theta1,x.reshape(-1,), y.reshape(-1,)))
      var_signal.append(jnp.exp(theta1["log_varf"]))
      var_noise.append(jnp.exp(theta1["log_vary"]))

      # print(jnp.exp(theta1["log_varf"]), jnp.exp(theta1["log_vary"]))
      batch_index += batch_size
      # lr = lr/(k+1)

  print(Nll(theta1,x.reshape(-1,), y.reshape(-1,)))
  print(jnp.exp(theta1["log_scale"]), jnp.exp(theta1["log_varf"]), jnp.exp(theta1["log_vary"]))

  return nll_epoch, var_signal, var_noise


In [None]:
batch_size = 64
N = 1024

theta_init1 = {"log_varf": jnp.log(5.),"log_vary": jnp.log(3.),"log_scale": jnp.log(0.01)} 
theta_init2 = {"log_varf": jnp.log(2.5),"log_vary": jnp.log(3.5),"log_scale": jnp.log(0.01)}
theta_init3 = {"log_varf": jnp.log(2.5),"log_vary": jnp.log(0.7),"log_scale": jnp.log(0.01)}

theta_init = [theta_init1, theta_init2, theta_init3]
alpha_arr = [0.01, 0.01, 0.01]

fig,ax = plt.subplots(1,3,figsize=(60,12))

seed = np.random.randint(50, size=10)

for  j in range(3):
  
  for t in tqdm(range(10)):
    theta_init1 = {"log_varf": jnp.log(5.),"log_vary": jnp.log(3.),"log_scale": jnp.log(0.01)} 
    theta_init2 = {"log_varf": jnp.log(2.5),"log_vary": jnp.log(3.5),"log_scale": jnp.log(0.01)}
    theta_init3 = {"log_varf": jnp.log(2.5),"log_vary": jnp.log(0.7),"log_scale": jnp.log(0.01)}

    theta_init = [theta_init1, theta_init2, theta_init3]

    theta_ = theta_init[j]
    # print(theta_)
    alpha_ = alpha_arr[j] 

    # key = jax.random.PRNGKey(seed[t])

    # X_dist = distrax.Normal(jnp.array(0.0),jnp.array(5.0))
    # X = X_dist.sample(seed=key, sample_shape = (1024,)).reshape(-1,1)
    # key_ = jax.random.split(key, num=3)
    # # print(X.shape)

    # varf = jnp.array(4.0)
    # len_scale = jnp.array(0.01)
    # vary = jnp.array(0.01)

    # kernel = varf*kernels.ExpSquared(scale=len_scale)
    # cov = kernel(X, X) + vary * jnp.eye(len(X))
    # # cov = varf*kernels.ExpSquared(scale=len_scale) 
    # mean_vec = jnp.zeros(N,)
    # Y_dist = distrax.MultivariateNormalFullCovariance(mean_vec, cov)
    # Y = Y_dist.sample(seed=key_[0]).reshape(-1,1)
    # print(Y.shape)

    torch.manual_seed(0)
    x_dist = Normal(torch.tensor([0.0]), torch.tensor([5.0]))
    X = x_dist.sample((N,))
    K = ScaleKernel(RBFKernel())
    K.base_kernel.lengthscale = 0.01
    K.outputscale = 1.0
    cov = K(X,X) + (0.1)*(torch.eye(len(X)))
    dist = MultivariateNormal(torch.zeros((1024)),cov.evaluate())
    torch.manual_seed(t+1)
    Y = dist.sample()

    X = np.asarray(X)
    X = jnp.array(X)

    Y = np.asarray(Y)
    Y = jnp.array(Y)


    loss, param1, param2 = SGD(X, Y, theta_, batch_size, alpha_, 50)

    # ax[].plot(loss, label='loss')
    ax[j].plot(param1, linestyle='dashed', linewidth=1, markersize=8)
    ax[j].plot(param2, linestyle='dotted', linewidth=1, markersize=6)

  ax[j].axhline(y = 1, color = 'black', linestyle = 'dashed')
  ax[j].axhline(y = 0.1, color = 'black', linestyle = ':')
  ax[j].set_xlabel('Iteration (k)')
  ax[j].set_ylabel('O(k)')
  
  theta_init1 = {"log_varf": jnp.log(5.),"log_vary": jnp.log(3.),"log_scale": jnp.log(0.01)} 
  theta_init2 = {"log_varf": jnp.log(2.5),"log_vary": jnp.log(3.5),"log_scale": jnp.log(0.01)}
  theta_init3 = {"log_varf": jnp.log(2.5),"log_vary": jnp.log(2.7),"log_scale": jnp.log(0.01)}

  theta_init = [theta_init1, theta_init2, theta_init3]

  ax[j].set_title(f'({round(jnp.exp(theta_init[j]["log_varf"]).item(),2)},{round(jnp.exp(theta_init[j]["log_vary"]).item(),2)})')
  ax[j].set_xticks([0,50,100,150,200])
  ax[j].set_yticks([0,1,2])
  ax[j].grid()
 
  
plt.savefig('figure1_adam0.01.png')