Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quadratic-type means & implement linear operations for means #2428

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions gpytorch/means/__init__.py
Expand Up @@ -8,6 +8,12 @@
from .linear_mean_gradgrad import LinearMeanGradGrad
from .mean import Mean
from .multitask_mean import MultitaskMean
from .positive_quadratic_mean import PositiveQuadraticMean
from .positive_quadratic_mean_grad import PositiveQuadraticMeanGrad
from .positive_quadratic_mean_gradgrad import PositiveQuadraticMeanGradGrad
from .quadratic_mean import QuadraticMean
from .quadratic_mean_grad import QuadraticMeanGrad
from .quadratic_mean_gradgrad import QuadraticMeanGradGrad
from .zero_mean import ZeroMean

__all__ = [
Expand All @@ -18,6 +24,12 @@
"LinearMean",
"LinearMeanGrad",
"LinearMeanGradGrad",
"QuadraticMean",
"QuadraticMeanGrad",
"QuadraticMeanGradGrad",
"PositiveQuadraticMean",
"PositiveQuadraticMeanGrad",
"PositiveQuadraticMeanGradGrad",
"MultitaskMean",
"ZeroMean",
]
81 changes: 81 additions & 0 deletions gpytorch/means/mean.py
@@ -1,5 +1,11 @@
#!/usr/bin/env python3

from typing import Iterable, Union

from torch import Tensor

from torch.nn import ModuleList

from ..module import Module


Expand All @@ -22,3 +28,78 @@ def __call__(self, x):
res = super(Mean, self).__call__(x)

return res

def __add__(self, other: "Mean") -> "Mean":
means = []
for entity in [self, other]:
means += entity.means if isinstance(entity, AdditiveMean) else [entity]
return AdditiveMean(*means)

def __sub__(self, other: "Mean") -> "Mean":
means = []
negative_means = []
means += self.means if isinstance(self, AdditiveMean) else [self]
negative_means += other.means if isinstance(other, AdditiveMean) else [other]
return AdditiveMean(*means, negative_means=negative_means)

def __mul__(self, other: Union[float, int]) -> "Mean":
if isinstance(other, Union[float, int]):
return HometetiaMean(self, other)
else:
raise NotImplementedError(f"Multiplication between 'Mean' and '{type(other)}' is not supported.")

def __rmul__(self, other: Union[float, int, "Mean"]) -> "Mean":
return self.__mul__(other)

def __neg__(self) -> "Mean":
return HometetiaMean(self, -1.0)


class AdditiveMean(Mean):
"""
A Mean that supports summing over multiple component means.

Example:
>>> mean_module = LinearMean(2) + PositiveQuadraticMean(2)
>>> x1 = torch.randn(50, 2)
>>> additive_mean_vector = mean_module(x1)

:param means: Means to add together.
:param negative_means: Means to subtract from the sum.
"""

def __init__(self, *means: Iterable[Mean], negative_means: Union[Iterable[Mean], None] = None):
super(AdditiveMean, self).__init__()
self.means = ModuleList(means)
self.negative_means = ModuleList(negative_means) if negative_means is not None else None

def forward(self, x: Tensor) -> Tensor:
res = 0.0
for mean in self.means:
res += mean(x)
if self.negative_means is not None:
for mean in self.negative_means:
res -= mean(x)
return res


class HometetiaMean(Mean):
"""
A Mean multiplied with a constant.

Example:
>>> mean_module = (-2) * PositiveQuadraticMean(2)
>>> x1 = torch.randn(50, 2)
>>> additive_mean_vector = mean_module(x1)

:param mean: Mean module.
:param c: Coefficient to multiply with the mean module.
"""

def __init__(self, mean: Mean, coefficient: float):
super(HometetiaMean, self).__init__()
self.mean = mean
self.c = coefficient

def forward(self, x: Tensor) -> Tensor:
return self.c * self.mean(x)
40 changes: 40 additions & 0 deletions gpytorch/means/positive_quadratic_mean.py
@@ -0,0 +1,40 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class PositiveQuadraticMean(Mean):
r"""
A positive quadratic prior mean function and its first derivative, i.e.:

.. math::

\mu(\mathbf x) &= \frac12 \mathbf x^\top \cdot A \cdot \mathbf x

where :math:`A = L L^\top` and L a lower triangular matrix.

:param input_size: dimension of input :math:`\mathbf x`.
:type input_size: int
:param batch_shape: The batch shape of the learned constant(s) (default: []).
:type batch_shape: torch.Size, optional

:var torch.Tensor cholesky: vector containing :math:`L` components.
"""

def __init__(self, input_size: int, batch_shape: torch.Size = torch.Size()):
super().__init__()
self.dim = input_size
self.register_parameter(
name="cholesky", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size * (input_size + 1) // 2))
)

def forward(self, x):
xl = torch.zeros(*x.shape, device=x.device)
for i in range(x.shape[-2]):
for j in range(self.dim):
for k in range(j, self.dim):
xl[..., i, j] += self.cholesky[..., k * (k + 1) // 2 + j] * x[..., i, k]
res = xl.pow(2).sum(-1).div(2)
return res
46 changes: 46 additions & 0 deletions gpytorch/means/positive_quadratic_mean_grad.py
@@ -0,0 +1,46 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class PositiveQuadraticMeanGrad(Mean):
r"""
A positive quadratic prior mean function and its first derivative, i.e.:

.. math::

\mu(\mathbf x) &= \frac12 \mathbf x^\top \cdot A \cdot \mathbf x \\
\nabla \mu(\mathbf x) &= \mathbf x \cdot A

where :math:`A = L L^\top` and L a lower triangular matrix.

:param input_size: dimension of input :math:`\mathbf x`.
:type input_size: int
:param batch_shape: The batch shape of the learned constant(s) (default: []).
:type batch_shape: torch.Size, optional

:var torch.Tensor cholesky: vector containing :math:`L` components.
"""

def __init__(self, input_size: int, batch_shape: torch.Size = torch.Size()):
super().__init__()
self.dim = input_size
self.register_parameter(
name="cholesky", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size * (input_size + 1) // 2))
)

def forward(self, x):
xl = torch.zeros(*x.shape, device=x.device)
for i in range(x.shape[-2]):
for j in range(self.dim):
for k in range(j, self.dim):
xl[..., i, j] += self.cholesky[..., k * (k + 1) // 2 + j] * x[..., i, k]
res = xl.pow(2).sum(-1).div(2)
dres = torch.zeros(*x.shape, device=x.device)
for i in range(x.shape[-2]):
for j in range(self.dim):
for k in range(j + 1):
dres[..., i, j] += self.cholesky[..., j * (j + 1) // 2 + k] * xl[..., i, k]
return torch.cat((res.unsqueeze(-1), dres), -1)
50 changes: 50 additions & 0 deletions gpytorch/means/positive_quadratic_mean_gradgrad.py
@@ -0,0 +1,50 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class PositiveQuadraticMeanGradGrad(Mean):
r"""
A positive quadratic prior mean function and its first and second derivative, i.e.:

.. math::

\mu(\mathbf x) &= \frac12 \mathbf x^\top \cdot A \cdot \mathbf x \\
\nabla \mu(\mathbf x) &= \mathbf x \cdot A \\
\nabla^2 \mu(\mathbf x) &= \mathbf A \\

where :math:`A = L L^\top` and L a lower triangular matrix.

:param input_size: dimension of input :math:`\mathbf x`.
:type input_size: int
:param batch_shape: The batch shape of the learned constant(s) (default: []).
:type batch_shape: torch.Size, optional

:var torch.Tensor cholesky: vector containing :math:`L` components.
"""

def __init__(self, input_size: int, batch_shape: torch.Size = torch.Size()):
super().__init__()
self.dim = input_size
self.register_parameter(
name="cholesky", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size * (input_size + 1) // 2))
)

def forward(self, x):
xl = torch.zeros(*x.shape, device=x.device)
for i in range(x.shape[-2]):
for j in range(self.dim):
for k in range(j, self.dim):
xl[..., i, j] += self.cholesky[..., k * (k + 1) // 2 + j] * x[..., i, k]
res = xl.pow(2).sum(-1).div(2)
dres = torch.zeros(*x.shape, device=x.device)
ddres = torch.zeros(*x.shape, device=x.device)
for i in range(x.shape[-2]):
for j in range(self.dim):
for k in range(j + 1):
c = self.cholesky[..., j * (j + 1) // 2 + k]
dres[..., i, j] += self.cholesky[..., j * (j + 1) // 2 + k] * xl[..., i, k]
ddres[..., i, j] += c**2
return torch.cat((res.unsqueeze(-1), dres, ddres), -1)
41 changes: 41 additions & 0 deletions gpytorch/means/quadratic_mean.py
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class QuadraticMean(Mean):
r"""
A quadratic prior mean function, i.e.:

.. math::

\mu(\mathbf x) &= \frac12 \mathbf x^\top \cdot A \cdot \mathbf x

where :math:`A` and L a square matrix.

:param input_size: dimension of input :math:`\mathbf x`.
:type input_size: int
:param batch_shape: The batch shape of the learned constant(s) (default: []).
:type batch_shape: torch.Size, optional

:var torch.Tensor A: a square matrix.
"""

def __init__(self, input_size: int, batch_shape: torch.Size = torch.Size()):
super().__init__()
self.dim = input_size
self.register_parameter(
name="A", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, input_size))
)

def forward(self, x):
res = torch.zeros(*x.shape[:-1], device=x.device)
for i in range(x.shape[-2]):
for j in range(self.dim):
s = 0.0
for k in range(self.dim):
s += x[..., i, k] * self.A[..., k, j]
res[..., i] += x[..., i, j] * s
return res.div(2)
44 changes: 44 additions & 0 deletions gpytorch/means/quadratic_mean_grad.py
@@ -0,0 +1,44 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class QuadraticMeanGrad(Mean):
r"""
A quadratic prior mean function and its first derivative, i.e.:

.. math::

\mu(\mathbf x) &= \frac12 \mathbf x^\top \cdot A \cdot \mathbf x \\
\nabla \mu(\mathbf x) &= \frac12 \mathbf x \cdot ( A + A^\top )

where :math:`A` and L a square matrix.

:param input_size: dimension of input :math:`\mathbf x`.
:type input_size: int
:param batch_shape: The batch shape of the learned constant(s) (default: []).
:type batch_shape: torch.Size, optional

:var torch.Tensor A: a square matrix.
"""

def __init__(self, input_size: int, batch_shape: torch.Size = torch.Size()):
super().__init__()
self.dim = input_size
self.register_parameter(
name="A", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, input_size))
)

def forward(self, x):
res = torch.zeros(*x.shape[:-1], device=x.device)
dres = torch.zeros(*x.shape, device=x.device)
for i in range(x.shape[-2]):
for j in range(self.dim):
s = 0.0
for k in range(self.dim):
s += x[..., i, k] * self.A[..., k, j]
dres[..., i, j] += x[..., i, k] * (self.A[..., j, k] + self.A[..., k, j])
res[..., i] += x[..., i, j] * s
return torch.cat((res.div(2).unsqueeze(-1), dres.div(2)), -1)
47 changes: 47 additions & 0 deletions gpytorch/means/quadratic_mean_gradgrad.py
@@ -0,0 +1,47 @@
#!/usr/bin/env python3

import torch

from .mean import Mean


class QuadraticMeanGradGrad(Mean):
r"""
A quadratic prior mean function and its first and second derivative, i.e.:

.. math::

\mu(\mathbf x) &= \frac12 \mathbf x^\top \cdot A \cdot \mathbf x \\
\nabla \mu(\mathbf x) &= \frac12 \mathbf x \cdot ( A + A^\top ) \\
\nabla^2 \mu(\mathbf x) &= \mathbf A \\

where :math:`A` and L a square matrix.

:param input_size: dimension of input :math:`\mathbf x`.
:type input_size: int
:param batch_shape: The batch shape of the learned constant(s) (default: []).
:type batch_shape: torch.Size, optional

:var torch.Tensor A: a square matrix.
"""

def __init__(self, input_size: int, batch_shape: torch.Size = torch.Size()):
super().__init__()
self.dim = input_size
self.register_parameter(
name="A", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, input_size))
)

def forward(self, x):
res = torch.zeros(*x.shape[:-1], device=x.device)
dres = torch.zeros(*x.shape, device=x.device)
ddres = torch.zeros(*x.shape, device=x.device)
for i in range(x.shape[-2]):
for j in range(self.dim):
s = 0.0
for k in range(self.dim):
s += x[..., i, k] * self.A[..., k, j]
dres[..., i, j] += x[..., i, k] * (self.A[..., j, k] + self.A[..., k, j])
res[..., i] += x[..., i, j] * s
ddres[..., i, j] = self.A[..., j, j]
return torch.cat((res.div(2).unsqueeze(-1), dres.div(2), ddres), -1)
2 changes: 1 addition & 1 deletion test/means/test_linear_mean.py
Expand Up @@ -12,7 +12,7 @@ class TestLinearMean(BaseMeanTestCase, unittest.TestCase):
def create_mean(self, input_size=1, batch_shape=torch.Size(), bias=True, **kwargs):
return LinearMean(input_size=input_size, batch_shape=batch_shape, bias=bias)

def forward_vec(self):
def test_forward_vec(self):
n = 4
test_x = torch.randn(n)
mean = self.create_mean(input_size=1)
Expand Down