In [None]:
import numpy as np
import torch
from torch.autograd import Variable
from torch.nn import Parameter
from torch.nn.functional import softplus
from torch.distributions import MultivariateNormal
from torch.distributions.kl import kl_divergence
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal


def RBFKernel(input_dim, variance=1., lengthscale=1.):
    """
    Takes Pytorch input_dim and returns the result of the RBF kernel function
    """
    return variance * torch.exp(-0.5 * torch.cdist(input_dim / lengthscale, input_dim / lengthscale) ** 2)



class SVGP():
    def __init__(self, X, Y, Z, batchsize=None):
        """
        Stochastic Variational GP.

        For Gaussian Likelihoods

        Based on Gaussian Processes for Big data, Hensman, Fusi and Lawrence, UAI 2013,

        But with natural gradients. This is a stochastic version of the SVGP model.
        """
        self.kernel =  RBFKernel
        self.batchsize = batchsize
        self.X_all, self.Y_all = X, Y
        if batchsize is None:
            X_batch, Y_batch = X, Y
        else:
            self.slicer = iter(self._batch_slices())
            X_batch, Y_batch = self.new_batch()

        #assume the number of latent functions is one per col of Y unless specified

        self.q_u_mean = Parameter(torch.randn(Z.shape[0], Y.shape[1]))

        self.mean_function = lambda x: 0
        self.Z = Z

    def _batch_slices(self):
        for i in range(0, len(self.X_all), self.batchsize):
            yield slice(i, i + self.batchsize)

    def _grads(self, parameters):
        """
        Compute the gradients of the parameters
        """
        #compute the gradients of the parameters
        grads = torch.autograd.grad(self.objective(), parameters, create_graph=True)
        return grads


    def fit(self, num_iter=1000, lr=0.01):
        optimizer = torch.optim.Adam([self.q_u_mean, self.kernel.variance, self.kernel.lengthscale, self.chol_var, self.chol_covar], lr=lr)
        for i in range(num_iter):
            optimizer.zero_grad()
            loss = self.objective()
            loss.backward()
            optimizer.step()
            print(f'Iter {i}, Loss: {loss.item()}')

    def objective(self):
        """
        Compute the objective function
        """
        #compute the kernel matrix
        K = self.kernel(self.Z)
        K += torch.eye(K.shape[0]) * 1e-6
        L = torch.cholesky(K)

        #compute the mean function
        mean = self.mean_function(self.X)

        #compute the kernel matrix between the inducing points and the data
        Kmn = self.kernel(self.Z, self.X)

        #compute the kernel matrix between the data
        Kmm = self.kernel(self.X)

        #compute the mean of the latent function
        mu = Kmn @ torch.solve(self.Y - mean, L)[0]

        #compute the covariance of the latent function
        A = torch.solve(Kmn.t(), L)[0]
        Sigma = Kmm - A.t() @ A

        #compute the likelihood
        likelihood = self.likelihood(mu, Sigma)

        #compute the KL divergence
        q_u = MultivariateNormal(self.q_u_mean, self.q_u_covar)
        kl = kl_divergence(q_u, MultivariateNormal(torch.zeros_like(self.q_u_mean), K))

        return -likelihood.log_prob(self.Y).sum() + kl.sum()
    
    def likelihood(self, mu, Sigma):
        """
        Compute the normal likelihood of the data given the latent function
        """
        return Normal(mu, Sigma)
    
    def set_data(self, X, Y):
        """
        Set the data without calling parameters_changed to avoid wasted computation
        If this is called by the stochastic_grad function this will immediately update the gradients
        """
        assert X.shape[1]==self.Z.shape[1]
        self.X, self.Y = X, Y

    def new_batch(self):
        """
        Return a new batch of X and Y by taking a chunk of data from the complete X and Y
        """
        i = next(self.slicer)
        return self.X_all[i], self.Y_all[i]

    def stochastic_grad(self, parameters):
        self.set_data(*self.new_batch())
        return self._grads(parameters)

