In [1]:
import math
import torch
import gpytorch
import tqdm
from matplotlib import pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2



In [2]:
train_x = torch.linspace(0, 1, 100)

train_y = torch.stack([
    torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
], -1)

In [3]:
class MultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        # We have to mark the CholeskyVariationalDistribution as batch
        # so that we learn a variational distribution for each task
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([2])
        )

        # We have to wrap the VariationalStrategy in a MultitaskVariationalStrategy
        # so that the output will be a MultitaskMultivariateNormal rather than a batch output
        variational_strategy = gpytorch.variational.MultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ), num_tasks=2
        )

        super().__init__(variational_strategy)

        # The mean and covariance modules should be marked as batch
        # so we learn a different set of hyperparameters
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([2]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([2])),
            batch_shape=torch.Size([2])
        )

    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


# The shape of the inducing points should be (2 x m x 1) - so that we learn different inducing
# points for each output
inducing_points = torch.rand(2, 16, 1)
model = MultitaskGPModel(inducing_points)

# We're going to use a multitask likeihood with this model
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=2)

In [4]:
model(train_x)


MultitaskMultivariateNormal(loc: torch.Size([200]))

In [5]:
model(train_x).rsample().shape


torch.Size([100, 2])

In [8]:
likelihood.noise

tensor([0.0330], grad_fn=<AddBackward0>)

In [6]:
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
num_epochs = 1 if smoke_test else 20


model.train()
likelihood.train()

# We use SGD here, rather than Adam. Emperically, we find that SGD is better for variational regression
optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
], lr=0.01)

# Training loader
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_x, train_y))

# Our loss object. We're using the VariationalELBO, which essentially just computes the ELBO
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))

# We use more CG iterations here because the preconditioner introduced in the NeurIPS paper seems to be less
# effective for VI.
epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc="Epoch")
for i in epochs_iter:
    # Within each iteration, we will go over each minibatch of data
    minibatch_iter = tqdm.notebook.tqdm(train_loader, desc="Minibatch", leave=False)
    for x_batch, y_batch in minibatch_iter:
        optimizer.zero_grad()
        output = model(x_batch)
        loss = -mll(output, y_batch)
        minibatch_iter.set_postfix(loss=loss.item())
        loss.backward()
        optimizer.step()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=20.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

HBox(children=(FloatProgress(value=0.0, description='Minibatch', style=ProgressStyle(description_width='initia…

KeyboardInterrupt: 