In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import logging

module_path = os.path.abspath(os.path.join("../../.."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from pvi.models.sgp import SparseGaussianProcessModel
from pvi.utils.gaussian import mvstandard2natural, mvnatural2standard
from pvi.clients.synchronous_client import SynchronousClient
from pvi.distributions.exponential_family_distributions import MultivariateGaussianDistribution
from pvi.distributions.exponential_family_factors import MultivariateGaussianFactor

import torch
import numpy as np
import matplotlib.pyplot as plt
import gpytorch
import tqdm.auto as tqdm

from torch import nn
from gpytorch.kernels import ScaleKernel, RBFKernel

%matplotlib inline
torch.set_default_dtype(torch.float64)

# Set up data and helper functions

In [None]:
x = np.linspace(-1, 1, 50)
y = 2 * np.sin(5*x) + 3 * np.abs(x) * np.random.rand(len(x))

x = torch.tensor(x).unsqueeze(1)
y = torch.tensor(y).unsqueeze(1)

In [None]:
def plot_data(x, y):
    plt.figure()
    plt.grid(b=True)
    plt.scatter(x, y)
    plt.show()
    
def plot_results(x, y, model, q):
    pp = model(x, q)
    mean = pp.mean.detach()
    std = pp.variance.detach() ** 0.5
    
    samples = pp.sample((20,))
    
    plt.figure()
    plt.grid(b=True)
    
    for sample in samples:
        plt.plot(x, sample, color='k', alpha=.1)
        
    plt.plot(x.squeeze(-1), mean)
    plt.fill_between(x.squeeze(-1), mean-1.96*std, mean+1.96*std, alpha=.25)
    plt.scatter(x, y)
    plt.show()
    
def plot_training(training_array):
    x_vals = np.arange(1, len(training_array)+1)
    plt.figure()
    plt.grid(b=True)
    plt.plot(x_vals, training_array)
    plt.ylabel('ELBO Loss')
    plt.xlabel('Step')
    plt.show()
    
data = {
    "x": x,
    "y": y,
}

In [None]:
plot_data(x, y)

# Construct SGP model

In [None]:
hyperparameters = {
    "D": 1,
    "num_inducing": 5,
    "kernel_class": RBFKernel,
    "kernel_params": {"lengthscale": .5},
    "epochs": 500,
    "optimiser_params": {"lr": 1e-3},
    "batch_size": 50
}

inducing_locations = x[::10].clone()

model = SparseGaussianProcessModel(inducing_locations=inducing_locations, 
                                   output_sigma=1., hyperparameters=hyperparameters)

q = MultivariateGaussianDistribution(
    nat_params={
        "np1": torch.zeros(hyperparameters["num_inducing"]),
        "np2": -0.5 * torch.eye(hyperparameters["num_inducing"])
    }
)

In [None]:
plot_results(x, y, model, q)

# Fit data

In [None]:
t = MultivariateGaussianFactor(
    nat_params = {
        "np1": torch.tensor([0.]*model.hyperparameters["num_inducing"]),
        "np2": torch.tensor([0.]*model.hyperparameters["num_inducing"]).diag_embed()
    }
)

In [None]:
# Construct synchronous client.
client = SynchronousClient(data=data, model=model, t=t)

## Without optimising model (hyper-)parameters

In [None]:
q_new, t_new = client.update_q(q)

In [None]:
plot_results(x, y, client.model, q_new)