In [None]:
import sys
import os
from collections import defaultdict


if os.path.abspath('..') not in sys.path:
    sys.path.insert(0, os.path.abspath('..'))

In [None]:
from tqdm.auto import tqdm
import torch
import gpytorch as gp
import altair as alt
import pandas as pd
import numpy as np

from bi_gp.bilateral_kernel import BilateralKernel

In [None]:
class ExactGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y):
        likelihood = gp.likelihoods.GaussianLikelihood()
        super().__init__(train_x, train_y, likelihood)
        # self.mean_module = gp.means.ConstantMean()
        # self.covar_module = gp.kernels.ScaleKernel(gp.kernels.RBFKernel(ard_num_dims=train_x.size(-1)))
        self.mean_module = gp.means.ZeroMean()
        self.covar_module = gp.kernels.RBFKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)
      
class BilateralGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y):
        likelihood = gp.likelihoods.GaussianLikelihood()
        super().__init__(train_x, train_y, likelihood)
        # self.mean_module = gp.means.ConstantMean()
        # self.covar_module = gp.kernels.ScaleKernel(BilateralKernel(ard_num_dims=train_x.size(-1)))
        self.mean_module = gp.means.ZeroMean()
        self.covar_module = BilateralKernel()
        
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)


class SKIPGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y, grid_size):
        likelihood = gp.likelihoods.GaussianLikelihood()
        super().__init__(train_x, train_y, likelihood)
        
#         self.mean_module = gp.means.ConstantMean()
#         self.base_covar_module = gp.kernels.RBFKernel()
#         self.covar_module = gp.kernels.ProductStructureKernel(
#             gp.kernels.ScaleKernel(
#                 gp.kernels.GridInterpolationKernel(self.base_covar_module, grid_size=grid_size, num_dims=1)
#             ), num_dims=train_x.size(-1)
#         )

        self.mean_module = gp.means.ZeroMean()
        self.base_covar_module = gp.kernels.RBFKernel()
        self.covar_module = gp.kernels.ProductStructureKernel(
            gp.kernels.GridInterpolationKernel(self.base_covar_module, grid_size=grid_size, num_dims=1),
            num_dims=train_x.size(-1)
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)



In [None]:
def train(x, y, model, mll, optim):
    model.train()

    optim.zero_grad()

    output = model(x)

    loss = -mll(output, y)

    loss.backward()

    optim.step()

    return { 'train/ll': -loss.detach().item() }


def test(x, y, model, lanc_iter=100, pre_size=0):
    model.eval()

    with torch.no_grad():
#        gp.settings.max_preconditioner_size(pre_size), \
#        gp.settings.max_root_decomposition_size(lanc_iter), \
#        gp.settings.fast_pred_var():
        preds = model(x)

        pred_y = model.likelihood(model(x))
        rmse = (pred_y.mean - y).pow(2).mean(0).sqrt()

    return { 'test/rmse': rmse.item() }


def train_util(model, x, y, lr=0.1, epochs=100):
    mll = gp.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
    optim = torch.optim.Adam(model.parameters(), lr=lr)

    for _ in tqdm(range(epochs), leave=False):
        train_dict = train(x, y, model, mll, optim)
    
    return train_dict

## Toy 4-D GP

In [None]:
n = 2000
d = 4
x = 2. * torch.rand(n, d) - 1.

with torch.no_grad():
  covar_module = gp.kernels.ScaleKernel(gp.kernels.RBFKernel())
  params = covar_module.state_dict()
  params['raw_outputscale'] = torch.tensor(1.0).log()
  params['base_kernel.raw_lengthscale'] = torch.Tensor([[1.5]]).log()
  covar_module.load_state_dict(params)

  covar = gp.distributions.MultivariateNormal(torch.zeros(n), covariance_matrix=covar_module(x))

rperm = torch.randperm(n)[:n//2]
train_x = x[rperm]
train_y = (covar.sample() + 0.1 * torch.randn(x.size(0)))[rperm]

#   sample_x = x.squeeze(-1).unsqueeze(0).expand(5, -1).numpy()
#   sample_y = covar.sample(torch.Size([5])).numpy()
#   label = np.repeat(np.array([['a', 'b', 'c', 'd', 'e']]).T, n, axis=1)
#   plot_data = {
#     'x': sample_x.flatten().tolist(),
#     'y': sample_y.flatten().tolist(),
#     'id': label.flatten()
#   }

# alt.Chart(pd.DataFrame(plot_data)).mark_line().encode(x='x', y='y', color='id') +\
# alt.Chart(pd.DataFrame({ 'x': train_x.squeeze(-1).numpy(), 'y': train_y.numpy() })).mark_circle().encode(x='x', y='y')

In [None]:
results = defaultdict(list)

### Exact GP

In [None]:
for _ in tqdm(range(10)):
  egp = ExactGPModel(train_x, train_y).float()

  train_dict = train_util(egp, train_x, train_y)
  
  for name, p in egp.named_parameters():
    results[name].append(p)
  results['kind'].append('Exact GP')
  for k, v in train_dict.items():
    results[k].append(v)

### Bilateral GP

In [None]:
for _ in tqdm(range(10)):
  bigp = BilateralGPModel(train_x, train_y).float()

  with gp.settings.max_root_decomposition_size(50):
    train_dict = train_util(bigp, train_x, train_y)
  
  for name, p in bigp.named_parameters():
    results[name].append(p)
  results['kind'].append('Bilateral GP')
    
  for k, v in train_dict.items():
    results[k].append(v)

In [None]:
for _ in tqdm(range(10)):
  skipgp = SKIPGPModel(train_x, train_y, 100).float()

  with gp.settings.max_root_decomposition_size(50):
    train_dict = train_util(skipgp, train_x, train_y)
  
  for name, p in skipgp.named_parameters():
    results[name].append(p)
  results['kind'].append('SKIP-GP')
    
  for k, v in train_dict.items():
    results[k].append(v)

In [None]:
results.keys()

In [None]:
data = {
  'obs_noise': [v.exp().item() for v in results['likelihood.noise_covar.raw_noise']],
  'ls': [v.exp().item() for v in results['covar_module.raw_lengthscale']] + [v.exp().item() for v in results['base_covar_module.raw_lengthscale']],
  'train/ll': results['train/ll'],
  'kind': results['kind'],
}

In [None]:
error_bars = alt.Chart(pd.DataFrame(data)).mark_errorbar(extent='stdev').encode(
  x=alt.X('obs_noise:Q', scale=alt.Scale(zero=False)),
  y=alt.Y('kind:N')
)

points = alt.Chart(pd.DataFrame(data)).mark_point(filled=True, color='black').encode(
  x=alt.X('obs_noise:Q', aggregate='mean'),
  y=alt.Y('kind:N'),
)

(error_bars + points).properties(width=800,height=100)

In [None]:
error_bars = alt.Chart(pd.DataFrame(data)).mark_errorbar(extent='stdev').encode(
  x=alt.X('ls:Q', scale=alt.Scale(zero=False)),
  y=alt.Y('kind:N')
)

points = alt.Chart(pd.DataFrame(data)).mark_point(filled=True, color='black').encode(
  x=alt.X('ls:Q', aggregate='mean'),
  y=alt.Y('kind:N'),
)

(error_bars + points).properties(width=800,height=100)

In [None]:
error_bars = alt.Chart(pd.DataFrame(data)).mark_errorbar(extent='stdev').encode(
  x=alt.X('train/ll:Q', scale=alt.Scale(zero=False)),
  y=alt.Y('kind:N')
)

points = alt.Chart(pd.DataFrame(data)).mark_point(filled=True, color='black').encode(
  x=alt.X('train/ll:Q', aggregate='mean'),
  y=alt.Y('kind:N'),
)

(error_bars + points).properties(width=800,height=100)