# Train and evaluate a PC

In [None]:
import random
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cpu")  # The device to use, e.g., "cpu", "cuda", "cuda:1"

%load_ext autoreload
%autoreload 2

In [None]:
%reload_ext autoreload

Set the random seeds.

In [None]:
random.seed(4)
np.random.seed(4)
torch.manual_seed(4)
# if 'cuda' in device.type:
#     torch.cuda.manual_seed(42)

## Load MNIST Dataset

Load the training and test splits of MNIST, and preprocess them by flattening the tensor images.

In [None]:
from torchvision import transforms, datasets
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (255 * x.view(-1)).long())
])
data_train = datasets.MNIST('datasets', train=True, download=True, transform=transform)
data_test = datasets.MNIST('datasets', train=False, download=True, transform=transform)
num_variables = data_train[0][0].shape[0]
height, width = 28, 28
print(f"Number of variables: {num_variables}")

In [None]:
plt.matshow(data_train[0][0].reshape(28, 28), cmap='gray')
plt.title(f"Class: {data_train[0][1]}")
plt.show()

## Instantiating the region graph

Initialize a _Quad Graph_ region graph.

In [None]:
from cirkit.region_graph.quad_tree import QuadTree
# region_graph = QuadTree(width, height, struct_decomp=False)
# region_graph = RandomBinaryTree(num_vars=11, depth=2, num_repetitions=1)
region_graph = FullyFactorized(num_vars=18)

In [None]:
region_graph

In [None]:
region_graph._nodes

Others available region graphs are _Poon Domingos_ and _QuadTree_, whose imports are showed below.

In [None]:
from cirkit.region_graph.poon_domingos import PoonDomingos
from cirkit.region_graph.random_binary_tree import RandomBinaryTree
from cirkit.region_graph.fully_factorized import FullyFactorized

## Choosing the layers

Now we have to choose both the input and inner layers of our circuit. As input layer we select the _CategoricalLayer_ with 256 categories (the number of pixel values). For the inner layer instead, we choose the _uncollapsed CP_ layer with rank 1.

In [None]:
from cirkit.layers.input.exp_family import CategoricalLayer
from cirkit.layers.sum_product import CPLayer
from cirkit.layers.input.rbf_kernel import RBFKernelLayer

efamily_cls = RBFKernelLayer
efamily_kwargs = {}
layer_cls = CPLayer
layer_kwargs = {'rank': 1}

## Building the tensorized PC

We can now build our tensorized PC by specifying the region graph and layers we chose previously. In addition, we can scale the architecture by increasing the number of input and inner units. We can also have circuits with multiple output units by choosing _num_classes > 1_. However, in this notebook we only estimate the distribution of the images and marginalize out the class variable.

To ensure weights are non-negative we reparametrize them via exponentiation. Several reparametrization functions are available.

In [None]:
from cirkit.reparams.leaf import ReparamExp, ReparamLogSoftmax, ReparamSoftmax
from cirkit.models.tensorized_circuit import TensorizedPC
pc = TensorizedPC.from_region_graph(
    region_graph,
    num_inner_units=50,
    num_input_units=50,
    efamily_cls=efamily_cls,
    efamily_kwargs=efamily_kwargs,
    layer_cls=layer_cls,
    layer_kwargs=layer_kwargs,
    num_classes=1,
    reparam=ReparamSoftmax # ReparamLogSoftmax # ReparamExp
)
pc.to(device)
print(pc)

In [None]:
for param in pc.parameters(): 
    print (param.shape)

In [None]:
from cirkit.models.rbf_kernel import RBFCircuitKernel

circuit_kernel = RBFCircuitKernel(pc, batch_shape=torch.Size([]))

In [None]:
from cirkit.models.gp import CircuitGP, initial_values

In [None]:
import torch.nn.functional as F

from uci_datasets import Dataset

from ignite.engine import Events, Engine
from ignite.metrics import Average, Loss
from ignite.contrib.handlers import ProgressBar

import gpytorch
from gpytorch.mlls import VariationalELBO
from gpytorch.likelihoods import GaussianLikelihood

import pandas as pd
import numpy as np


In [None]:
data = Dataset("elevators")
x_train, y_train, x_test, y_test = data.get_split(split=9)

In [None]:
x_train.shape, x_test.shape

In [None]:
x_train_real = x_train[:13281] #32000 # 2053   36584    36584     39063   13281    2672   # RE-RUN # 13279   # 1279   4701  824
y_train_real = y_train[:13281]
y_train_real = y_train_real.squeeze()
x_val = x_train[13281:]
y_val = y_train[13281:]
y_val = y_val.squeeze()
y_test = y_test.squeeze()

In [None]:
x_val.shape

In [None]:
import torch.nn as nn

class IdentityMapping(nn.Module):
    def __init__(self):
        super(IdentityMapping, self).__init__()
    
    def forward(self, x):
        return x

In [None]:
np.random.seed(24)
torch.manual_seed(24) ####################### CHANGE

batch_size = 128

# X_train, y_train = make_data(n_samples)
# X_test, y_test = X_train, y_train

# x_train, y_train, x_test, y_test

ds_train = torch.utils.data.TensorDataset(torch.from_numpy(x_train_real).float(), torch.from_numpy(y_train_real).float())
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True) # suffle 

ds_val = torch.utils.data.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val).float())
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=512, shuffle=False)

ds_test = torch.utils.data.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test).float())
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=512, shuffle=False)

# steps = 5e3
epochs = 50
print(f"Training with {len(x_train_real)} datapoints for {epochs} epochs")

# Change this boolean to False for SNGP
DUE = True

input_dim = 18 # input di  # 128
# features = 1024 # hidden    128
# depth = 2   # 4  6
num_outputs = 1 # regression with 1D output
# spectral_normalization = True
# coeff = 0.95
# n_power_iterations = 1
# dropout_rate = 0.01

# feature_extractor = FCResNet(
#     input_dim=input_dim, 
#     features=features, 
#     depth=depth, 
#     spectral_normalization=spectral_normalization, 
#     coeff=coeff, 
#     n_power_iterations=n_power_iterations,
#     dropout_rate=dropout_rate
# )

feature_extractor = IdentityMapping()

if DUE:
    n_inducing_points = 100
    kernel = "HBF" ################# change 
    
    initial_inducing_points, initial_lengthscale = initial_values(
            ds_train, feature_extractor, n_inducing_points
    )

    gp_model = CircuitGP(
        num_outputs=num_outputs,
        num_features=input_dim,          # CHANGE features / input_dim
        initial_lengthscale=initial_lengthscale,
        initial_inducing_points=initial_inducing_points,
        circuit=pc
        # kernel=kernel,
    )

    # model = DKL(feature_extractor, gp)

    likelihood = GaussianLikelihood()
    elbo_fn = VariationalELBO(likelihood, gp_model, num_data=len(ds_train))
    loss_fn = lambda x, y: -elbo_fn(x, y)
    
    # mse_loss_fn = F.mse_loss
# else:
    # Nothing 
#     num_gp_features = 128
#     num_random_features = 1024
#     normalize_gp_features = True
#     feature_scale = 2
#     ridge_penalty = 1
    
#     model = Laplace(feature_extractor,
#                     features,
#                     num_gp_features,
#                     normalize_gp_features,
#                     num_random_features,
#                     num_outputs,
#                     len(ds_train),
#                     batch_size,
#                     ridge_penalty=ridge_penalty,
#                     feature_scale=feature_scale
#                    )

#     loss_fn = F.mse_loss # MSE

if torch.cuda.is_available():
    gp_model = gp_model.cuda()
    if DUE:
        likelihood = likelihood.cuda()

# learning rate   
lr = 1e-3

parameters = [
    {"params": gp_model.parameters(), "lr": lr},
]

if DUE:
    parameters.append({"params": likelihood.parameters(), "lr": lr})
    
    
optimizer = torch.optim.Adam(parameters)
pbar = ProgressBar()

def step(engine, batch):
    gp_model.train()
    if DUE:
        likelihood.train()
    
    optimizer.zero_grad()
    
    x, y = batch
    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()

    y_pred = gp_model(x) # get y
    
    if not DUE:
        y_pred.squeeze_()
    
#     print("y_pred", y_pred)
#     print("y_pred_real", likelihood(y_pred).mean.cpu())
#     print("y", y)
    loss = loss_fn(y_pred, y) # loss
    
    loss.backward()
    optimizer.step()
    
    return loss.item()


def eval_step(engine, batch):
    gp_model.eval() # set to eval
    if DUE:
        likelihood.eval()
    
    x, y = batch
    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()

    y_pred = gp_model(x)
    
    # eval_mes_loss = mse_loss_fn(y_pred, y) # MSE eval
            
    return y_pred, y

    
trainer = Engine(step)
evaluator = Engine(eval_step)

metric = Average()
metric.attach(trainer, "loss")
pbar.attach(trainer)

if DUE:
    metric = Loss(lambda y_pred, y: - likelihood.expected_log_prob(y, y_pred).mean())
    # metric = Loss(lambda y_pred, y: F.mse_loss(likelihood(y_pred).mean.cpu(), y))
else:
    metric = Loss(lambda y_pred, y: F.mse_loss(y_pred[0].squeeze(), y))


metric.attach(evaluator, "loss")

@trainer.on(Events.EPOCH_COMPLETED(every=int(epochs/20) + 1))
def log_results(trainer):
    evaluator.run(dl_val) # val dataset
    print(f"Results - Epoch: {trainer.state.epoch} - "
          f"Val Loss: {evaluator.state.metrics['loss']:.2f} - "
          f"Train Loss: {trainer.state.metrics['loss']:.2f}")

    
if not DUE:
    @trainer.on(Events.EPOCH_STARTED)
    def reset_precision_matrix(trainer):
        gp_model.reset_precision_matrix()

In [None]:
for param in gp_model.parameters(): 
    print(param.shape)
    print(param)

In [None]:
trainer.run(dl_train, max_epochs=epochs)

In [None]:
gp_model.eval()
if DUE:
    likelihood.eval()

all_mse = []
    
with torch.no_grad(), gpytorch.settings.num_likelihood_samples(100):
    
    xx_split = np.array_split(x_test, 40)       ############# CHANGE
    yy_split = np.array_split(y_test, 40)
    
    for index in range(len(xx_split)):
    
        xx = torch.from_numpy(xx_split[index]).float()
        yy = torch.from_numpy(yy_split[index]).float()
        pred_test = gp_model(xx)
        ol = likelihood(pred_test)
        output = ol.mean.cpu()
        mse = F.mse_loss(output, yy)
        all_mse.append(mse)
    
    
average_mse = sum(all_mse) / len(all_mse)
average_mse

In [None]:
pc.input_layer.params.param.shape
# (self.num_vars, self.num_output_units, self.num_replicas, self.num_suff_stats)

In [None]:
pc.scope_layer.scope.shape

In [None]:
pc.inner_layers[0].params_in() #.param #.shape #.param.shape
# (F, H, I, O)
# (fold count, arity, input, output)

In [None]:
from cirkit.models.rbf_kernel import RBFCircuitKernel

circuit_kernel = RBFCircuitKernel(pc, batch_shape=torch.Size([]))


In [None]:
circuit_kernel(x1.squeeze(), x2.squeeze()).evaluate()

In [None]:
x1.squeeze().shape

In [None]:
# set parameters

pc.input_layer.params.param = torch.nn.Parameter(torch.log(torch.ones(tuple(pc.input_layer.params.shape))*3.3))
# pc.inner_layers[0].params_in.param = torch.nn.Parameter(torch.log(0.25*torch.ones(tuple(pc.inner_layers[0].params_in.shape))))
# pc.inner_layers[0].params_in = torch.nn.Parameter(torch.ones(tuple(pc.inner_layers[0].params_in.shape))*3.3)
# pc.inner_layers[1].params_in = torch.nn.Parameter(torch.ones(tuple(pc.inner_layers[1].params_in.shape))*3.3)
# pc.inner_layers[2].params_in = torch.nn.Parameter(torch.ones(tuple(pc.inner_layers[2].params_in.shape))*3.3)
# pc.inner_layers[3].params_in = torch.nn.Parameter(torch.ones(tuple(pc.inner_layers[3].params_in.shape))*3.3)

In [None]:
pc.inner_layers[0].params_in() #.shape

In [None]:
x1 = torch.randn(3, 8, 1)
x2 = torch.randn(3, 8, 1)

In [None]:
pc(x1, x2).squeeze()

In [None]:
def eval_pc(x1, x2): 
    return pc(x1.unsqueeze(-1), x2.unsqueeze(-1)).squeeze(-1)

eval_pc(x1.squeeze(), x2.squeeze())

In [None]:
from gpytorch.kernels import RBFKernel

# x = torch.randn(3, 5)
covar_module = RBFKernel()
covar_module.lengthscale = torch.tensor(3.3)
covar_module(x1.squeeze(), x2.squeeze()).evaluate()

In [None]:
x1.squeeze().shape

In [None]:
from gpytorch.kernels import RBFKernel
x = torch.randn(3, 2)
RBFKernel().lengthscale = torch.tensor(3.3)

In [None]:
# Test RBF input output = RBF kernel 

In [None]:
from gpytorch.kernels import RBFKernel

x = torch.randn(3, 5)
covar_module = RBFKernel()
covar_module.lengthscale = torch.tensor(3.3)
covar_module(x).evaluate()
# covar_module.lengthscale

In [None]:
from cirkit.layers.input.rbf_kernel import RBFKernelLayer
input_la = RBFKernelLayer(num_vars=5, num_output_units=1)

input_la.params = torch.nn.Parameter(torch.ones((5,1))*3.3)

# input_la(x1, x2).squeeze().shape

# input_la(x.unsqueeze(-1), x.unsqueeze(-1)).shape

torch.prod(torch.exp(input_la(x.unsqueeze(-1), x.unsqueeze(-1)).squeeze()), dim=2)

In [None]:
input_la = RBFKernelLayer(num_vars=20, num_output_units=1)

input_la.params = torch.nn.Parameter(torch.ones((20,1))*3.3)

# input_la(x1, x2).squeeze().shape
torch.prod(input_la(x1, x1).squeeze(), dim=2)

In [None]:
covar_module(x1).evaluate().shape

In [None]:
x1.shape

In [None]:
x_2 = torch.tensor([[-0.6281], [ 0.1011], [ 0.0664]])

In [None]:
from cirkit.layers.input.rbf_kernel import RBFKernelLayer
input_la = RBFKernelLayer(num_vars=2, num_output_units=1)

input_la.params = torch.nn.Parameter(torch.ones((1,1))*3.3)

input_la(x_2.unsqueeze(-1), x_2.unsqueeze(-1)).squeeze()

In [None]:
input_la.params

In [None]:
torch.ones((2,1))*3.3

In [None]:
input_la(x.unsqueeze(-1), x.unsqueeze(-1)).squeeze()

In [None]:
x_2.unsqueeze(-1).shape

In [None]:
torch.cdist(x1, x2, p=2)

In [None]:
from torch import optim
from torch.utils.data import DataLoader
train_dataloader = DataLoader(data_train, shuffle=True, batch_size=256)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=256)
optimizer = optim.SGD(pc.parameters(), lr=0.1, momentum=0.9)

Since the constructed PC is not necessarily normalized, we construct the integral circuit that will compute the partition function. Note that parameters are shared and therefore there is no additional memory required.

In [None]:
from cirkit.models.functional import integrate
pc_pf = integrate(pc)

Finally, we optimize the parameters for 5 epochs by minimizing the negative log-likelohood.

In [None]:
num_epochs = 5
for epoch_idx in range(num_epochs):
    running_loss = 0.0
    for batch, _ in train_dataloader:
        batch = batch.to(device).unsqueeze(dim=-1)  # Add a channel dimension
        log_score = pc(batch)
        log_pf = pc_pf(batch)     # Compute the partition function
        lls = log_score - log_pf  # Compute the log-likelihood
        loss = -torch.mean(lls)   # The loss is the negative average log-likelihood
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        running_loss += loss * len(batch)
        # Clamp the parameters to ensure they are in the intended domain
        # This is needed if we do not use any reparametrization to ensure parameters non-negativity
        # In our case, clamping is disable becuase we reparameterize via exponentiation (see above)
        #for layer in model.inner_layers:
        #    layer.clamp_params()
    print(f"Epoch {epoch_idx}: Average NLL: {running_loss / len(data_train):.3f}")

We then evaluate our model on test data by computing the average log-likelihood and bits per dimension.

In [None]:
with torch.no_grad():
    pc.eval()
    log_pf = pc_pf(torch.empty((), device=device))  # Compute the partition function once for testing
    test_lls = 0.0
    for batch, _ in test_dataloader:
        log_score = pc(batch.to(device).unsqueeze(dim=-1))
        lls = log_score - log_pf
        test_lls += lls.sum().item()
    average_ll = test_lls / len(data_test)
    bpd = -average_ll / (num_variables * np.log(2.0))
    print(f"Average test LL: {average_ll:.3f}")
    print(f"Bits per dimension: {bpd}")

In [None]:
#!/usr/bin/env python3

import gpytorch

# from ..functions import RBFCovariance
# from ..settings import trace_mode
from gpytorch.kernels import Kernel


def postprocess_rbf(dist_mat):
    return dist_mat.div_(-2).exp_()


class TestRBFKernel(Kernel):
    r"""
    Computes a covariance matrix based on the RBF (squared exponential) kernel
    between inputs :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`:

    .. math::

       \begin{equation*}
          k_{\text{RBF}}(\mathbf{x_1}, \mathbf{x_2}) = \exp \left( -\frac{1}{2}
          (\mathbf{x_1} - \mathbf{x_2})^\top \Theta^{-2} (\mathbf{x_1} - \mathbf{x_2}) \right)
       \end{equation*}

    where :math:`\Theta` is a :attr:`lengthscale` parameter.
    See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options.

    .. note::

        This kernel does not have an `outputscale` parameter. To add a scaling parameter,
        decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.

    Args:
        :attr:`ard_num_dims` (int, optional):
            Set this if you want a separate lengthscale for each
            input dimension. It should be `d` if :attr:`x1` is a `n x d` matrix. Default: `None`
        :attr:`batch_shape` (torch.Size, optional):
            Set this if you want a separate lengthscale for each
            batch of input data. It should be `b` if :attr:`x1` is a `b x n x d` tensor. Default: `torch.Size([])`.
        :attr:`active_dims` (tuple of ints, optional):
            Set this if you want to compute the covariance of only a few input dimensions. The ints
            corresponds to the indices of the dimensions. Default: `None`.
        :attr:`lengthscale_prior` (Prior, optional):
            Set this if you want to apply a prior to the lengthscale parameter.  Default: `None`.
        :attr:`lengthscale_constraint` (Constraint, optional):
            Set this if you want to apply a constraint to the lengthscale parameter. Default: `Positive`.
        :attr:`eps` (float):
            The minimum value that the lengthscale can take (prevents divide by zero errors). Default: `1e-6`.

    Attributes:
        :attr:`lengthscale` (Tensor):
            The lengthscale parameter. Size/shape of parameter depends on the
            :attr:`ard_num_dims` and :attr:`batch_shape` arguments.

    Example:
        >>> x = torch.randn(10, 5)
        >>> # Non-batch: Simple option
        >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        >>> # Non-batch: ARD (different lengthscale for each input dimension)
        >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=5))
        >>> covar = covar_module(x)  # Output: LazyTensor of size (10 x 10)
        >>>
        >>> batch_x = torch.randn(2, 10, 5)
        >>> # Batch: Simple option
        >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        >>> # Batch: different lengthscale for each batch
        >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([2])))
        >>> covar = covar_module(x)  # Output: LazyTensor of size (2 x 10 x 10)
    """

    has_lengthscale = True

    def forward(self, x1, x2, diag=False, **params):

        x1_ = x1.div(self.lengthscale)
        x2_ = x2.div(self.lengthscale)
        
        # print ("x1, x2", x1_, x2_)
        
        return self.covar_dist(
            x1_, x2_, square_dist=True, diag=diag, dist_postprocess_func=postprocess_rbf, postprocess=True, **params
        )

In [None]:
test_kernel = TestRBFKernel()
test_kernel.lengthscale = torch.tensor(3.3)

In [None]:
test_kernel.lengthscale

In [None]:
test_kernel(x1.squeeze(),x2.squeeze()).evaluate()

In [None]:
x1.shape

In [None]:
import logging
import math
from typing import Optional, Tuple, Union

import torch

from gpytorch.constraints import Interval, Positive
from gpytorch.priors import Prior
from gpytorch.kernels import Kernel

logger = logging.getLogger()


class SpectralMixtureKernel(Kernel):
    r"""
    Computes a covariance matrix based on the Spectral Mixture Kernel
    between inputs :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`.

    It was proposed in `Gaussian Process Kernels for Pattern Discovery and Extrapolation`_.

    .. note::
        Unlike other kernels,

            * ard_num_dims **must equal** the number of dimensions of the data.
            * This kernel should not be combined with a :class:`gpytorch.kernels.ScaleKernel`.

    :param int num_mixtures: The number of components in the mixture.
    :param int ard_num_dims: Set this to match the dimensionality of the input.
        It should be `d` if x1 is a `... x n x d` matrix. (Default: `1`.)
    :param batch_shape: Set this if the data is batch of input data. It should
        be `b_1 x ... x b_j` if x1 is a `b_1 x ... x b_j x n x d` tensor. (Default: `torch.Size([])`.)
    :type batch_shape: torch.Size, optional
    :param active_dims: Set this if you want to compute the covariance of only
        a few input dimensions. The ints corresponds to the indices of the dimensions. (Default: `None`.)
    :type active_dims: float, optional
    :param eps: The minimum value that the lengthscale can take (prevents divide by zero errors). (Default: `1e-6`.)
    :type eps: float, optional

    :param mixture_scales_prior: A prior to set on the mixture_scales parameter
    :type mixture_scales_prior: ~gpytorch.priors.Prior, optional
    :param mixture_scales_constraint: A constraint to set on the mixture_scales parameter
    :type mixture_scales_constraint: ~gpytorch.constraints.Interval, optional
    :param mixture_means_prior: A prior to set on the mixture_means parameter
    :type mixture_means_prior: ~gpytorch.priors.Prior, optional
    :param mixture_means_constraint: A constraint to set on the mixture_means parameter
    :type mixture_means_constraint: ~gpytorch.constraints.Interval, optional
    :param mixture_weights_prior: A prior to set on the mixture_weights parameter
    :type mixture_weights_prior: ~gpytorch.priors.Prior, optional
    :param mixture_weights_constraint: A constraint to set on the mixture_weights parameter
    :type mixture_weights_constraint: ~gpytorch.constraints.Interval, optional

    :ivar torch.Tensor mixture_scales: The lengthscale parameter. Given
        `k` mixture components, and `... x n x d` data, this will be of size `... x k x 1 x d`.
    :ivar torch.Tensor mixture_means: The mixture mean parameters (`... x k x 1 x d`).
    :ivar torch.Tensor mixture_weights: The mixture weight parameters (`... x k`).

    Example:

        >>> # Non-batch
        >>> x = torch.randn(10, 5)
        >>> covar_module = gpytorch.kernels.SpectralMixtureKernel(num_mixtures=4, ard_num_dims=5)
        >>> covar = covar_module(x)  # Output: LazyVariable of size (10 x 10)
        >>>
        >>> # Batch
        >>> batch_x = torch.randn(2, 10, 5)
        >>> covar_module = gpytorch.kernels.SpectralMixtureKernel(num_mixtures=4, batch_size=2, ard_num_dims=5)
        >>> covar = covar_module(x)  # Output: LazyVariable of size (10 x 10)

    .. _Gaussian Process Kernels for Pattern Discovery and Extrapolation:
        https://arxiv.org/pdf/1302.4245.pdf
    """

    is_stationary = True  # kernel is stationary even though it does not have a lengthscale

    def __init__(
        self,
        num_mixtures: Optional[int] = None,
        ard_num_dims: Optional[int] = 1,
        batch_shape: Optional[torch.Size] = torch.Size([]),
        mixture_scales_prior: Optional[Prior] = None,
        mixture_scales_constraint: Optional[Interval] = None,
        mixture_means_prior: Optional[Prior] = None,
        mixture_means_constraint: Optional[Interval] = None,
        mixture_weights_prior: Optional[Prior] = None,
        mixture_weights_constraint: Optional[Interval] = None,
        **kwargs,
    ):
        if num_mixtures is None:
            raise RuntimeError("num_mixtures is a required argument")
        if mixture_means_prior is not None or mixture_scales_prior is not None or mixture_weights_prior is not None:
            logger.warning("Priors not implemented for SpectralMixtureKernel")

        # This kernel does not use the default lengthscale
        super(SpectralMixtureKernel, self).__init__(ard_num_dims=ard_num_dims, batch_shape=batch_shape, **kwargs)
        self.num_mixtures = num_mixtures

        if mixture_scales_constraint is None:
            mixture_scales_constraint = Positive()

        if mixture_means_constraint is None:
            mixture_means_constraint = Positive()

        if mixture_weights_constraint is None:
            mixture_weights_constraint = Positive()

        self.register_parameter(
            name="raw_mixture_weights", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, self.num_mixtures))
        )
        ms_shape = torch.Size([*self.batch_shape, self.num_mixtures, 1, self.ard_num_dims])
        self.register_parameter(name="raw_mixture_means", parameter=torch.nn.Parameter(torch.zeros(ms_shape)))
        self.register_parameter(name="raw_mixture_scales", parameter=torch.nn.Parameter(torch.zeros(ms_shape)))

        self.register_constraint("raw_mixture_scales", mixture_scales_constraint)
        self.register_constraint("raw_mixture_means", mixture_means_constraint)
        self.register_constraint("raw_mixture_weights", mixture_weights_constraint)

    @property
    def mixture_scales(self):
        return self.raw_mixture_scales_constraint.transform(self.raw_mixture_scales)

    @mixture_scales.setter
    def mixture_scales(self, value: Union[torch.Tensor, float]):
        self._set_mixture_scales(value)

    def _set_mixture_scales(self, value: Union[torch.Tensor, float]):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_mixture_scales)
        self.initialize(raw_mixture_scales=self.raw_mixture_scales_constraint.inverse_transform(value))

    @property
    def mixture_means(self):
        return self.raw_mixture_means_constraint.transform(self.raw_mixture_means)

    @mixture_means.setter
    def mixture_means(self, value: Union[torch.Tensor, float]):
        self._set_mixture_means(value)

    def _set_mixture_means(self, value: Union[torch.Tensor, float]):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_mixture_means)
        self.initialize(raw_mixture_means=self.raw_mixture_means_constraint.inverse_transform(value))

    @property
    def mixture_weights(self):
        return self.raw_mixture_weights_constraint.transform(self.raw_mixture_weights)

    @mixture_weights.setter
    def mixture_weights(self, value: Union[torch.Tensor, float]):
        self._set_mixture_weights(value)

    def _set_mixture_weights(self, value: Union[torch.Tensor, float]):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_mixture_weights)
        self.initialize(raw_mixture_weights=self.raw_mixture_weights_constraint.inverse_transform(value))

    def initialize_from_data_empspect(self, train_x: torch.Tensor, train_y: torch.Tensor):
        """
        Initialize mixture components based on the empirical spectrum of the data.
        This will often be better than the standard initialize_from_data method, but it assumes
        that your inputs are evenly spaced.

        :param torch.Tensor train_x: Training inputs
        :param torch.Tensor train_y: Training outputs
        """

        import numpy as np
        from scipy.fftpack import fft
        from scipy.integrate import cumtrapz

        with torch.no_grad():
            if not torch.is_tensor(train_x) or not torch.is_tensor(train_y):
                raise RuntimeError("train_x and train_y should be tensors")
            if train_x.ndimension() == 1:
                train_x = train_x.unsqueeze(-1)
            if self.active_dims is not None:
                train_x = train_x[..., self.active_dims]

            # Flatten batch dimensions
            train_x = train_x.view(-1, train_x.size(-1))
            train_y = train_y.view(-1)

            N = train_x.size(-2)
            emp_spect = np.abs(fft(train_y.cpu().detach().numpy())) ** 2 / N
            M = math.floor(N / 2)

            freq1 = np.arange(M + 1)
            freq2 = np.arange(-M + 1, 0)
            freq = np.hstack((freq1, freq2)) / N
            freq = freq[: M + 1]
            emp_spect = emp_spect[: M + 1]

            total_area = np.trapz(emp_spect, freq)
            spec_cdf = np.hstack((np.zeros(1), cumtrapz(emp_spect, freq)))
            spec_cdf = spec_cdf / total_area

            a = np.random.rand(1000, self.ard_num_dims)
            p, q = np.histogram(a, spec_cdf)
            bins = np.digitize(a, q)
            slopes = (spec_cdf[bins] - spec_cdf[bins - 1]) / (freq[bins] - freq[bins - 1])
            intercepts = spec_cdf[bins - 1] - slopes * freq[bins - 1]
            inv_spec = (a - intercepts) / slopes

            from sklearn.mixture import GaussianMixture

            GMM = GaussianMixture(n_components=self.num_mixtures, covariance_type="diag").fit(inv_spec)
            means = GMM.means_
            varz = GMM.covariances_
            weights = GMM.weights_

            dtype = self.raw_mixture_means.dtype
            device = self.raw_mixture_means.device
            self.mixture_means = torch.tensor(means, dtype=dtype, device=device).unsqueeze(-2)
            self.mixture_scales = torch.tensor(varz, dtype=dtype, device=device).unsqueeze(-2)
            self.mixture_weights = torch.tensor(weights, dtype=dtype, device=device)


    def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs):
        """
        Initialize mixture components based on batch statistics of the data. You should use
        this initialization routine if your observations are not evenly spaced.

        :param torch.Tensor train_x: Training inputs
        :param torch.Tensor train_y: Training outputs
        """

        with torch.no_grad():
            if not torch.is_tensor(train_x) or not torch.is_tensor(train_y):
                raise RuntimeError("train_x and train_y should be tensors")
            if train_x.ndimension() == 1:
                train_x = train_x.unsqueeze(-1)
            if self.active_dims is not None:
                train_x = train_x[..., self.active_dims]

            # Compute maximum distance between points in each dimension
            train_x_sort = train_x.sort(dim=-2)[0]
            max_dist = train_x_sort[..., -1, :] - train_x_sort[..., 0, :]

            # Compute the minimum distance between points in each dimension
            dists = train_x_sort[..., 1:, :] - train_x_sort[..., :-1, :]
            # We don't want the minimum distance to be zero, so fill zero values with some large number
            dists = torch.where(dists.eq(0.0), torch.tensor(1.0e10, dtype=train_x.dtype, device=train_x.device), dists)
            sorted_dists = dists.sort(dim=-2)[0]
            min_dist = sorted_dists[..., 0, :]

            # Reshape min_dist and max_dist to match the shape of parameters
            # First add a singleton data dimension (-2) and a dimension for the mixture components (-3)
            min_dist = min_dist.unsqueeze_(-2).unsqueeze_(-3)
            max_dist = max_dist.unsqueeze_(-2).unsqueeze_(-3)
            # Compress any dimensions in min_dist/max_dist that correspond to singletons in the SM parameters
            dim = -3
            while -dim <= min_dist.dim():
                if -dim > self.raw_mixture_scales.dim():
                    min_dist = min_dist.min(dim=dim)[0]
                    max_dist = max_dist.max(dim=dim)[0]
                elif self.raw_mixture_scales.size(dim) == 1:
                    min_dist = min_dist.min(dim=dim, keepdim=True)[0]
                    max_dist = max_dist.max(dim=dim, keepdim=True)[0]
                    dim -= 1
                else:
                    dim -= 1

            # Inverse of lengthscales should be drawn from truncated Gaussian | N(0, max_dist^2) |
            self.mixture_scales = torch.randn_like(self.raw_mixture_scales).mul_(max_dist).abs_().reciprocal_()
            # Draw means from Unif(0, 0.5 / minimum distance between two points)
            self.mixture_means = torch.rand_like(self.raw_mixture_means).mul_(0.5).div(min_dist)
            # Mixture weights should be roughly the stdv of the y values divided by the number of mixtures
            self.mixture_weights = train_y.std().div(self.num_mixtures)


    def _create_input_grid(
        self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        This is a helper method for creating a grid of the kernel's inputs.
        Use this helper rather than maually creating a meshgrid.

        The grid dimensions depend on the kernel's evaluation mode.

        :param torch.Tensor x1: ... x n x d
        :param torch.Tensor x2: ... x m x d (for diag mode, these must be the same inputs)
        :param diag: Should the Kernel compute the whole kernel, or just the diag? (Default: True.)
        :type diag: bool, optional
        :param last_dim_is_batch: If this is true, it treats the last dimension
            of the data as another batch dimension.  (Useful for additive
            structure over the dimensions). (Default: False.)
        :type last_dim_is_batch: bool, optional

        :rtype: torch.Tensor, torch.Tensor
        :return: Grid corresponding to x1 and x2. The shape depends on the kernel's mode:
            * `full_covar`: (`... x n x 1 x d` and `... x 1 x m x d`)
            * `full_covar` with `last_dim_is_batch=True`: (`... x k x n x 1 x 1` and `... x k x 1 x m x 1`)
            * `diag`: (`... x n x d` and `... x n x d`)
            * `diag` with `last_dim_is_batch=True`: (`... x k x n x 1` and `... x k x n x 1`)
        """
        x1_, x2_ = x1, x2
        if last_dim_is_batch:
            x1_ = x1_.transpose(-1, -2).unsqueeze(-1)
            if torch.equal(x1, x2):
                x2_ = x1_
            else:
                x2_ = x2_.transpose(-1, -2).unsqueeze(-1)

        if diag:
            return x1_, x2_
        else:
            return x1_.unsqueeze(-2), x2_.unsqueeze(-3)

    def forward(
        self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        n, num_dims = x1.shape[-2:]

        if not num_dims == self.ard_num_dims:
            raise RuntimeError(
                "The SpectralMixtureKernel expected the input to have {} dimensionality "
                "(based on the ard_num_dims argument). Got {}.".format(self.ard_num_dims, num_dims)
            )

        # Expand x1 and x2 to account for the number of mixtures
        # Should make x1/x2 (... x k x n x d) for k mixtures
        x1_ = x1.unsqueeze(-3)
        x2_ = x2.unsqueeze(-3)
        
        # print("x1_", x1_.shape)
        # print("x1_", x1_)
        # print("self.mixture_means", self.mixture_means.shape)
        # print("self.mixture_means", self.mixture_means)
        # print("self.mixture_scales", self.mixture_scales.shape)
        # print("self.mixture_scales", self.mixture_scales)

        # Compute distances - scaled by appropriate parameters
        x1_exp = x1_ * self.mixture_scales
        x2_exp = x2_ * self.mixture_scales
        x1_cos = x1_ * self.mixture_means
        x2_cos = x2_ * self.mixture_means
        
        # print("x1_exp", x1_exp)
        # print("x1_cos", x1_cos)
        
        # print("x1_exp", x1_exp)

        # Create grids
        x1_exp_, x2_exp_ = self._create_input_grid(x1_exp, x2_exp, diag=diag, **params)
        x1_cos_, x2_cos_ = self._create_input_grid(x1_cos, x2_cos, diag=diag, **params)

        # Compute the exponential and cosine terms
        exp_term = (x1_exp_ - x2_exp_).pow_(2).mul_(-2 * math.pi**2)
        cos_term = (x1_cos_ - x2_cos_).mul_(2 * math.pi)
        res = exp_term.exp_() * cos_term.cos_()
        
        # print("exp_term", exp_term)
        # print("cos_term", cos_term)
        
        # print("exp_term", exp_term)
        
        # print("res", res)

        # Sum over mixtures
        mixture_weights = self.mixture_weights.view(*self.mixture_weights.shape, 1, 1)
        if not diag:
            mixture_weights = mixture_weights.unsqueeze(-2)

        res = (res * mixture_weights).sum(-3 if diag else -4)

        # Product over dimensions
        if last_dim_is_batch:
            # Put feature-dimension in front of data1/data2 dimensions
            res = res.permute(*list(range(0, res.dim() - 3)), -1, -3, -2)
        else:
            res = res.prod(-1)

        return res

In [None]:
import math


train_x = torch.tensor([[0.4124, 0.5949, 0.4964, 0.7655, 0.5808],[0.1466, 0.6540, 0.6794, 0.8892, 0.8310],[0.1754, 0.4290, 0.6012, 0.4222, 0.4456]])

covar_module = SpectralMixtureKernel(num_mixtures=4, ard_num_dims=5)
covar_module.mixture_weights = torch.tensor([0.0334, 0.0334, 0.0334, 0.0334])
covar_module.mixture_means = torch.tensor([[[11.8665,  2.0900,  3.1740,  3.5268,  1.1721]],[[ 1.7340,  8.3558,  3.6868,  0.3790,  3.3036]],[[ 9.2139,  5.0617,  4.0027,  3.9273,  1.8358]],[[16.3662,  2.0223,  3.3369,  1.4315,  2.6659]]])
covar_module.mixture_scales = torch.tensor([[[  3.4405,   7.5435,   3.9373,   2.4702,   3.0673]],[[  3.0207,   7.0498,  10.1843,  78.9824,   2.1191]],[[  4.3353,  11.9640,  18.6242,   1.8301,   7.3544]],[[  3.3139,   3.4624,  34.5551,   7.9503, 195.3149]]])
covar_module(train_x).evaluate()

In [None]:
covar_module.mixture_scales

In [None]:
train_x, train_y