Skip to content

Commit

Permalink
Add batch_shape property to SingleTaskVariationalGP
Browse files Browse the repository at this point in the history
This enables the use of `SingleTaskVariationalGP` with certain botorch features (e.g. with entropy-based acquistion functions as requested in pytorch#1795).

This is a bit of a band-aid, the proper thing to do here is to fix up the PR upstreaming this to gpytorch (cornellius-gp/gpytorch#2307) to enable support for `batch_shape` on all approximate gpytorch models, and then just call that on the `model` in `ApproximateGPyTorchModel`.
  • Loading branch information
Balandat committed Apr 19, 2023
1 parent 71690a8 commit 8c1efd9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
11 changes: 11 additions & 0 deletions botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,17 @@ def __init__(

self.to(train_X)

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.
This is a batch shape from an I/O perspective. For a model with `m`
outputs, a `test_batch_shape x q x d`-shaped input `X` to the `posterior`
method returns a Posterior object over an output of shape
`broadcast(test_batch_shape, model.batch_shape) x q x m`.
"""
return self._input_batch_shape

def init_inducing_points(
self,
inputs: Tensor,
Expand Down
2 changes: 2 additions & 0 deletions test/models/test_approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def test_posterior(self):
model = SingleTaskVariationalGP(tx, ty, inducing_points=tx)
posterior = model.posterior(test)
self.assertIsInstance(posterior, GPyTorchPosterior)
# test batch_shape property
self.assertEqual(model.batch_shape, tx.shape[:-2])

def test_variational_setUp(self):
for dtype in [torch.float, torch.double]:
Expand Down

0 comments on commit 8c1efd9

Please sign in to comment.