In [None]:
import torch
import numpy as np
import pandas as pd
import gpytorch as gp
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

device = 'cuda:0' if torch.cuda.is_available() else None

sns.set(font_scale=2.0, style='whitegrid')

In [None]:
class KeOpsModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y, nu=None, min_noise=1e-4):
        assert train_x.is_contiguous(), 'Need contiguous x for KeOps'

        likelihood = gp.likelihoods.GaussianLikelihood(
                      noise_constraint=gp.constraints.GreaterThan(min_noise))
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gp.means.ConstantMean()
        self.base_covar_module = gp.kernels.keops.MaternKernel(nu=nu) \
          if nu is not None else gp.kernels.keops.RBFKernel()
        self.covar_module = gp.kernels.ScaleKernel(self.base_covar_module)

    def forward(self, x):
        assert x.is_contiguous(), 'Need contiguous x for KeOps'

        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)

In [None]:
df = pd.read_csv('snelson.csv')
train_x = torch.from_numpy(df.x.values[:, np.newaxis]).float().to(device)[:100]
train_y = torch.from_numpy(df.y.values).float().to(device)[:100]
train_x.shape, train_y.shape

In [None]:
def train(x, y, model, mll, optim):
    model.train()

    optim.zero_grad()

    output = model(x)
    loss = -mll(output, y)

    loss.backward()
    optim.step()

    return { 'train/mll': -loss.detach().item() }

def get_f_samples(model):
    model.eval()

    x = torch.linspace(-1., 7., 200).to(device).unsqueeze(-1)
    pred = model(x)

    return x, pred.sample(torch.Size([10]))

In [None]:
model = KeOpsModel(train_x, train_y).to(device)
mll = gp.mlls.ExactMarginalLogLikelihood(model.likelihood, model)

prior_x, prior_f = get_f_samples(model)

In [None]:
optim = torch.optim.Adam(model.parameters(), lr=.1)
for i in tqdm(range(50)):
    print(train(train_x, train_y, model, mll, optim))

post_x, post_f = get_f_samples(model)

In [None]:
def plot_fns(x, f):
    y = torch.cat([x.expand(-1, 10).permute(1, 0).unsqueeze(-1), f.unsqueeze(-1)], axis=-1)

    viz_data = []
    for i in range(10):
        for idx in range(200):
            viz_data.append({ 'id': i, 'x': y[i][idx][0].item(), 'y': y[i][idx][1].item() })
    viz_data = pd.DataFrame(viz_data)
    fig, ax = plt.subplots(figsize=(11,7))
    sns.lineplot(ax=ax, data=viz_data, x='x', y='y', hue='id', legend=False, 
                 palette=sns.color_palette('husl', 10))

    return fig, ax

In [None]:
fig_prior, ax_prior = plot_fns(prior_x, prior_f)
sns.scatterplot(ax=ax_prior, x=train_x.squeeze(-1).cpu().numpy(), y=train_y.cpu().numpy(),
                color='black')

fig_post, ax_post = plot_fns(post_x, post_f)
sns.scatterplot(ax=ax_post, x=train_x.squeeze(-1).cpu().numpy(), y=train_y.cpu().numpy(),
                color='gray')

ax_prior.set_title('Prior')
ax_prior.set_yticks(np.arange(-2, 1.1))
ax_post.set_title('Posterior')
ax_post.set_yticks(np.arange(-2, 1.1))

fig_prior.savefig('prior.pdf', bbox_inches='tight')
fig_post.savefig('post.pdf', bbox_inches='tight')