In [None]:
import os
import torch
import gpytorch as gp
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

if os.path.abspath('..') not in sys.path:
    sys.path.insert(0, os.path.abspath('..'))

device = 'cuda:0' if torch.cuda.is_available() else None

torch.cuda.set_device(device)

In [None]:
from gpytorch_lattice_kernel import MaternLattice, RBFLattice

class SimplexGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y, nu=None, order=1, min_noise=1e-4):
        likelihood = gp.likelihoods.GaussianLikelihood(
                      noise_constraint=gp.constraints.GreaterThan(min_noise))
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gp.means.ConstantMean()
        self.base_covar_module = MaternLattice(nu=nu, order=order) \
          if nu is not None else RBFLattice(order=order)
        self.covar_module = gp.kernels.ScaleKernel(self.base_covar_module)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)


class KeOpsModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y, nu=None, min_noise=1e-4):
        assert train_x.is_contiguous(), 'Need contiguous x for KeOps'

        likelihood = gp.likelihoods.GaussianLikelihood(
                      noise_constraint=gp.constraints.GreaterThan(min_noise))
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gp.means.ConstantMean()
        self.base_covar_module = gp.kernels.keops.MaternKernel(nu=nu) \
          if nu is not None else gp.kernels.keops.RBFKernel()
        self.covar_module = gp.kernels.ScaleKernel(self.base_covar_module)

    def forward(self, x):
        assert x.is_contiguous(), 'Need contiguous x for KeOps'

        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)

In [None]:
# from experiments.utils import prepare_dataset

# data_iter = prepare_dataset('precipitation3d_all', uci_data_dir=None, device=device)
# _, train_x, train_y = next(data_iter)

## Toy Model on Snelson Dataset

In [None]:
df = pd.read_csv('snelson.csv')
train_x = torch.from_numpy(df.x.values[:, np.newaxis]).float().to(device)
train_y = torch.from_numpy(df.y.values).float().to(device)
train_x.shape, train_y.shape

In [None]:
def train(x, y, model, mll, optim, lanc_iter=100, pre_size=100):
  model.train()

  optim.zero_grad()

  with gp.settings.cg_tolerance(1e-2), \
       gp.settings.max_preconditioner_size(pre_size), \
       gp.settings.max_root_decomposition_size(lanc_iter):
    output = model(x)
    loss = -mll(output, y)

    loss.backward()

    for k, p in enumerate(model.parameters()):
        print(f'[{k}] [{p}] --> {p.grad}')

    optim.step()

  return {
    'train/mll': -loss.detach().item(),
  }

In [None]:
from copy import deepcopy

toy_model = SimplexGPModel(train_x, train_y).to(device)
# toy_model = KeOpsModel(train_x, train_y).to(device)
toy_mll = gp.mlls.ExactMarginalLogLikelihood(toy_model.likelihood, toy_model)
optimizer = torch.optim.Adam(toy_model.parameters(), lr=0.1)

for i in tqdm(range(100)):
    with torch.no_grad():
        toy_state_dict = deepcopy(toy_model.state_dict()) ## clone before it changes the reference.
    print(train(train_x, train_y, toy_model, toy_mll, optimizer))

## Autograd

This currently uses the approximation to the gradient, defined as another collection of filtering operations.

In [None]:
# with gp.settings.cg_tolerance(1e-2), \
#     gp.settings.max_preconditioner_size(100), \
#     gp.settings.max_root_decomposition_size(100):
#     model = BilateralGPModel(train_x, train_y, nu=1.5, order=1).to(device)
#     # model = KeOpsModel(train_x, train_y, nu=1.5).to(device)
#     model.base_covar_module.lengthscale = 1.0
#     mll = gp.mlls.ExactMarginalLogLikelihood(model.likelihood, model)

#     loss = -mll(model(train_x), train_y)
#     print(f'Loss: {loss.item()}')

#     loss.backward()

# for idx, p in enumerate(model.parameters()):
#     print(f'[{idx}] {p} ---> {p.grad}')

## Finite Difference

In [None]:
# import torch.nn.functional as F
# import torch.nn as nn

# def f(raw_ell):
#     with gp.settings.cg_tolerance(1e-2), \
#         gp.settings.max_preconditioner_size(100), \
#         gp.settings.max_root_decomposition_size(100), torch.no_grad():
#         model = BilateralGPModel(train_x, train_y, nu=1.5, order=1).to(device)
#         model.base_covar_module.raw_lengthscale = nn.Parameter(raw_ell)

#         mll = gp.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
#         loss = -mll(model(train_x), train_y)

#     print(f'Raw Lengthscale: {model.base_covar_module.raw_lengthscale}; Loss: {loss}')
#     return loss

# raw_ell = model.base_covar_module.raw_lengthscale
# eps = 1e-4
# grads = []

# for _ in range(5):
#     grad = (f(raw_ell + eps) - f(raw_ell - eps)) / (2. * eps)
#     print(f'Finite Diff: {grad}')
#     grads.append(grad.item())
#     torch.cuda.empty_cache()

# print(f'{np.mean(grads)} +/- {2 * np.std(grads)}')

## JVP Checks

In [None]:
# from gpytorch.kernels.keops import RBFKernel, MaternKernel
# from gpytorch_lattice_kernel import MaternLattice, RBFLattice

# # K_gt = MaternKernel(nu=1.5).to(device)
# # f_gt = lambda x, y: K_gt(x, x) @ y

# K_lattice = RBFLattice(order=1).to(device)
# f_lattice = lambda x, y: (K_lattice(x, x) @ y).sum()
# train_x = torch.randn(5, 1).to(device).requires_grad_(True)
# train_y = torch.rand(5, 1).to(device).requires_grad_(False)
# torch.autograd.gradcheck(f_lattice, (train_x, train_y), eps=1e-4, rtol=1e-2, atol=1e-2)

## Toy Finite-Diff Gradients at Convergence

In [None]:
import torch.nn as nn

def f(raw_ell, state_dict):
    with gp.settings.cg_tolerance(1e-2), \
        gp.settings.max_preconditioner_size(100), \
        gp.settings.max_root_decomposition_size(100), torch.no_grad():
        model = BilateralGPModel(train_x, train_y, nu=1.5, order=1).to(device)
        model.load_state_dict(state_dict)
        model.base_covar_module.raw_lengthscale = nn.Parameter(raw_ell)

        print(list(model.parameters()))

        mll = gp.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
        loss = -mll(model(train_x), train_y)

    print(f'Raw Lengthscale: {model.base_covar_module.raw_lengthscale}; Loss: {loss}')
    return loss

with torch.no_grad():
    raw_ell = toy_state_dict['covar_module.base_kernel.raw_lengthscale']
    eps = 1e-5
    grads = []

    for _ in range(5):
        grad = (f(raw_ell + eps, toy_state_dict) - f(raw_ell - eps, toy_state_dict)) / (2. * eps)
        print(f'Finite Diff: {grad}')
        grads.append(grad.item())
        torch.cuda.empty_cache()

    print(f'{np.mean(grads)} +/- {2 * np.std(grads)}')