Skip to content

Commit

Permalink
Prevent dtype errors when torch's default_dtype is set to torch.float64
Browse files Browse the repository at this point in the history
[Fixes #2225]
  • Loading branch information
gpleiss committed Jan 17, 2023
1 parent 41a8386 commit 58af927
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 38 deletions.
4 changes: 2 additions & 2 deletions gpytorch/kernels/newton_girard_additive_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
s_k = kern_values.unsqueeze(kernel_dim - 1).pow(kvals).sum(dim=kernel_dim)

# just the constant -1
m1 = torch.tensor([-1], dtype=torch.float, device=kern_values.device)
m1 = torch.tensor([-1], dtype=x1.dtype, device=kern_values.device)

shape = [1 for _ in range(len(kern_values.shape))]
shape[kernel_dim] = -1
for deg in range(1, self.max_degree + 1): # deg goes from 1 to R (it's 1-indexed!)
# we avg over k [1, ..., deg] (-1)^(k-1)e_{deg-k} s_{k}

ks = torch.arange(1, deg + 1, device=kern_values.device, dtype=torch.float).reshape(*shape) # use for pow
ks = torch.arange(1, deg + 1, device=kern_values.device, dtype=x1.dtype).reshape(*shape) # use for pow
kslong = torch.arange(1, deg + 1, device=kern_values.device, dtype=torch.long) # use for indexing

# note that s_k is 0-indexed, so we must subtract 1 from kslong
Expand Down
5 changes: 4 additions & 1 deletion gpytorch/likelihoods/gaussian_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ class DirichletClassificationLikelihood(FixedNoiseGaussianLikelihood):
>>> pred_y = likelihood(gp_model(test_x), targets=labels)
"""

def _prepare_targets(self, targets, alpha_epsilon=0.01, dtype=torch.float):
def _prepare_targets(self, targets: torch.Tensor, alpha_epsilon: float = 0.01, dtype: Optional[torch.dtype] = None):
if dtype is None:
dtype = torch.get_default_dtype()

num_classes = int(targets.max() + 1)
# set alpha = \alpha_\epsilon
alpha = alpha_epsilon * torch.ones(targets.shape[-1], num_classes, device=targets.device, dtype=dtype)
Expand Down
12 changes: 5 additions & 7 deletions gpytorch/utils/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
import warnings
from typing import List, Tuple
from typing import List, Optional, Tuple

import torch

Expand Down Expand Up @@ -130,8 +130,8 @@ def create_grid(
grid_sizes: List[int],
grid_bounds: List[Tuple[float, float]],
extend: bool = True,
device="cpu",
dtype=torch.float,
device: str = "cpu",
dtype: Optional[torch.dtype] = None,
) -> List[torch.Tensor]:
"""
Creates a grid represented by a list of 1D Tensors representing the
Expand All @@ -141,16 +141,14 @@ def create_grid(
which can be important for getting good grid interpolations.
:param grid_sizes: Sizes of each grid dimension
:type grid_sizes: List[int]
:param grid_bounds: Lower and upper bounds of each grid dimension
:type grid_sizes: List[Tuple[float, float]]
:param device: target device for output (default: cpu)
:type device: torch.device, optional
:param dtype: target dtype for output (default: torch.float)
:type dtype: torch.dtype, optional
:return: Grid points for each dimension. Grid points are stored in a :obj:`torch.Tensor` with shape `grid_sizes[i]`.
:rtype: List[torch.Tensor]
"""
if dtype is None:
dtype = torch.get_default_dtype()
grid = []
for i in range(len(grid_bounds)):
grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_sizes[i] - 2)
Expand Down
17 changes: 3 additions & 14 deletions test/examples/test_kissgp_gp_regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#!/usr/bin/env python3

import os
import random
import unittest
from math import exp, pi

Expand All @@ -12,6 +10,7 @@
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.priors import SmoothedBoxPrior
from gpytorch.test.base_test_case import BaseTestCase
from gpytorch.test.utils import least_used_cuda_device
from torch import optim

Expand Down Expand Up @@ -44,18 +43,8 @@ def forward(self, x):
return MultivariateNormal(mean_x, covar_x)


class TestKISSGPRegression(unittest.TestCase):
def setUp(self):
if os.getenv("UNLOCK_SEED") is None or os.getenv("UNLOCK_SEED").lower() == "false":
self.rng_state = torch.get_rng_state()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
random.seed(0)

def tearDown(self):
if hasattr(self, "rng_state"):
torch.set_rng_state(self.rng_state)
class TestKISSGPRegression(unittest.TestCase, BaseTestCase):
seed = 0

def test_kissgp_gp_mean_abs_error(self):
train_x, train_y, test_x, test_y = make_data()
Expand Down
33 changes: 19 additions & 14 deletions test/examples/test_kissgp_kronecker_product_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from math import pi

import os
import random
import torch
import unittest

Expand All @@ -13,6 +11,7 @@
from gpytorch.means import ConstantMean
from gpytorch.priors import SmoothedBoxPrior
from gpytorch.distributions import MultivariateNormal
from gpytorch.test.base_test_case import BaseTestCase
from torch import optim

# Simple training data: let's try to learn a sine function,
Expand Down Expand Up @@ -52,18 +51,8 @@ def forward(self, x):
return MultivariateNormal(mean_x, covar_x)


class TestKISSGPKroneckerProductRegression(unittest.TestCase):
def setUp(self):
if os.getenv("UNLOCK_SEED") is None or os.getenv("UNLOCK_SEED").lower() == "false":
self.rng_state = torch.get_rng_state()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
random.seed(0)

def tearDown(self):
if hasattr(self, "rng_state"):
torch.set_rng_state(self.rng_state)
class TestKISSGPKroneckerProductRegression(unittest.TestCase, BaseTestCase):
seed = 0

def test_kissgp_gp_mean_abs_error(self):
likelihood = GaussianLikelihood()
Expand Down Expand Up @@ -99,6 +88,22 @@ def test_kissgp_gp_mean_abs_error(self):
mean_abs_error = torch.mean(torch.abs(test_y - test_preds))
self.assertLess(mean_abs_error.squeeze().item(), 0.2)

# Try drawing a sample - make sure there's no errors
with torch.no_grad(), gpytorch.settings.max_root_decomposition_size(100):
with gpytorch.settings.fast_pred_samples():
gp_model(train_x).rsample(torch.Size([1]))


class TestKISSGPKroneckerProductRegressionDouble(TestKISSGPKroneckerProductRegression):
def setUp(self):
super().setUp
self.default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)

def tearDown(self):
super().tearDown()
torch.set_default_dtype(self.default_dtype)


if __name__ == "__main__":
unittest.main()

0 comments on commit 58af927

Please sign in to comment.