In [2]:
try:
  import distrax
except ModuleNotFoundError:
  %pip install distrax
  import distrax
try:
  import jax
except ModuleNotFoundError:
  %pip install jax 
  import jax

import jax.numpy as jnp
try:
  import matplotlib.pyplot as plt
except ModuleNotFoundError:
  %pip install matplotlib 
  import matplotlib.pyplot as plt

jax.config.update("jax_enable_x64", True)

try:
  import GPy
except ModuleNotFoundError:
  %pip install GPy
  import GPy

try:
  from tqdm import tqdm
except ModuleNotFoundError:
  %pip install tqdm
  from tqdm import tqdm

try:
  import tinygp
except ModuleNotFoundError:
  %pip install tinygp
  import tinygp
  
try: 
  import jaxopt
except ModuleNotFoundError:
  %pip install jaxopt
  import jaxopt

import optax
from sklearn.neighbors import NearestNeighbors

try:
  from pyDOE import *
except ModuleNotFoundError:
  ! pip install pyDOE
  from pyDOE import *

try:
  from smt.sampling_methods import LHS
except ModuleNotFoundError:
   ! pip install smt
   from smt.sampling_methods import LHS

import pandas as pd
import numpy as np

import time
import sklearn
from tinygp import kernels, GaussianProcess
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

from sklearn.metrics import mean_squared_error

In [3]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [4]:
dirname = "../Datasets/Levy"
X_train = pd.read_csv(dirname+"/X_train.csv.gz").iloc[:,1:]
X_test = pd.read_csv(dirname+"/X_test.csv.gz").iloc[:,1:]
y_train = pd.read_csv(dirname+"/y_train.csv.gz").iloc[:,1:]
y_test = pd.read_csv(dirname+"/y_test.csv.gz").iloc[:,1:]

X_train, X_test = jnp.array(jnp.array([X_train])).squeeze(),jnp.array([X_test]).squeeze()
y_train, y_test = jnp.array([y_train]).squeeze(),jnp.array([y_test]).squeeze() 

print(X_train.shape, X_test.shape)

(6000, 4) (4000, 4)


In [None]:
def build_gp(theta_, x):
  kernel = theta_["varf"]*kernels.ExpSquared(scale=theta_["len_scale"])
  return GaussianProcess(kernel, x, diag = theta_["vary"])

def NLL(theta_, x, y):
  gp = build_gp(theta_, x)
  return -gp.log_probability(y)

In [None]:
# initial paramters
max_iters = int(100)
lr = 0.01
nll_iters = []
theta_init = {"varf": 1.,"vary": 1.,"len_scale": 1.}
batch_size = 16

start = time.time()
nll_gradient = jax.value_and_grad(NLL, argnums = 0)

# Using adam optimizer
tx1 = optax.adam(lr)

## Nearest Neighbour calculation
neigh = NearestNeighbors(n_neighbors=32, algorithm='kd_tree')
neigh.fit(X_train)
_,neigh_idx = neigh.kneighbors(X_train, 32)

fig, ax = plt.subplots(1,1,figsize=(14,6))

for j in range(1):
 
  if (len(X_train)%batch_size  == 0):
    num_batches = int(len(X_train)/batch_size)
  else:
    num_batches = int((len(X_train)/batch_size)) + 1

  key = jax.random.PRNGKey(42)
  key_spl = jax.random.split(key,2)
  random_number = jax.random.randint(key,(100,), 1, len(y_train))
  random_number = np.asarray(random_number)
  for i in tqdm(range(max_iters)):
    
    # X_, Y_ = sklearn.utils.shuffle(X_train, y_train, random_state = random_number[i])
    # Y_ = Y_.reshape(-1,1)
    # # print(X_.shape,Y_.shape)

    # batch_index = 0
    for k in range(num_batches):

      opt_state1 = tx1.init((theta_init["len_scale"]))
      opt_state2 = tx1.init((theta_init["varf"]))
      opt_state3 = tx1.init((theta_init["vary"]))

      ## Random batches
      # if batch_index+batch_size > len(X_):
      #    X_batch, Y_batch = X_[batch_index:,:], Y_[batch_index:,:]
      # else:
      #   X_batch, Y_batch = X_[batch_index:batch_index+batch_size,:], Y_[batch_index:batch_index+batch_size,:]

      # batch_index += batch_size

      ## NN batches
      center_idx  = jax.random.randint(key_spl[0],(1,), 1, len(y_train))
      nn_batch_indices =  neigh_idx[center_idx,]
      nn_batch_X  = X_train[nn_batch_indices,]
      nn_batch_y  = y_train[nn_batch_indices,]
  

      # loss_rs,grads_rs = nll_gradient(theta_init,X_batch, Y_batch)
      loss_nn,grads_nn = nll_gradient(theta_init,nn_batch_X, nn_batch_y)

      nll_iters.append(loss_nn)

      updates1,opt_state1 = tx1.update(grads_nn["len_scale"], opt_state1)
      theta_init["len_scale"] = optax.apply_updates((theta_init["len_scale"]), updates1)
      updates2,opt_state2 = tx1.update(((batch_size*grads_nn["varf"])/(3*jnp.log(batch_size))), opt_state2)
      theta_init["varf"] = optax.apply_updates((theta_init["varf"]), updates2)
      updates3,opt_state3 = tx1.update(grads_nn["vary"], opt_state3)
      theta_init["vary"] = optax.apply_updates((theta_init["vary"]), updates3)

  end = time.time()
  f = open("train.log", "a")
  f.write("Time taken for training model:SGDgp_levy"+"_"+str(max_iters)+".pt: "+str(end-start)+"\n")
  f.close()
  print(loss_nn)
  plt.plot(nll_iters)
  plt.savefig("levy_loss.png")
  
print(theta_init["len_scale"], theta_init["varf"], theta_init["vary"])

In [None]:
gp = build_gp(theta_init, X_train)
y_pred = gp.condition(y_train, X_test).gp.loc.reshape(y_test.shape)
pd.DataFrame([i.item() for i in y_pred]).to_csv("/results/table_levy"+"_"+str(max_iters)+".csv")
f = open("./logs/test.log", "a")
f.write("RMSE on test data for model  SGDgp_levy"+"_"+str(max_iters)+".pt: "+str(mean_squared_error(y_pred,y_test)**0.5)+"\n")
f.close()