In [None]:
import torch
import numpy as np
import pandas as pd
import gpytorch as gp
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

device = 'cuda:0' if torch.cuda.is_available() else None

sns.set(font_scale=2.0, style='whitegrid')

In [None]:
class ExactGPModel(gp.models.ExactGP):
    def __init__(self, train_x, train_y):
        likelihood = gp.likelihoods.GaussianLikelihood()
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gp.means.ZeroMean()
        self.covar_module = gp.kernels.ScaleKernel(gp.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gp.distributions.MultivariateNormal(mean_x, covar_x)

In [None]:
df = pd.read_csv('snelson.csv')
train_x = torch.from_numpy(df.x.values[:, np.newaxis]).float().to(device)[:10]
train_y = torch.from_numpy(df.y.values).float().to(device)[:10]

train_x = (train_x - train_x.mean(dim=0, keepdim=True)) / (train_x.std(dim=0, keepdim=True) + 1e-6)
train_y = (train_y - train_y.mean(dim=0, keepdim=True)) / (train_y.std(dim=0, keepdim=True) + 1e-6)

train_x.shape, train_y.shape

In [None]:
model = ExactGPModel(train_x, train_y).to(device)
with torch.no_grad():
    all_x = torch.linspace(-3., 3., 200).to(device).unsqueeze(-1)
    prior = model.forward(all_x)

In [None]:
def train(x, y, model, mll, optim):
    model.train()

    optim.zero_grad()

    output = model(x)
    loss = -mll(output, y)

    loss.backward()
    optim.step()

    return { 'train/mll': -loss.detach().item() }

optim = torch.optim.Adam(model.parameters(), lr=.1)
mll = gp.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
for i in tqdm(range(50)):
    print(train(train_x, train_y, model, mll, optim))

model.eval()
with torch.no_grad():
    posterior = model(all_x)

In [None]:
def plot_fns(x, f):
    y = torch.cat([x.expand(-1, 3).permute(1, 0).unsqueeze(-1), f.unsqueeze(-1)], axis=-1)

    viz_data = []
    for i in range(3):
        for idx in range(200):
            viz_data.append({ 'id': i, 'x': y[i][idx][0].item(), 'y': y[i][idx][1].item() })
    viz_data = pd.DataFrame(viz_data)
    fig, ax = plt.subplots(figsize=(11,7))
    sns.lineplot(ax=ax, data=viz_data, x='x', y='y', hue='id', legend=False, 
                 palette=sns.color_palette('Set1', 3), alpha=.7)

    return fig, ax

fig_prior, ax_prior = plot_fns(all_x, prior.sample(torch.Size([3])))
ax_prior.plot(all_x.cpu().numpy().flatten(), prior.mean.cpu().numpy(), linestyle=(0, (10,5)),
              color='black', alpha=.6, linewidth=3)
with torch.no_grad():
    ax_prior.fill_between(all_x.cpu().numpy().flatten(),
                        prior.mean.cpu().numpy() - 2. * prior.variance.sqrt().cpu().numpy(),
                        prior.mean.cpu().numpy() + 2. * prior.variance.sqrt().cpu().numpy(),
                        color='grey', alpha=.15)
ax_prior.set_title('Prior')
ax_prior.set_yticks(np.arange(-2, 2.1))
ax_prior.set_ylim([-2.5,2.5])

fig_post, ax_post = plot_fns(all_x, posterior.sample(torch.Size([3])))
ax_post.plot(all_x.cpu().numpy().flatten(), posterior.mean.cpu().numpy(), linestyle=(0, (10,5)),
             color='black', alpha=.6, linewidth=3)
with torch.no_grad():
    ax_post.fill_between(all_x.cpu().numpy().flatten(),
                        posterior.mean.cpu().numpy() - 2. * posterior.variance.sqrt().cpu().numpy(),
                        posterior.mean.cpu().numpy() + 2. * posterior.variance.sqrt().cpu().numpy(),
                        color='grey', alpha=.15)
sns.scatterplot(ax=ax_post, x=train_x.squeeze(-1).cpu().numpy(), y=train_y.cpu().numpy(),
                color='red', s=100, edgecolor='black', linewidth=1)
ax_post.set_title('Posterior')
ax_post.set_yticks(np.arange(-2, 2.1))
ax_post.set_ylim([-2.5,2.5]);

fig_prior.savefig('prior.pdf', bbox_inches='tight')
fig_post.savefig('post.pdf', bbox_inches='tight')