Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update multivariate_normal.py to use loc instead of mean #2488

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 13 additions & 13 deletions gpytorch/distributions/multivariate_normal.py
Expand Up @@ -25,33 +25,33 @@ class MultivariateNormal(TMultivariateNormal, Distribution):
Constructs a multivariate normal random variable, based on mean and covariance.
Can be multivariate, or a batch of multivariate normals

Passing a vector mean corresponds to a multivariate normal.
Passing a matrix mean corresponds to a batch of multivariate normals.
Passing a vector loc corresponds to a multivariate normal.
Passing a matrix loc corresponds to a batch of multivariate normals.

:param mean: `... x N` mean of mvn distribution.
:param loc: `... x N` mean of mvn distribution.
:param covariance_matrix: `... x N X N` covariance matrix of mvn distribution.
:param validate_args: If True, validate `mean` anad `covariance_matrix` arguments. (Default: False.)
:param validate_args: If True, validate `loc` anad `covariance_matrix` arguments. (Default: False.)

:ivar torch.Size base_sample_shape: The shape of a base sample (without
batching) that is used to generate a single sample.
:ivar torch.Tensor covariance_matrix: The covariance matrix, represented as a dense :class:`torch.Tensor`
:ivar ~linear_operator.LinearOperator lazy_covariance_matrix: The covariance matrix, represented
as a :class:`~linear_operator.LinearOperator`.
:ivar torch.Tensor mean: The mean.
:ivar torch.Tensor loc: The mean.
:ivar torch.Tensor stddev: The standard deviation.
:ivar torch.Tensor variance: The variance.
"""

def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], validate_args: bool = False):
self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator)
def __init__(self, loc: Union[Tensor, LinearOperator], covariance_matrix: Union[Tensor, LinearOperator], validate_args: bool = False):
self._islazy = isinstance(loc, LinearOperator) or isinstance(covariance_matrix, LinearOperator)
if self._islazy:
if validate_args:
ms = mean.size(-1)
ms = loc.size(-1)
cs1 = covariance_matrix.size(-1)
cs2 = covariance_matrix.size(-2)
if not (ms == cs1 and ms == cs2):
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
self.loc = mean
raise ValueError(f"Wrong shapes in {self._repr_sizes(loc, covariance_matrix)}")
self.loc = loc
self._covar = covariance_matrix
self.__unbroadcasted_scale_tril = None
self._validate_args = validate_args
Expand All @@ -62,7 +62,7 @@ def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator
# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
else:
super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
super().__init__(loc=loc, covariance_matrix=covariance_matrix, validate_args=validate_args)

def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size:
"""
Expand All @@ -78,8 +78,8 @@ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size
return sample_shape + self._batch_shape + self.base_sample_shape

@staticmethod
def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator]) -> str:
return f"MultivariateNormal(loc: {mean.size()}, scale: {covariance_matrix.size()})"
def _repr_sizes(loc: Tensor, covariance_matrix: Union[Tensor, LinearOperator]) -> str:
return f"MultivariateNormal(loc: {loc.size()}, scale: {covariance_matrix.size()})"

@property
def _unbroadcasted_scale_tril(self) -> Tensor:
Expand Down