Skip to content

Commit

Permalink
Decouple kernel computation class initialisation from kernel (#328)
Browse files Browse the repository at this point in the history
* decouple kernel init and computations

* pass all tests

* update docstrings

* make angry bear happy

* add missing docstring
  • Loading branch information
frazane authored Jun 27, 2023
1 parent e98050e commit 65c786d
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 72 deletions.
6 changes: 4 additions & 2 deletions gpjax/kernels/approximations/rff.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class RFF(AbstractKernel):
base_kernel: AbstractKernel = None
num_basis_fns: int = static_field(50)
frequencies: Float[Array, "M 1"] = param_field(None, bijector=tfb.Identity())
compute_engine: BasisFunctionComputation = static_field(
BasisFunctionComputation(), repr=False
)
key: KeyArray = static_field(PRNGKey(123))

def __post_init__(self) -> None:
Expand All @@ -46,7 +49,6 @@ def __post_init__(self) -> None:
set the computation engine to be the basis function computation engine.
"""
self._check_valid_base_kernel(self.base_kernel)
self.compute_engine = BasisFunctionComputation

if self.frequencies is None:
n_dims = self.base_kernel.ndims
Expand Down Expand Up @@ -83,4 +85,4 @@ def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
-------
Float[Array, "N L"]: A $`N \times L`$ array of features where $`L = 2M`$.
"""
return self.compute_engine(self).compute_features(x)
return self.compute_engine.compute_features(self, x)
10 changes: 3 additions & 7 deletions gpjax/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================


import abc
from dataclasses import dataclass
from functools import partial
Expand All @@ -22,7 +21,6 @@
Callable,
List,
Optional,
Type,
Union,
)
import jax.numpy as jnp
Expand Down Expand Up @@ -51,9 +49,7 @@
class AbstractKernel(Module):
r"""Base kernel class."""

compute_engine: Type[AbstractKernelComputation] = static_field(
DenseKernelComputation
)
compute_engine: AbstractKernelComputation = static_field(DenseKernelComputation())
active_dims: Optional[List[int]] = static_field(None)
name: str = static_field("AbstractKernel")

Expand All @@ -62,10 +58,10 @@ def ndims(self):
return 1 if not self.active_dims else len(self.active_dims)

def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]):
return self.compute_engine(self).cross_covariance(x, y)
return self.compute_engine.cross_covariance(self, x, y)

def gram(self, x: Num[Array, "N D"]):
return self.compute_engine(self).gram(x)
return self.compute_engine.gram(self, x)

def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]:
r"""Slice out the relevant columns of the input matrix.
Expand Down
19 changes: 13 additions & 6 deletions gpjax/kernels/computations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import abc
from dataclasses import dataclass
import typing as tp

from jax import vmap
from jaxtyping import (
Expand All @@ -29,37 +30,40 @@
)
from gpjax.typing import Array

Kernel = tp.TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821


@dataclass
class AbstractKernelComputation:
r"""Abstract class for kernel computations."""

kernel: "gpjax.kernels.base.AbstractKernel" # noqa: F821

def gram(
self,
kernel: Kernel,
x: Num[Array, "N D"],
) -> LinearOperator:
r"""Compute Gram covariance operator of the kernel function.
Args:
kernel (AbstractKernel): the kernel function.
x (Float[Array, "N N"]): The inputs to the kernel function.
Returns
-------
LinearOperator: Gram covariance operator of the kernel function.
"""
Kxx = self.cross_covariance(x, x)
Kxx = self.cross_covariance(kernel, x, x)
return DenseLinearOperator(Kxx)

@abc.abstractmethod
def cross_covariance(
self, x: Num[Array, "N D"], y: Num[Array, "M D"]
self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"]
) -> Float[Array, "N M"]:
r"""For a given kernel, compute the NxM gram matrix on an a pair
of input matrices with shape NxD and MxD.
Args:
kernel (AbstractKernel): the kernel function.
x (Float[Array,"N D"]): The first input matrix.
y (Float[Array,"M D"]): The second input matrix.
Expand All @@ -69,15 +73,18 @@ def cross_covariance(
"""
raise NotImplementedError

def diagonal(self, inputs: Num[Array, "N D"]) -> DiagonalLinearOperator:
def diagonal(
self, kernel: Kernel, inputs: Num[Array, "N D"]
) -> DiagonalLinearOperator:
r"""For a given kernel, compute the elementwise diagonal of the
NxN gram matrix on an input matrix of shape NxD.
Args:
kernel (AbstractKernel): the kernel function.
inputs (Float[Array, "N D"]): The input matrix.
Returns
-------
DiagonalLinearOperator: The computed diagonal variance entries.
"""
return DiagonalLinearOperator(diag=vmap(lambda x: self.kernel(x, x))(inputs))
return DiagonalLinearOperator(diag=vmap(lambda x: kernel(x, x))(inputs))
44 changes: 29 additions & 15 deletions gpjax/kernels/computations/basis_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
import typing as tp

import jax.numpy as jnp
from jaxtyping import Float
Expand All @@ -7,63 +8,76 @@
from gpjax.linops import DenseLinearOperator
from gpjax.typing import Array

Kernel = tp.TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821


@dataclass
class BasisFunctionComputation(AbstractKernelComputation):
r"""Compute engine class for finite basis function approximations to a kernel."""

num_basis_fns: int = None

def cross_covariance(
self, x: Float[Array, "N D"], y: Float[Array, "M D"]
self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"]
) -> Float[Array, "N M"]:
r"""Compute an approximate cross-covariance matrix.
For a pair of inputs, compute the cross covariance matrix between the inputs.
Args:
kernel (Kernel): the kernel function.
x: (Float[Array, "N D"]): A $`N \times D`$ array of inputs.
y: (Float[Array, "M D"]): A $`M \times D`$ array of inputs.
Returns:
Float[Array, "N M"]: A $N \times M$ array of cross-covariances.
"""
z1 = self.compute_features(x)
z2 = self.compute_features(y)
return self.scaling * jnp.matmul(z1, z2.T)
z1 = self.compute_features(kernel, x)
z2 = self.compute_features(kernel, y)
return self.scaling(kernel) * jnp.matmul(z1, z2.T)

def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
def gram(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
r"""Compute an approximate Gram matrix.
For the Gram matrix, we can save computations by computing only one matrix
multiplication between the inputs and the scaled frequencies.
Args:
kernel (Kernel): the kernel function.
inputs (Float[Array, "N D"]): A $`N x D`$ array of inputs.
Returns:
DenseLinearOperator: A dense linear operator representing the
$`N \times N`$ Gram matrix.
"""
z1 = self.compute_features(inputs)
return DenseLinearOperator(self.scaling * jnp.matmul(z1, z1.T))
z1 = self.compute_features(kernel, inputs)
return DenseLinearOperator(self.scaling(kernel) * jnp.matmul(z1, z1.T))

def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
def compute_features(
self, kernel: Kernel, x: Float[Array, "N D"]
) -> Float[Array, "N L"]:
r"""Compute the features for the inputs.
Args:
kernel (Kernel): the kernel function.
x (Float[Array, "N D"]): A $`N \times D`$ array of inputs.
Returns
-------
Float[Array, "N L"]: A $`N \times L`$ array of features where $`L = 2M`$.
"""
frequencies = self.kernel.frequencies
scaling_factor = self.kernel.base_kernel.lengthscale
frequencies = kernel.frequencies
scaling_factor = kernel.base_kernel.lengthscale
z = jnp.matmul(x, (frequencies / scaling_factor).T)
z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
return z

@property
def scaling(self):
return self.kernel.base_kernel.variance / self.kernel.num_basis_fns
def scaling(self, kernel: Kernel):
r"""Compute the scaling factor for the covariance matrix.
Args:
kernel (Kernel): the kernel function.
Returns
-------
Float[Array, ""]: A scalar array.
"""
return kernel.base_kernel.variance / kernel.num_basis_fns
25 changes: 18 additions & 7 deletions gpjax/kernels/computations/constant_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,70 @@
# limitations under the License.
# ==============================================================================

import typing as tp

from jax import vmap
import jax.numpy as jnp
from jaxtyping import Float

from gpjax.kernels.computations.base import AbstractKernelComputation
from gpjax.kernels.computations import AbstractKernelComputation
from gpjax.linops import (
ConstantDiagonalLinearOperator,
DiagonalLinearOperator,
)
from gpjax.typing import Array

Kernel = tp.TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821


class ConstantDiagonalKernelComputation(AbstractKernelComputation):
def gram(self, x: Float[Array, "N D"]) -> ConstantDiagonalLinearOperator:
def gram(
self, kernel: Kernel, x: Float[Array, "N D"]
) -> ConstantDiagonalLinearOperator:
r"""Compute the Gram matrix.
Compute Gram covariance operator of the kernel function.
Args:
kernel (Kernel): the kernel function.
x (Float[Array, "N N"]): The inputs to the kernel function.
"""
value = self.kernel(x[0], x[0])
value = kernel(x[0], x[0])

return ConstantDiagonalLinearOperator(
value=jnp.atleast_1d(value), size=x.shape[0]
)

def diagonal(self, inputs: Float[Array, "N D"]) -> DiagonalLinearOperator:
def diagonal(
self, kernel: Kernel, inputs: Float[Array, "N D"]
) -> DiagonalLinearOperator:
r"""Compute the diagonal Gram matrix's entries.
For a given kernel, compute the elementwise diagonal of the
NxN gram matrix on an input matrix of shape $`N\times D`$.
Args:
kernel (Kernel): the kernel function.
inputs (Float[Array, "N D"]): The input matrix.
Returns
-------
DiagonalLinearOperator: The computed diagonal variance entries.
"""
diag = vmap(lambda x: self.kernel(x, x))(inputs)
diag = vmap(lambda x: kernel(x, x))(inputs)

return DiagonalLinearOperator(diag=diag)

def cross_covariance(
self, x: Float[Array, "N D"], y: Float[Array, "M D"]
self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"]
) -> Float[Array, "N M"]:
r"""Compute the cross-covariance matrix.
For a given kernel, compute the NxM covariance matrix on a pair of input
matrices of shape NxD and MxD.
Args:
kernel (Kernel): the kernel function.
x (Float[Array,"N D"]): The input matrix.
y (Float[Array,"M D"]): The input matrix.
Expand All @@ -75,5 +86,5 @@ def cross_covariance(
"""
# TODO: This is currently a dense implementation. We should implement
# a sparse LinearOperator for non-square cross-covariance matrices.
cross_cov = vmap(lambda x: vmap(lambda y: self.kernel(x, y))(y))(x)
cross_cov = vmap(lambda x: vmap(lambda y: kernel(x, y))(y))(x)
return cross_cov
8 changes: 6 additions & 2 deletions gpjax/kernels/computations/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,37 @@
# limitations under the License.
# ==============================================================================

import beartype.typing as tp
from jax import vmap
from jaxtyping import Float

from gpjax.kernels.computations.base import AbstractKernelComputation
from gpjax.typing import Array

Kernel = tp.TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821


class DenseKernelComputation(AbstractKernelComputation):
r"""Dense kernel computation class. Operations with the kernel assume
a dense gram matrix structure.
"""

def cross_covariance(
self, x: Float[Array, "N D"], y: Float[Array, "M D"]
self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"]
) -> Float[Array, "N M"]:
r"""Compute the cross-covariance matrix.
For a given kernel, compute the NxM covariance matrix on a pair of input
matrices of shape $`NxD`$ and $`MxD`$.
Args:
kernel (Kernel): the kernel function.
x (Float[Array,"N D"]): The input matrix.
y (Float[Array,"M D"]): The input matrix.
Returns
-------
Float[Array, "N M"]: The computed cross-covariance.
"""
cross_cov = vmap(lambda x: vmap(lambda y: self.kernel(x, y))(y))(x)
cross_cov = vmap(lambda x: vmap(lambda y: kernel(x, y))(y))(x)
return cross_cov
Loading

0 comments on commit 65c786d

Please sign in to comment.