Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Potential Bug in GP Regression with KeOps Kernels #2453

Open
kayween opened this issue Dec 7, 2023 · 0 comments
Open

[Bug] Potential Bug in GP Regression with KeOps Kernels #2453

kayween opened this issue Dec 7, 2023 · 0 comments
Labels

Comments

@kayween
Copy link
Collaborator

kayween commented Dec 7, 2023

馃悰 Bug

During the backward pass of the exact marginal log likelihood, GPyTorch throws a CUDA error. This happens when using the KeOps kernel on large datasets with $1$ million data points.

Though, the KeOps kernel works fine with smaller datasets.

To reproduce

import torch

import gpytorch


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.keops.RBFKernel()
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


if __name__ == "__main__":
    # n = 500_000  # works okay
    n = 1_000_000  # error
    d = 10

    device = "cuda:0"

    train_x = torch.randn(n, d, device=device)
    train_y = torch.randn(n, device=device)

    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    # likelihood.noise_covar.initilize(10.)

    model = ExactGPModel(train_x, train_y, likelihood).to(device)

    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    training_iter = 50
    for i in range(training_iter):
        # Zero gradients from previous iteration
        optimizer.zero_grad()
        # Output from model
        output = model(train_x)
        # Calc loss and backprop gradients
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
            i + 1, training_iter, loss.item(),
            model.covar_module.base_kernel.lengthscale.item(),
            model.likelihood.noise.item()
        ))
        optimizer.step()

Stack trace/error message

[KeOps] Generating code for formula Sum_Reduction(Exp(-Sum((Var(0,10,0)-Var(1,10,1))**2)/2)*Var(2,11,1),0) ... OK
[KeOps] Generating code for formula Sum_Reduction(-((2*(Var(0,10,0)-Var(1,10,1)))*(((Var(3,11,0)|Var(2,11,1))*Exp(-Sum((Var(0,10,0)-Var(1,10,1))**2)/2))/2)),0) ... OK
[KeOps] Generating code for formula Sum_Reduction(-(-((2*(Var(0,10,0)-Var(1,10,1)))*(((Var(3,11,0)|Var(2,11,1))*Exp(-Sum((Var(0,10,0)-Var(1,10,1))**2)/2))/2))),1) ... OK
Traceback (most recent call last):
  File "test.py", line 51, in <module>
    loss.backward()
  File "/home/kaiwen/anaconda3/envs/altproj/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/kaiwen/anaconda3/envs/altproj/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasStrsmBatched( handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount)`

Expected Behavior

The backward pass should be executed successfully.

System information

Please complete the following information:

  • GPyTorch Version 1.12.dev24+g14705b9d (the latest commit on the main branch)
  • PyTorch Version 2.0.1
  • Ubuntu 20.04.5 LTS
  • CUDA 12.2

Additional context

This issue is related to the recent commit, which is intended to fix the bug in the KeOps kernel. I encountered the above issue when testing KeOps kernel after the commit.

Weirdly enough, the above code works fine when $n = 500,000$.

Note that GPyTorch1.6.0 does not have this bug.

@kayween kayween added the bug label Dec 7, 2023
@kayween kayween changed the title [Bug] Potential Bug with GP Regression with KeOps Kernels [Bug] Potential Bug in GP Regression with KeOps Kernels Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant