In [None]:
import sys
import os


if os.path.abspath('..') not in sys.path:
  sys.path.insert(0, os.path.abspath('..'))

In [None]:
from tqdm.auto import tqdm
import torch
import gpytorch as gp
import altair as alt
import pandas as pd
import numpy as np

from bi_gp.bilateral_kernel import BilateralKernel

## GP Models

In [None]:
class ExactGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y):
        likelihood = gp.likelihoods.GaussianLikelihood()
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gp.means.ConstantMean()
        self.covar_module = gp.kernels.ScaleKernel(gp.kernels.RBFKernel(ard_num_dims=train_x.size(-1)))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)


class SGPRModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y, inducing_points):
        likelihood = gp.likelihoods.GaussianLikelihood()
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gp.means.ConstantMean()
        self.base_covar_module = gp.kernels.ScaleKernel(gp.kernels.RBFKernel(ard_num_dims=train_x.size(-1)))
        self.covar_module = gp.kernels.InducingPointKernel(
          self.base_covar_module, inducing_points=inducing_points, likelihood=likelihood)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)


class KISSGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y, grid_size):
        likelihood = gp.likelihoods.GaussianLikelihood()
        super().__init__(train_x, train_y, likelihood)

        if not isinstance(grid_size, int):
          grid_size = gp.utils.grid.choose_grid_size(train_x)

        self.mean_module = gp.means.ConstantMean()
        self.covar_module = gp.kernels.ScaleKernel(
            gp.kernels.GridInterpolationKernel(
                gp.kernels.RBFKernel(ard_num_dims=train_x.size(-1)), grid_size=grid_size, num_dims=train_x.size(-1)
            )
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)


class SKIPGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y, grid_size):
        likelihood = gp.likelihoods.GaussianLikelihood()
        super().__init__(train_x, train_y, likelihood)
        
        self.mean_module = gp.means.ConstantMean()
        self.base_covar_module = gp.kernels.RBFKernel()
        self.covar_module = gp.kernels.ProductStructureKernel(
            gp.kernels.ScaleKernel(
                gp.kernels.GridInterpolationKernel(self.base_covar_module, grid_size=grid_size, num_dims=1)
            ), num_dims=train_x.size(-1)
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)


class BilateralGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y):
        likelihood = gp.likelihoods.GaussianLikelihood()
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gp.means.ConstantMean()
        self.covar_module = gp.kernels.ScaleKernel(BilateralKernel(ard_num_dims=train_x.size(-1)))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)

### Utils

In [None]:
def train(x, y, model, mll, optim):
  model.train()

  optim.zero_grad()

  output = model(x)

  loss = -mll(output, y)

  loss.backward()

  optim.step()

  return {
    'train/ll': -loss.detach().item()
  }


def test(x, y, model, lanc_iter=100, pre_size=0):
  model.eval()

  with torch.no_grad():
#        gp.settings.max_preconditioner_size(pre_size), \
#        gp.settings.max_root_decomposition_size(lanc_iter), \
#        gp.settings.fast_pred_var():
      preds = model(x)

      pred_y = model.likelihood(model(x))
      rmse = (pred_y.mean - y).pow(2).mean(0).sqrt()

  return {
    'test/rmse': rmse.item()
  }


def train_util(model, x, y, lr=0.1, epochs=200):
  mll = gp.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
  optim = torch.optim.Adam(model.parameters(), lr=lr)

  for _ in tqdm(range(epochs)):
    train(x, y, model, mll, optim)


def generate(model, device=None):
  model.eval()
  
  x = torch.arange(-1., 7., 0.05).unsqueeze(-1).to(device)
  
  with torch.no_grad():
    preds = model(x)
    
    pred_y = model.likelihood(model(x))
  
  return pd.DataFrame({
    'x': x.squeeze(-1).cpu().numpy(),
    'y': pred_y.mean.cpu().numpy(),
    'y_hi': pred_y.mean.cpu().numpy() + 2. * pred_y.variance.sqrt().cpu().numpy(),
    'y_lo': pred_y.mean.cpu().numpy() - 2. * pred_y.variance.sqrt().cpu().numpy(),
  })


def chart_util(model, color, device=None):
  cdata = generate(model, device=device)

  mean = alt.Chart(cdata).mark_line(color=color,opacity=1.0,strokeDash=[5,5]).encode(x='x', y='y')
  err1 = mean.mark_line(color=color,opacity=0.5).encode(x='x', y='y_lo')
  err2 = mean.mark_line(color=color,opacity=0.5).encode(x='x', y='y_hi')
  
  return mean + err1 + err2

## Snelson 1-D Dataset

In [None]:
device = "cuda" if torch.cuda.is_available() else None

In [None]:
dpath = os.path.join('snelson.csv')
snel = pd.read_csv(dpath)
snel = snel.iloc[np.random.permutation(range(snel.shape[0]))[:100]]
snel_x = torch.from_numpy(snel.x.to_numpy()).unsqueeze(-1).float().to(device)
snel_y = torch.from_numpy(snel.y.to_numpy()).float().to(device)

snelc = alt.Chart(snel).mark_circle(color='black',opacity=0.6).encode(x='x', y='y')

### Exact GP

In [None]:
egp = ExactGPModel(snel_x, snel_y).float().to(device)

train_util(egp, snel_x, snel_y)

egp_gen = generate(egp, device=device)

snelc_egp = alt.Chart(egp_gen).mark_line(color='black', opacity=0.5).encode(x='x', y='y')
snelc_egp_err = snelc_egp.mark_area(opacity=0.2,color='grey').encode(y='y_lo', y2='y_hi')

egp_chart = snelc + snelc_egp + snelc_egp_err

### Bilateral GP

In [None]:
bigp = BilateralGPModel(snel_x, snel_y).float().to(device)

train_util(bigp, snel_x, snel_y)

bigp_chart = chart_util(bigp, color='red', device=device)

## Comparisons

**NOTE**: Enable each raw cell as needed.

### Sparse GP (Titsias)

In [None]:
sgp = SGPRModel(snel_x, snel_y, (6. * torch.rand(500, 1)).float()).float()

train_util(sgp, snel_x, snel_y)

sgp_chart = chart_util(sgp, color='blue')

In [None]:
ind_chart = alt.Chart(pd.DataFrame({ 'x': sgp.covar_module.inducing_points.detach().squeeze(-1).clamp(0.0, 6.0).numpy(),
                                     'y': -2.5,  }))\
                      .mark_circle(color='blue').encode(x='x', y='y')
(egp_chart + sgp_chart + ind_chart).properties(title='Sparse GP') |\
(egp_chart + bigp_chart).properties(title='Bilateral GP')

### KISS-GP

In [None]:
kgp = KISSGPModel(snel_x, snel_y, 30).float().to(device)

train_util(kgp, snel_x, snel_y)

kgp_chart = chart_util(kgp, color='blue', device=device)

In [None]:
kgp_grid_chart = alt.Chart(pd.DataFrame({ 'x': list(kgp.covar_module.sub_kernels())[0].grid[0].cpu().numpy(), 'y': -2.5,  }))\
                       .mark_circle(color='blue').encode(x='x', y='y')

(egp_chart + kgp_chart + kgp_grid_chart).properties(title='KISS-GP') |\
(egp_chart + bigp_chart).properties(title='Bilateral GP')

### SKIP-GP

In [None]:
skipgp = SKIPGPModel(snel_x, snel_y, 30).float().to(device)

train_util(skipgp, snel_x, snel_y)

skipgp_chart = chart_util(skipgp, color='blue', device=device)

In [None]:
skipgp_grid_chart = alt.Chart(pd.DataFrame({ 'x': list(list(skipgp.covar_module.sub_kernels())[0].sub_kernels())[0].grid[0].cpu().numpy(), 'y': -2.5,  }))\
                       .mark_circle(color='blue').encode(x='x', y='y')

(egp_chart + skipgp_chart + skipgp_grid_chart).properties(title='SKIP-GP') |\
(egp_chart + bigp_chart).properties(title='Bilateral GP')