Skip to content

Commit

Permalink
Merge pull request #208 from JaxGaussianProcesses/new_refactor_kernels
Browse files Browse the repository at this point in the history
New refactor kernels
  • Loading branch information
daniel-dodd committed Apr 4, 2023
2 parents aa356e5 + d6a045b commit 10fd335
Show file tree
Hide file tree
Showing 50 changed files with 384 additions and 394 deletions.
4 changes: 2 additions & 2 deletions gpjax/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

import dataclasses
from copy import copy, deepcopy
from typing import Any, Callable, Dict, Iterable, Tuple, List
from typing_extensions import Self
from typing import Any, Callable, Dict, Iterable, List, Tuple

import jax
import jax.tree_util as jtu
from jax._src.tree_util import _registry
from simple_pytree import Pytree, static_field
from typing_extensions import Self

import tensorflow_probability.substrates.jax.bijectors as tfb

Expand Down
6 changes: 3 additions & 3 deletions gpjax/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,26 @@

"""JaxKern."""
from .approximations import RFF
from .base import ProductKernel, SumKernel, AbstractKernel
from .base import AbstractKernel, ProductKernel, SumKernel
from .computations import (
BasisFunctionComputation,
ConstantDiagonalKernelComputation,
DenseKernelComputation,
DiagonalKernelComputation,
EigenKernelComputation,
)
from .non_euclidean import GraphKernel
from .nonstationary import Linear, Polynomial
from .stationary import (
RBF,
Matern12,
Matern32,
Matern52,
RationalQuadratic,
Periodic,
PoweredExponential,
RationalQuadratic,
White,
)
from .non_euclidean import GraphKernel

__all__ = [
"AbstractKernel",
Expand Down
62 changes: 31 additions & 31 deletions gpjax/kernels/approximations/rff.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
import tensorflow_probability.substrates.jax.bijectors as tfb

from dataclasses import dataclass

from jax.random import KeyArray, PRNGKey
from jaxtyping import Array, Float
from simple_pytree import static_field

from ...base import param_field
from ..base import AbstractKernel
from ..computations import BasisFunctionComputation
from jax.random import KeyArray
from typing import Dict, Any


class RFF(AbstractKernel):

@dataclass
class AbstractFourierKernel:
base_kernel: AbstractKernel
num_basis_fns: int
frequencies: Float[Array, "M 1"] = param_field(None, bijector=tfb.Identity)
key: KeyArray = static_field(PRNGKey(123))


@dataclass
class RFF(AbstractKernel, AbstractFourierKernel):
"""Computes an approximation of the kernel using Random Fourier Features.
All stationary kernels are equivalent to the Fourier transform of a probability
Expand All @@ -23,39 +40,22 @@ class RFF(AbstractKernel):
AbstractKernel (_type_): _description_
"""

def __init__(self, base_kernel: AbstractKernel, num_basis_fns: int) -> None:
"""Initialise the Random Fourier Features approximation.
def __post_init__(self) -> None:
"""Post-initialisation function.
Args:
base_kernel (AbstractKernel): The kernel that is to be approximated. This kernel must be stationary.
num_basis_fns (int): The number of basis functions that should be used to approximate the kernel.
This function is called after the initialisation of the kernel. It is used to
set the computation engine to be the basis function computation engine.
"""
self._check_valid_base_kernel(base_kernel)
self.base_kernel = base_kernel
self.num_basis_fns = num_basis_fns
# Set the computation engine to be basis function computation engine
self._check_valid_base_kernel(self.base_kernel)
self.compute_engine = BasisFunctionComputation
# Inform the compute engine of the number of basis functions
self.compute_engine.num_basis_fns = num_basis_fns

def init_params(self, key: KeyArray) -> Dict:
"""Initialise the parameters of the RFF approximation.
if self.frequencies is None:
n_dims = self.base_kernel.ndims
self.frequencies = self.base_kernel.spectral_density.sample(
seed=self.key, sample_shape=(self.num_basis_fns, n_dims)
)

Args:
key (KeyArray): A pseudo-random number generator key.
Returns:
Dict: A dictionary containing the original kernel's parameters and the initial frequencies used in RFF approximation.
"""
base_params = self.base_kernel.init_params(key)
n_dims = self.base_kernel.ndims
frequencies = self.base_kernel.spectral_density.sample(
seed=key, sample_shape=(self.num_basis_fns, n_dims)
)
base_params["frequencies"] = frequencies
return base_params

def __call__(self, *args: Any, **kwds: Any) -> Any:
def __call__(self, x: Array, y: Array) -> Array:
pass

def _check_valid_base_kernel(self, kernel: AbstractKernel):
Expand Down
6 changes: 6 additions & 0 deletions gpjax/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from functools import partial
from simple_pytree import static_field
from dataclasses import dataclass
from functools import partial
from typing import Callable, List, Union

from ..base import Module, param_field
from .computations import AbstractKernelComputation, DenseKernelComputation
Expand Down Expand Up @@ -115,6 +117,10 @@ def __mul__(

return ProductKernel(kernels=[self, Constant(other)])

@property
def spectral_density(self) -> tfd.Distribution:
return None


@dataclass
class Constant(AbstractKernel):
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/computations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# ==============================================================================

from .base import AbstractKernelComputation
from .basis_functions import BasisFunctionComputation
from .constant_diagonal import ConstantDiagonalKernelComputation
from .dense import DenseKernelComputation
from .diagonal import DiagonalKernelComputation
from .eigen import EigenKernelComputation
from .basis_functions import BasisFunctionComputation

__all__ = [
"AbstractKernelComputation",
Expand Down
10 changes: 4 additions & 6 deletions gpjax/kernels/computations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
# ==============================================================================

import abc
from dataclasses import dataclass
from typing import Any

from jax import vmap
from gpjax.linops import (
DenseLinearOperator,
DiagonalLinearOperator,
LinearOperator,
)
from jaxtyping import Array, Float
from dataclasses import dataclass

from gpjax.linops import DenseLinearOperator, DiagonalLinearOperator, LinearOperator

Kernel = Any

Expand Down
18 changes: 10 additions & 8 deletions gpjax/kernels/computations/basis_functions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from dataclasses import dataclass

import jax.numpy as jnp
from jaxtyping import Array, Float
from .base import AbstractKernelComputation

from gpjax.linops import DenseLinearOperator

from dataclasses import dataclass
from .base import AbstractKernelComputation


@dataclass
class BasisFunctionComputation(AbstractKernelComputation):
"""Compute engine class for finite basis function approximations to a kernel."""

num_basis_fns = None
num_basis_fns: int = None

def cross_covariance(
self, x: Float[Array, "N D"], y: Float[Array, "M D"]
Expand All @@ -26,8 +28,8 @@ def cross_covariance(
"""
z1 = self.compute_features(x)
z2 = self.compute_features(y)
z1 /= self.num_basis_fns
return self.kernel.variance * jnp.matmul(z1, z2.T)
z1 /= self.kernel.num_basis_fns
return self.kernel.base_kernel.variance * jnp.matmul(z1, z2.T)

def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
"""For the Gram matrix, we can save computations by computing only one matrix multiplication between the inputs and the scaled frequencies.
Expand All @@ -41,8 +43,8 @@ def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
"""
z1 = self.compute_features(inputs)
matrix = jnp.matmul(z1, z1.T) # shape: (n_samples, n_samples)
matrix /= self.num_basis_fns
return DenseLinearOperator(self.kernel.variance * matrix)
matrix /= self.kernel.num_basis_fns
return DenseLinearOperator(self.kernel.base_kernel.variance * matrix)

def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
"""Compute the features for the inputs.
Expand All @@ -55,7 +57,7 @@ def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
Float[Array, "N L"]: A N x L array of features where L = 2M.
"""
frequencies = self.kernel.frequencies
scaling_factor = self.kernel.lengthscale
scaling_factor = self.kernel.base_kernel.lengthscale
z = jnp.matmul(x, (frequencies / scaling_factor).T)
z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
return z
8 changes: 3 additions & 5 deletions gpjax/kernels/computations/constant_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# limitations under the License.
# ==============================================================================
import jax.numpy as jnp

from jax import vmap
from gpjax.linops import (
ConstantDiagonalLinearOperator,
DiagonalLinearOperator,
)
from jaxtyping import Array, Float

from gpjax.linops import ConstantDiagonalLinearOperator, DiagonalLinearOperator

from .base import AbstractKernelComputation


Expand Down
1 change: 1 addition & 0 deletions gpjax/kernels/computations/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from jax import vmap
from jaxtyping import Array, Float

from .base import AbstractKernelComputation

class DenseKernelComputation(AbstractKernelComputation):
Expand Down
6 changes: 3 additions & 3 deletions gpjax/kernels/computations/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# ==============================================================================

from jax import vmap
from gpjax.linops import (
DiagonalLinearOperator,
)
from jaxtyping import Array, Float

from gpjax.linops import DiagonalLinearOperator

from .base import AbstractKernelComputation


Expand Down
17 changes: 6 additions & 11 deletions gpjax/kernels/computations/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,31 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from typing import Dict

import jax.numpy as jnp
from jaxtyping import Array, Float

from .base import AbstractKernelComputation
from dataclasses import dataclass


@dataclass
class EigenKernelComputation(AbstractKernelComputation):
eigenvalues: Float[Array, "N"] = None
eigenvectors: Float[Array, "N N"] = None
num_verticies: int = None

def cross_covariance(
self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"]
self, x: Float[Array, "N D"], y: Float[Array, "M D"]
) -> Float[Array, "N M"]:
# Extract the graph Laplacian's eigenvalues
evals = self.eigenvalues
# Transform the eigenvalues of the graph Laplacian according to the
# RBF kernel's SPDE form.
S = jnp.power(
evals
self.kernel.eigenvalues
+ 2
* self.kernel.smoothness
/ self.kernel.lengthscale
/ self.kernel.lengthscale,
-self.kernel.smoothness,
)
S = jnp.multiply(S, self.num_vertex / jnp.sum(S))
S = jnp.multiply(S, self.kernel.num_vertex / jnp.sum(S))
# Scale the transform eigenvalues by the kernel variance
S = jnp.multiply(S, params["variance"])
S = jnp.multiply(S, self.kernel.variance)
return self.kernel(x, y, S=S)
Loading

0 comments on commit 10fd335

Please sign in to comment.