In [9]:
import torch
import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import VariationalStrategy, CholeskyVariationalDistribution
from gpytorch.likelihoods import SoftmaxLikelihood
from gpytorch.mlls import VariationalELBO
from torch.utils.data import TensorDataset, DataLoader

# Generate some sample data
num_data = 100
num_features = 5
num_classes = 3
X_train = torch.randn(num_data, num_features)
y_train = torch.randint(0, num_classes, (num_data,))

# Define a multi-output GP model for multiclass classification
class MulticlassGPModel(ApproximateGP):
    def __init__(self, num_features, num_classes):
        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=50)
        variational_strategy = VariationalStrategy(self, torch.randn(50, num_features), variational_distribution, learn_inducing_locations=True)
        super(MulticlassGPModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_classes]))
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_classes])), batch_shape=torch.Size([num_classes]))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# Instantiate the model and likelihood
model = MulticlassGPModel(num_features=num_features, num_classes=num_classes)
likelihood = SoftmaxLikelihood(num_classes=num_classes, mixing_weights=None)

# Define the marginal log likelihood (MLL) objective
mll = VariationalELBO(likelihood, model, num_data=X_train.size(0))

# Setup training
optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': likelihood.parameters()}], lr=0.01)
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Train the model
model.train()
likelihood.train()
for epoch in range(100):
    
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        output = model(x_batch)
        loss = -mll(output, y_batch)
        loss.backward()
        optimizer.step()
    print("epoch: {}, loss: {:.3f}".format(epoch, loss.item()))
# Switch to evaluation mode
model.eval()
likelihood.eval()

# Predict on new data
# with torch.no_grad(), gpytorch.settings.num_likelihood_samples(10):
test_x = torch.randn(10, num_features)  # example test data
preds = likelihood(model(test_x))

# Extract predicted probabilities that support backpropagation
predicted_probabilities = preds.probs
print(predicted_probabilities)


epoch: 0, loss: 1.455
epoch: 1, loss: 1.216
epoch: 2, loss: 1.554
epoch: 3, loss: 1.334
epoch: 4, loss: 1.084
epoch: 5, loss: 1.164
epoch: 6, loss: 1.292
epoch: 7, loss: 1.273
epoch: 8, loss: 1.208
epoch: 9, loss: 1.114
epoch: 10, loss: 1.108
epoch: 11, loss: 1.297
epoch: 12, loss: 1.269
epoch: 13, loss: 1.075
epoch: 14, loss: 1.181
epoch: 15, loss: 1.209
epoch: 16, loss: 1.282
epoch: 17, loss: 1.222
epoch: 18, loss: 1.207
epoch: 19, loss: 1.104
epoch: 20, loss: 1.134
epoch: 21, loss: 1.209
epoch: 22, loss: 1.199
epoch: 23, loss: 1.211
epoch: 24, loss: 1.244
epoch: 25, loss: 1.243
epoch: 26, loss: 1.354
epoch: 27, loss: 1.201
epoch: 28, loss: 1.271
epoch: 29, loss: 1.203
epoch: 30, loss: 1.180
epoch: 31, loss: 1.234
epoch: 32, loss: 1.204
epoch: 33, loss: 1.159
epoch: 34, loss: 1.192
epoch: 35, loss: 1.078
epoch: 36, loss: 1.336
epoch: 37, loss: 1.146
epoch: 38, loss: 1.125
epoch: 39, loss: 0.965
epoch: 40, loss: 1.153
epoch: 41, loss: 1.161
epoch: 42, loss: 1.160
epoch: 43, loss: 1.28

[0;31mSignature:[0m       [0mpreds[0m[0;34m.[0m[0mprobs[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m            _lazy_property_and_property
[0;31mString form:[0m     <torch.distributions.utils._lazy_property_and_property object at 0x7fe3ef552a40>
[0;31mFile:[0m            ~/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/distributions/utils.py
[0;31mDocstring:[0m       <no docstring>
[0;31mClass docstring:[0m
We want lazy properties to look like multiple things.

* property when Sphinx autodoc looks
* lazy_property when Distribution validate_args looks

In [5]:
predicted_probabilities = preds.probs

In [6]:
predicted_probabilities.requires_grad

True

In [9]:
with torch.enable_grad(), gpytorch.settings.num_likelihood_samples(10):
    test_x = torch.randn(10, num_features)  # example test data
    preds = likelihood(model(test_x))
predicted_probabilities = preds.probs
predicted_probabilities.requires_grad

True

In [13]:
predicted_probabilities.shape

torch.Size([10, 10, 3])

In [15]:
test_x.shape

torch.Size([10, 5])

In [10]:
predicted_probabilities.mean(dim=1)

tensor([[0.2735, 0.3779, 0.3485],
        [0.3033, 0.4045, 0.2922],
        [0.2721, 0.3386, 0.3893],
        [0.2908, 0.3926, 0.3166],
        [0.3435, 0.3606, 0.2959],
        [0.2832, 0.3628, 0.3540],
        [0.3078, 0.3588, 0.3334],
        [0.2703, 0.3458, 0.3839],
        [0.3117, 0.3043, 0.3840],
        [0.3411, 0.3668, 0.2921]], grad_fn=<MeanBackward1>)