Skip to content

Commit

Permalink
Add batch_shape property to models
Browse files Browse the repository at this point in the history
A consistent API like this will be useful and avoid ad-hoc inferring of batch shapes. See pytorch#587 for more context.
  • Loading branch information
Balandat committed Nov 19, 2020
1 parent 8a64f8b commit 4aee944
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 8 deletions.
37 changes: 37 additions & 0 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ def _validate_tensor_args(
f" {Yvar.shape})."
)

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.
This is a batch shape from an I/O perspective, independent of the internal
representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
to the `posterior` method returns a Posterior object over an output of
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
"""
return self.train_inputs[0].shape[:-2]

@property
def num_outputs(self) -> int:
r"""The number of outputs of the model."""
Expand Down Expand Up @@ -234,6 +246,18 @@ def _set_dimensions(self, train_X: Tensor, train_Y: Tensor) -> None:
train_X=train_X, train_Y=train_Y
)

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.
This is a batch shape from an I/O perspective, independent of the internal
representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
to the `posterior` method returns a Posterior object over an output of
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
"""
return self._input_batch_shape

def _transform_tensor_args(
self, X: Tensor, Y: Tensor, Yvar: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
Expand Down Expand Up @@ -442,6 +466,19 @@ class ModelListGPyTorchModel(GPyTorchModel, ABC):
evaluation of submodels.
"""

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.
This is a batch shape from an I/O perspective, independent of the internal
representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
to the `posterior` method returns a Posterior object over an output of
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
"""
# TODO: Either check that batch shapes match across models, or broadcast them
raise NotImplementedError

def posterior(
self,
X: Tensor,
Expand Down
14 changes: 14 additions & 0 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

import torch
from botorch import settings
from botorch.posteriors import Posterior
from botorch.sampling.samplers import MCSampler
Expand Down Expand Up @@ -51,6 +52,19 @@ def posterior(
"""
pass # pragma: no cover

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.
This is a batch shape from an I/O perspective, independent of the internal
representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
to the `posterior` method returns a Posterior object over an output of
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
"""
cls_name = self.__class__.__name__
raise NotImplementedError(f"{cls_name} does not define batch_shape property")

@property
def num_outputs(self) -> int:
r"""The number of outputs of the model."""
Expand Down
22 changes: 17 additions & 5 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,6 @@ def __deepcopy__(self, memo) -> PairwiseGP:
self.__deepcopy__ = dcp
return new_model

@property
def num_outputs(self) -> int:
r"""The number of outputs of the model."""
return self._num_outputs

def _has_no_data(self):
r"""Return true if the model does not have both datapoints and comparisons"""
return (
Expand Down Expand Up @@ -646,6 +641,23 @@ def _util_newton_updates(self, x0, max_iter=1, xtol=None) -> Tensor:

# ============== public APIs ==============

@property
def num_outputs(self) -> int:
r"""The number of outputs of the model."""
return self._num_outputs

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.
This is a batch shape from an I/O perspective, independent of the internal
representation of the model (as e.g. in BatchedMultiOutputGPyTorchModel).
For a model with `m` outputs, a `test_batch_shape x q x d`-shaped input `X`
to the `posterior` method returns a Posterior object over an output of
shape `broadcast(test_batch_shape, model.batch_shape) x q x m`.
"""
return self.datapoints.shape[:-2]

def set_train_data(
self, datapoints: Tensor, comparisons: Tensor, update_model: bool = True
) -> None:
Expand Down
6 changes: 3 additions & 3 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def gen_value_function_initial_conditions(
Returns:
A `num_restarts x batch_shape x q x d` tensor that can be used as initial
conditions for `optimize_acqf()`. Here `batch_shape` is the
`_input_batch_shape` of value function model.
conditions for `optimize_acqf()`. Here `batch_shape` is the batch shape
of value function model.
Example:
>>> fant_X = torch.rand(5, 1, 2)
Expand Down Expand Up @@ -325,7 +325,7 @@ def gen_value_function_initial_conditions(
},
)

batch_shape = acq_function.model._input_batch_shape
batch_shape = acq_function.model.batch_shape
# sampling from the optimizers
n_value = int((1 - frac_random) * raw_samples) # number of non-random ICs
if n_value > 0:
Expand Down
4 changes: 4 additions & 0 deletions test/models/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_gpytorch_model(self):
# basic test
model = SimpleGPyTorchModel(train_X, train_Y, octf)
self.assertEqual(model.num_outputs, 1)
self.assertEqual(model.batch_shape, torch.Size())
test_X = torch.rand(2, 1, **tkwargs)
posterior = model.posterior(test_X)
self.assertIsInstance(posterior, GPyTorchPosterior)
Expand Down Expand Up @@ -181,6 +182,7 @@ def test_batched_multi_output_gpytorch_model(self):
# basic test
model = SimpleBatchedMultiOutputGPyTorchModel(train_X, train_Y)
self.assertEqual(model.num_outputs, 2)
self.assertEqual(model.batch_shape, torch.Size())
test_X = torch.rand(2, 1, **tkwargs)
posterior = model.posterior(test_X)
self.assertIsInstance(posterior, GPyTorchPosterior)
Expand Down Expand Up @@ -257,6 +259,8 @@ def test_model_list_gpytorch_model(self):
m2 = SimpleGPyTorchModel(train_X2, train_Y2)
model = SimpleModelListGPyTorchModel(m1, m2)
self.assertEqual(model.num_outputs, 2)
with self.assertRaises(NotImplementedError):
model.batch_shape
test_X = torch.rand(2, 1, **tkwargs)
posterior = model.posterior(test_X)
self.assertIsInstance(posterior, GPyTorchPosterior)
Expand Down
2 changes: 2 additions & 0 deletions test/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def test_not_so_abstract_base_model(self):
model.condition_on_observations(None, None)
with self.assertRaises(NotImplementedError):
model.num_outputs
with self.assertRaises(NotImplementedError):
model.batch_shape
with self.assertRaises(NotImplementedError):
model.subset_output([0])
with self.assertRaises(NotImplementedError):
Expand Down
1 change: 1 addition & 0 deletions test/models/test_pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_pairwise_gp(self):
model.covar_module.outputscale_prior, SmoothedBoxPrior
)
self.assertEqual(model.num_outputs, 1)
self.assertEqual(model.batch_shape, batch_shape)

# test custom models
custom_m = PairwiseGP(**model_kwargs, covar_module=LinearKernel())
Expand Down

0 comments on commit 4aee944

Please sign in to comment.