Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] torch.float62 raises error in GridInterpolationKernel #2225

Closed
anjawa opened this issue Dec 14, 2022 · 2 comments 路 May be fixed by #2254
Closed

[Bug] torch.float62 raises error in GridInterpolationKernel #2225

anjawa opened this issue Dec 14, 2022 · 2 comments 路 May be fixed by #2254
Labels

Comments

@anjawa
Copy link

anjawa commented Dec 14, 2022

馃悰 Bug

I want to sample from the prior distribution with precision torch.float62
However, during sampling with KISS-GP a dtype error is raised
if I manually design a kernel (very similar to the RBF Kernel)
that is included in the GridInterpolationKernel.

Changing the test data size from 2500x2 to 100x2, no error will occur.

x = torch.meshgrid(
    torch.linspace(0, 10 - 1, 10) * 1.,
    torch.linspace(0, 10 - 1, 10) * 1.,
    indexing="xy",
)
x = torch.cat(
    (
        x[0].contiguous().view(x[0].numel(), 1),
        x[1].contiguous().view(x[1].numel(), 1),
    ),
    dim=1,
)

To reproduce

import torch
import gpytorch


torch.set_default_dtype(torch.float64)


def postprocess_rot(dist_mat):
    return dist_mat.mul_(-1.0).exp_()

class TestKernel(gpytorch.kernels.Kernel):

    is_stationary = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x1, x2, **params):
        x1_ = x1.div_(torch.tensor([10., 1.]))
        x2_ = x2.div_(torch.tensor([10., 1.]))
        return self.covar_dist(
            x1_, x2_, square_dist=False, dist_postprocess_func=postprocess_rot, **params
        )

class ExactGP(gpytorch.models.ExactGP):

    def __init__(self, **kwargs):
        super().__init__(None, None, gpytorch.likelihoods.GaussianLikelihood())
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.GridInterpolationKernel(
            TestKernel(ard_num_dims=2, **kwargs),
            grid_size=100,
            num_dims=2
            )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


x = torch.meshgrid(
    torch.linspace(0, 50 - 1, 50) * 1.,
    torch.linspace(0, 50 - 1, 50) * 1.,
    indexing="xy",
)
x = torch.cat(
    (
        x[0].contiguous().view(x[0].numel(), 1),
        x[1].contiguous().view(x[1].numel(), 1),
    ),
    dim=1,
)

model = ExactGP()
model.eval()

with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.max_root_decomposition_size(100):
    with gpytorch.settings.fast_pred_samples():
        samples = model(x).rsample(torch.Size([1]))

** Error message **

expected scalar type Double but found Float

System information

torch=1.13.0
gpytorch=1.9.0

@gpleiss
Copy link
Member

gpleiss commented Jan 17, 2023

Sorry for the slow response - I'm putting up a PR to fix this.

@anjawa
Copy link
Author

anjawa commented Jan 17, 2023

Thanks 馃憤馃従 馃檪

@anjawa anjawa closed this as completed Jan 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants