Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple kernel computation class initialisation from kernel #328

Merged
merged 6 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions gpjax/kernels/approximations/rff.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class RFF(AbstractKernel):
base_kernel: AbstractKernel = None
num_basis_fns: int = static_field(50)
frequencies: Float[Array, "M 1"] = param_field(None, bijector=tfb.Identity())
compute_engine: BasisFunctionComputation = static_field(
BasisFunctionComputation(), repr=False
)
key: KeyArray = static_field(PRNGKey(123))

def __post_init__(self) -> None:
Expand All @@ -46,7 +49,6 @@ def __post_init__(self) -> None:
set the computation engine to be the basis function computation engine.
"""
self._check_valid_base_kernel(self.base_kernel)
self.compute_engine = BasisFunctionComputation

if self.frequencies is None:
n_dims = self.base_kernel.ndims
Expand Down Expand Up @@ -83,4 +85,4 @@ def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
-------
Float[Array, "N L"]: A $`N \times L`$ array of features where $`L = 2M`$.
"""
return self.compute_engine(self).compute_features(x)
return self.compute_engine.compute_features(self, x)
10 changes: 3 additions & 7 deletions gpjax/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================


import abc
from dataclasses import dataclass
from functools import partial
Expand All @@ -22,7 +21,6 @@
Callable,
List,
Optional,
Type,
Union,
)
import jax.numpy as jnp
Expand Down Expand Up @@ -51,9 +49,7 @@
class AbstractKernel(Module):
r"""Base kernel class."""

compute_engine: Type[AbstractKernelComputation] = static_field(
DenseKernelComputation
)
compute_engine: AbstractKernelComputation = static_field(DenseKernelComputation())
active_dims: Optional[List[int]] = static_field(None)
name: str = static_field("AbstractKernel")

Expand All @@ -62,10 +58,10 @@ def ndims(self):
return 1 if not self.active_dims else len(self.active_dims)

def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]):
return self.compute_engine(self).cross_covariance(x, y)
return self.compute_engine.cross_covariance(self, x, y)

def gram(self, x: Num[Array, "N D"]):
return self.compute_engine(self).gram(x)
return self.compute_engine.gram(self, x)

def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]:
r"""Slice out the relevant columns of the input matrix.
Expand Down
19 changes: 13 additions & 6 deletions gpjax/kernels/computations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import abc
from dataclasses import dataclass
import typing as tp

from jax import vmap
from jaxtyping import (
Expand All @@ -29,37 +30,40 @@
)
from gpjax.typing import Array

Kernel = tp.TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821


@dataclass
class AbstractKernelComputation:
r"""Abstract class for kernel computations."""

kernel: "gpjax.kernels.base.AbstractKernel" # noqa: F821

def gram(
self,
kernel: Kernel,
x: Num[Array, "N D"],
) -> LinearOperator:
r"""Compute Gram covariance operator of the kernel function.

Args:
kernel (AbstractKernel): the kernel function.
x (Float[Array, "N N"]): The inputs to the kernel function.

Returns
-------
LinearOperator: Gram covariance operator of the kernel function.
"""
Kxx = self.cross_covariance(x, x)
Kxx = self.cross_covariance(kernel, x, x)
return DenseLinearOperator(Kxx)

@abc.abstractmethod
def cross_covariance(
self, x: Num[Array, "N D"], y: Num[Array, "M D"]
self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"]
) -> Float[Array, "N M"]:
r"""For a given kernel, compute the NxM gram matrix on an a pair
of input matrices with shape NxD and MxD.

Args:
kernel (AbstractKernel): the kernel function.
x (Float[Array,"N D"]): The first input matrix.
y (Float[Array,"M D"]): The second input matrix.

Expand All @@ -69,15 +73,18 @@ def cross_covariance(
"""
raise NotImplementedError

def diagonal(self, inputs: Num[Array, "N D"]) -> DiagonalLinearOperator:
def diagonal(
self, kernel: Kernel, inputs: Num[Array, "N D"]
) -> DiagonalLinearOperator:
r"""For a given kernel, compute the elementwise diagonal of the
NxN gram matrix on an input matrix of shape NxD.

Args:
kernel (AbstractKernel): the kernel function.
inputs (Float[Array, "N D"]): The input matrix.

Returns
-------
DiagonalLinearOperator: The computed diagonal variance entries.
"""
return DiagonalLinearOperator(diag=vmap(lambda x: self.kernel(x, x))(inputs))
return DiagonalLinearOperator(diag=vmap(lambda x: kernel(x, x))(inputs))
41 changes: 24 additions & 17 deletions gpjax/kernels/computations/basis_functions.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,76 @@
from dataclasses import dataclass
import typing as tp

import jax.numpy as jnp
from jaxtyping import Float
from jaxtyping import (
Array,
Float,
)

from gpjax.kernels.computations.base import AbstractKernelComputation
from gpjax.linops import DenseLinearOperator
from gpjax.typing import Array

Kernel = tp.TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works with beartype but static type checkers complain about this way of annotating the bound, but all alternative solutions I tried resulted in circular imports errors at runtime type checking as explained in #293 (comment).



@dataclass
class BasisFunctionComputation(AbstractKernelComputation):
r"""Compute engine class for finite basis function approximations to a kernel."""

num_basis_fns: int = None

def cross_covariance(
self, x: Float[Array, "N D"], y: Float[Array, "M D"]
self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"]
) -> Float[Array, "N M"]:
r"""Compute an approximate cross-covariance matrix.

For a pair of inputs, compute the cross covariance matrix between the inputs.

Args:
kernel (Kernel): the kernel function.
x: (Float[Array, "N D"]): A $`N \times D`$ array of inputs.
y: (Float[Array, "M D"]): A $`M \times D`$ array of inputs.

Returns:
Float[Array, "N M"]: A $N \times M$ array of cross-covariances.
"""
z1 = self.compute_features(x)
z2 = self.compute_features(y)
return self.scaling * jnp.matmul(z1, z2.T)
z1 = self.compute_features(kernel, x)
z2 = self.compute_features(kernel, y)
return self.scaling(kernel) * jnp.matmul(z1, z2.T)

def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
def gram(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
r"""Compute an approximate Gram matrix.

For the Gram matrix, we can save computations by computing only one matrix
multiplication between the inputs and the scaled frequencies.

Args:
kernel (Kernel): the kernel function.
inputs (Float[Array, "N D"]): A $`N x D`$ array of inputs.

Returns:
DenseLinearOperator: A dense linear operator representing the
$`N \times N`$ Gram matrix.
"""
z1 = self.compute_features(inputs)
return DenseLinearOperator(self.scaling * jnp.matmul(z1, z1.T))
z1 = self.compute_features(kernel, inputs)
return DenseLinearOperator(self.scaling(kernel) * jnp.matmul(z1, z1.T))

def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
def compute_features(
self, kernel: Kernel, x: Float[Array, "N D"]
) -> Float[Array, "N L"]:
r"""Compute the features for the inputs.

Args:
kernel (Kernel): the kernel function.
x (Float[Array, "N D"]): A $`N \times D`$ array of inputs.

Returns
-------
Float[Array, "N L"]: A $`N \times L`$ array of features where $`L = 2M`$.
"""
frequencies = self.kernel.frequencies
scaling_factor = self.kernel.base_kernel.lengthscale
frequencies = kernel.frequencies
scaling_factor = kernel.base_kernel.lengthscale
z = jnp.matmul(x, (frequencies / scaling_factor).T)
z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
return z

@property
def scaling(self):
return self.kernel.base_kernel.variance / self.kernel.num_basis_fns
def scaling(self, kernel):
return kernel.base_kernel.variance / kernel.num_basis_fns
frazane marked this conversation as resolved.
Show resolved Hide resolved
31 changes: 22 additions & 9 deletions gpjax/kernels/computations/constant_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,72 @@
# limitations under the License.
# ==============================================================================

import typing as tp

from jax import vmap
import jax.numpy as jnp
from jaxtyping import Float
from jaxtyping import (
Array,
Float,
)

from gpjax.kernels.computations.base import AbstractKernelComputation
from gpjax.kernels.computations import AbstractKernelComputation
from gpjax.linops import (
ConstantDiagonalLinearOperator,
DiagonalLinearOperator,
)
from gpjax.typing import Array

Kernel = tp.TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821


class ConstantDiagonalKernelComputation(AbstractKernelComputation):
def gram(self, x: Float[Array, "N D"]) -> ConstantDiagonalLinearOperator:
def gram(
self, kernel: Kernel, x: Float[Array, "N D"]
) -> ConstantDiagonalLinearOperator:
r"""Compute the Gram matrix.

Compute Gram covariance operator of the kernel function.

Args:
kernel (Kernel): the kernel function.
x (Float[Array, "N N"]): The inputs to the kernel function.
"""
value = self.kernel(x[0], x[0])
value = kernel(x[0], x[0])

return ConstantDiagonalLinearOperator(
value=jnp.atleast_1d(value), size=x.shape[0]
)

def diagonal(self, inputs: Float[Array, "N D"]) -> DiagonalLinearOperator:
def diagonal(
self, kernel: Kernel, inputs: Float[Array, "N D"]
) -> DiagonalLinearOperator:
r"""Compute the diagonal Gram matrix's entries.

For a given kernel, compute the elementwise diagonal of the
NxN gram matrix on an input matrix of shape $`N\times D`$.

Args:
kernel (Kernel): the kernel function.
inputs (Float[Array, "N D"]): The input matrix.

Returns
-------
DiagonalLinearOperator: The computed diagonal variance entries.
"""
diag = vmap(lambda x: self.kernel(x, x))(inputs)
diag = vmap(lambda x: kernel(x, x))(inputs)

return DiagonalLinearOperator(diag=diag)

def cross_covariance(
self, x: Float[Array, "N D"], y: Float[Array, "M D"]
self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"]
) -> Float[Array, "N M"]:
r"""Compute the cross-covariance matrix.

For a given kernel, compute the NxM covariance matrix on a pair of input
matrices of shape NxD and MxD.

Args:
kernel (Kernel): the kernel function.
x (Float[Array,"N D"]): The input matrix.
y (Float[Array,"M D"]): The input matrix.

Expand All @@ -75,5 +88,5 @@ def cross_covariance(
"""
# TODO: This is currently a dense implementation. We should implement
# a sparse LinearOperator for non-square cross-covariance matrices.
cross_cov = vmap(lambda x: vmap(lambda y: self.kernel(x, y))(y))(x)
cross_cov = vmap(lambda x: vmap(lambda y: kernel(x, y))(y))(x)
return cross_cov
14 changes: 10 additions & 4 deletions gpjax/kernels/computations/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
# limitations under the License.
# ==============================================================================

import beartype.typing as tp
from jax import vmap
from jaxtyping import Float
from jaxtyping import (
Array,
Float,
)

from gpjax.kernels.computations.base import AbstractKernelComputation
from gpjax.typing import Array

Kernel = tp.TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821


class DenseKernelComputation(AbstractKernelComputation):
Expand All @@ -26,20 +31,21 @@ class DenseKernelComputation(AbstractKernelComputation):
"""

def cross_covariance(
self, x: Float[Array, "N D"], y: Float[Array, "M D"]
self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"]
) -> Float[Array, "N M"]:
r"""Compute the cross-covariance matrix.

For a given kernel, compute the NxM covariance matrix on a pair of input
matrices of shape $`NxD`$ and $`MxD`$.

Args:
kernel (Kernel): the kernel function.
x (Float[Array,"N D"]): The input matrix.
y (Float[Array,"M D"]): The input matrix.

Returns
-------
Float[Array, "N M"]: The computed cross-covariance.
"""
cross_cov = vmap(lambda x: vmap(lambda y: self.kernel(x, y))(y))(x)
cross_cov = vmap(lambda x: vmap(lambda y: kernel(x, y))(y))(x)
return cross_cov
Loading
Loading