Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent dtype errors when torch's default_dtype is set to torch.float64 #2254

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions gpytorch/kernels/newton_girard_additive_kernel.py
Expand Up @@ -90,14 +90,14 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
s_k = kern_values.unsqueeze(kernel_dim - 1).pow(kvals).sum(dim=kernel_dim)

# just the constant -1
m1 = torch.tensor([-1], dtype=torch.float, device=kern_values.device)
m1 = torch.tensor([-1], dtype=x1.dtype, device=kern_values.device)

shape = [1 for _ in range(len(kern_values.shape))]
shape[kernel_dim] = -1
for deg in range(1, self.max_degree + 1): # deg goes from 1 to R (it's 1-indexed!)
# we avg over k [1, ..., deg] (-1)^(k-1)e_{deg-k} s_{k}

ks = torch.arange(1, deg + 1, device=kern_values.device, dtype=torch.float).reshape(*shape) # use for pow
ks = torch.arange(1, deg + 1, device=kern_values.device, dtype=x1.dtype).reshape(*shape) # use for pow
kslong = torch.arange(1, deg + 1, device=kern_values.device, dtype=torch.long) # use for indexing

# note that s_k is 0-indexed, so we must subtract 1 from kslong
Expand Down
5 changes: 4 additions & 1 deletion gpytorch/likelihoods/gaussian_likelihood.py
Expand Up @@ -308,7 +308,10 @@ class DirichletClassificationLikelihood(FixedNoiseGaussianLikelihood):
>>> pred_y = likelihood(gp_model(test_x), targets=labels)
"""

def _prepare_targets(self, targets, alpha_epsilon=0.01, dtype=torch.float):
def _prepare_targets(self, targets: torch.Tensor, alpha_epsilon: float = 0.01, dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor, int]:
if dtype is None:
dtype = torch.get_default_dtype()

num_classes = int(targets.max() + 1)
# set alpha = \alpha_\epsilon
alpha = alpha_epsilon * torch.ones(targets.shape[-1], num_classes, device=targets.device, dtype=dtype)
Expand Down
12 changes: 5 additions & 7 deletions gpytorch/utils/grid.py
Expand Up @@ -2,7 +2,7 @@

import math
import warnings
from typing import List, Tuple
from typing import List, Optional, Tuple

import torch

Expand Down Expand Up @@ -130,8 +130,8 @@ def create_grid(
grid_sizes: List[int],
grid_bounds: List[Tuple[float, float]],
extend: bool = True,
device="cpu",
dtype=torch.float,
device: str = "cpu",
dtype: Optional[torch.dtype] = None,
) -> List[torch.Tensor]:
"""
Creates a grid represented by a list of 1D Tensors representing the
Expand All @@ -141,16 +141,14 @@ def create_grid(
which can be important for getting good grid interpolations.

:param grid_sizes: Sizes of each grid dimension
:type grid_sizes: List[int]
:param grid_bounds: Lower and upper bounds of each grid dimension
:type grid_sizes: List[Tuple[float, float]]
:param device: target device for output (default: cpu)
:type device: torch.device, optional
:param dtype: target dtype for output (default: torch.float)
:type dtype: torch.dtype, optional
:return: Grid points for each dimension. Grid points are stored in a :obj:`torch.Tensor` with shape `grid_sizes[i]`.
:rtype: List[torch.Tensor]
"""
if dtype is None:
dtype = torch.get_default_dtype()
grid = []
for i in range(len(grid_bounds)):
grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_sizes[i] - 2)
Expand Down
17 changes: 3 additions & 14 deletions test/examples/test_kissgp_gp_regression.py
@@ -1,7 +1,5 @@
#!/usr/bin/env python3

import os
import random
import unittest
from math import exp, pi

Expand All @@ -12,6 +10,7 @@
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.priors import SmoothedBoxPrior
from gpytorch.test.base_test_case import BaseTestCase
from gpytorch.test.utils import least_used_cuda_device
from torch import optim

Expand Down Expand Up @@ -44,18 +43,8 @@ def forward(self, x):
return MultivariateNormal(mean_x, covar_x)


class TestKISSGPRegression(unittest.TestCase):
def setUp(self):
if os.getenv("UNLOCK_SEED") is None or os.getenv("UNLOCK_SEED").lower() == "false":
self.rng_state = torch.get_rng_state()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
random.seed(0)

def tearDown(self):
if hasattr(self, "rng_state"):
torch.set_rng_state(self.rng_state)
class TestKISSGPRegression(unittest.TestCase, BaseTestCase):
seed = 0

def test_kissgp_gp_mean_abs_error(self):
train_x, train_y, test_x, test_y = make_data()
Expand Down
33 changes: 19 additions & 14 deletions test/examples/test_kissgp_kronecker_product_regression.py
Expand Up @@ -2,8 +2,6 @@

from math import pi

import os
import random
import torch
import unittest

Expand All @@ -13,6 +11,7 @@
from gpytorch.means import ConstantMean
from gpytorch.priors import SmoothedBoxPrior
from gpytorch.distributions import MultivariateNormal
from gpytorch.test.base_test_case import BaseTestCase
from torch import optim

# Simple training data: let's try to learn a sine function,
Expand Down Expand Up @@ -52,18 +51,8 @@ def forward(self, x):
return MultivariateNormal(mean_x, covar_x)


class TestKISSGPKroneckerProductRegression(unittest.TestCase):
def setUp(self):
if os.getenv("UNLOCK_SEED") is None or os.getenv("UNLOCK_SEED").lower() == "false":
self.rng_state = torch.get_rng_state()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
random.seed(0)

def tearDown(self):
if hasattr(self, "rng_state"):
torch.set_rng_state(self.rng_state)
class TestKISSGPKroneckerProductRegression(unittest.TestCase, BaseTestCase):
seed = 0

def test_kissgp_gp_mean_abs_error(self):
likelihood = GaussianLikelihood()
Expand Down Expand Up @@ -99,6 +88,22 @@ def test_kissgp_gp_mean_abs_error(self):
mean_abs_error = torch.mean(torch.abs(test_y - test_preds))
self.assertLess(mean_abs_error.squeeze().item(), 0.2)

# Try drawing a sample - make sure there's no errors
with torch.no_grad(), gpytorch.settings.max_root_decomposition_size(100):
with gpytorch.settings.fast_pred_samples():
gp_model(train_x).rsample(torch.Size([1]))


class TestKISSGPKroneckerProductRegressionDouble(TestKISSGPKroneckerProductRegression):
def setUp(self):
super().setUp
self.default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)

def tearDown(self):
super().tearDown()
torch.set_default_dtype(self.default_dtype)


if __name__ == "__main__":
unittest.main()