diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 262765cdb..873d8944f 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -14,11 +14,12 @@ # ============================================================================== -from itertools import permutations +from itertools import permutations, product import jax import jax.numpy as jnp import jax.random as jr +import jax.tree_util as jtu import pytest import distrax as dx from jax.config import config @@ -30,8 +31,14 @@ Matern12, Matern32, Matern52, + White, + Periodic, + PoweredExponential, + RationalQuadratic, ) +from gpjax.kernels.computations import DenseKernelComputation, DiagonalKernelComputation from gpjax.kernels.stationary.utils import build_student_t_distribution +from gpjax.parameters.bijectors import Identity, Softplus # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -39,295 +46,178 @@ _jitter = 1e-6 -@pytest.mark.parametrize( - "kernel", - [ - RBF(), - # Matern12(), - # Matern32(), - # Matern52(), - # RationalQuadratic(), - # White(), - ], -) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("n", [1, 2, 10]) -def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: - - # Gram constructor static method: - kernel.gram - - # Inputs x: - x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) - - # Test gram matrix: - Kxx = kernel.gram(x) - assert isinstance(Kxx, LinearOperator) - assert Kxx.shape == (n, n) - - -@pytest.mark.parametrize( - "kernel", - [ - RBF(), - # Matern12(), - # Matern32(), - # Matern52(), - # RationalQuadratic(), - # White(), - ], -) -@pytest.mark.parametrize("num_a", [1, 2, 5]) -@pytest.mark.parametrize("num_b", [1, 2, 5]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -def test_cross_covariance( - kernel: AbstractKernel, num_a: int, num_b: int, dim: int -) -> None: - # Inputs a, b: - a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim) - b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) - - # Test cross covariance, Kab: - Kab = kernel.cross_covariance(a, b) - assert isinstance(Kab, jnp.ndarray) - assert Kab.shape == (num_a, num_b) - - -@pytest.mark.parametrize( - "kernel", - [ - RBF(), - # Matern12(), - # Matern32(), - # Matern52(), - # White(), - ], -) -@pytest.mark.parametrize("dim", [1, 2, 5]) -def test_call(kernel: AbstractKernel, dim: int) -> None: - - # Datapoint x and datapoint y: - x = jnp.array([[1.0] * dim]) - y = jnp.array([[0.5] * dim]) - - # Test calling gives an autocovariance value of no dimension between the inputs: - kxy = kernel(x, y) - - assert isinstance(kxy, jax.Array) - assert kxy.shape == () - - -@pytest.mark.parametrize( - "kern", - [ - RBF, - # Matern12, - # Matern32, - # Matern52, - ], -) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def( - kern: AbstractKernel, dim: int, ell: float, sigma: float, n: int -) -> None: - kern = kern( - active_dims=list(range(dim)), - lengthscale=jnp.array([ell]), - variance=jnp.array([sigma]), +class BaseTestKernel: + """A base class that contains all tests applied on stationary kernels.""" + + kernel: AbstractKernel + default_compute_engine = type + spectral_density_name: str + + def pytest_generate_tests(self, metafunc): + """This is called automatically by pytest""" + id_func = lambda x: "-".join([f"{k}={v}" for k, v in x.items()]) + funcarglist = metafunc.cls.params.get(metafunc.function.__name__, None) + + if funcarglist is None: + + return + else: + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [[funcargs[name] for name in argnames] for funcargs in funcarglist], + ids=id_func, + ) + + @pytest.mark.parametrize("dim", [None, 1, 3], ids=lambda x: f"dim={x}") + def test_initialization(self, fields: dict, dim: int) -> None: + + fields = {k: jnp.array([v]) for k, v in fields.items()} + + # number of dimensions + if dim is None: + kernel: AbstractKernel = self.kernel(**fields) + assert kernel.ndims == 1 + else: + kernel: AbstractKernel = self.kernel( + active_dims=[i for i in range(dim)], **fields + ) + assert kernel.ndims == dim + + # compute engine + assert kernel.compute_engine == self.default_compute_engine + + # properties + for field, value in fields.items(): + assert getattr(kernel, field) == value + + # pytree + leaves = jtu.tree_leaves(kernel) + assert len(leaves) == len(fields) + + # meta + meta_leaves = kernel._pytree__meta + assert meta_leaves.keys() == fields.keys() + for field in fields: + if field in ["variance", "lengthscale", "period", "alpha"]: + assert meta_leaves[field]["bijector"] == Softplus + if field in ["power"]: + assert meta_leaves[field]["bijector"] == Identity + assert meta_leaves[field]["trainable"] == True + + # call + x = jnp.linspace(0.0, 1.0, 10 * kernel.ndims).reshape(10, kernel.ndims) + kernel(x, x) + + @pytest.mark.parametrize("n", [1, 5], ids=lambda x: f"n={x}") + @pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}") + def test_gram(self, dim: int, n: int) -> None: + kernel: AbstractKernel = self.kernel() + kernel.gram + x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + Kxx = kernel.gram(x) + assert isinstance(Kxx, LinearOperator) + assert Kxx.shape == (n, n) + assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense()) > 0.0) + + @pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}") + @pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}") + @pytest.mark.parametrize("dim", [1, 2, 5], ids=lambda x: f"dim={x}") + def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None: + + kernel: AbstractKernel = self.kernel() + a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim) + b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim) + Kab = kernel.cross_covariance(a, b) + assert isinstance(Kab, jnp.ndarray) + assert Kab.shape == (n_a, n_b) + + def test_spectral_density(self): + + kernel: AbstractKernel = self.kernel() + + if self.kernel not in [RBF, Matern12, Matern32, Matern52]: + with pytest.raises(AttributeError): + kernel.spectral_density + else: + sdensity = kernel.spectral_density + assert sdensity.name == self.spectral_density_name + assert sdensity.loc == jnp.array(0.0) + assert sdensity.scale == jnp.array(1.0) + + +prod = lambda inp: [ + {"fields": dict(zip(inp.keys(), values))} for values in product(*inp.values()) +] + + +class TestRBF(BaseTestKernel): + kernel = RBF + fields = prod({"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0]}) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + spectral_density_name = "Normal" + + +class TestMatern12(BaseTestKernel): + kernel = Matern12 + fields = prod({"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0]}) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + spectral_density_name = "StudentT" + + +class TestMatern32(BaseTestKernel): + kernel = Matern32 + fields = prod({"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0]}) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + spectral_density_name = "StudentT" + + +class TestMatern52(BaseTestKernel): + kernel = Matern52 + fields = prod({"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0]}) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + spectral_density_name = "StudentT" + + +class TestWhite(BaseTestKernel): + kernel = White + fields = prod({"variance": [0.1, 1.0]}) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation + + +class TestPeriodic(BaseTestKernel): + kernel = Periodic + fields = prod( + {"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0], "period": [0.1, 1.0]} ) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -# @pytest.mark.parametrize("dim", [1, 2, 5]) -# @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -# @pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) -# @pytest.mark.parametrize("n", [1, 2, 5]) -# def test_pos_def_rq(dim: int, ell: float, sigma: float, alpha: float, n: int) -> None: -# kern = RationalQuadratic(active_dims=list(range(dim))) -# # Gram constructor static method: -# kern.gram - -# # Create inputs x: -# x = jr.uniform(_initialise_key, (n, dim)) -# params = { -# "lengthscale": jnp.array([ell]), -# "variance": jnp.array([sigma]), -# "alpha": jnp.array([alpha]), -# } - -# # Test gram matrix eigenvalues are positive: -# Kxx = kern.gram(params, x) -# Kxx += identity(n) * _jitter -# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) -# assert (eigen_values > 0.0).all() - - -# @pytest.mark.parametrize("dim", [1, 2, 5]) -# @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -# @pytest.mark.parametrize("period", [0.1, 0.5, 1.0]) -# @pytest.mark.parametrize("n", [1, 2, 5]) -# def test_pos_def_periodic( -# dim: int, ell: float, sigma: float, period: float, n: int -# ) -> None: -# kern = Periodic(active_dims=list(range(dim))) -# # Gram constructor static method: -# kern.gram - -# # Create inputs x: -# x = jr.uniform(_initialise_key, (n, dim)) -# params = { -# "lengthscale": jnp.array([ell]), -# "variance": jnp.array([sigma]), -# "period": jnp.array([period]), -# } - -# # Test gram matrix eigenvalues are positive: -# Kxx = kern.gram(params, x) -# Kxx += identity(n) * _jitter -# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) -# # assert (eigen_values > 0.0).all() - - -# @pytest.mark.parametrize("dim", [1, 2, 5]) -# @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -# @pytest.mark.parametrize("power", [0.1, 0.5, 1.0]) -# @pytest.mark.parametrize("n", [1, 2, 5]) -# def test_pos_def_power_exp( -# dim: int, ell: float, sigma: float, power: float, n: int -# ) -> None: -# kern = PoweredExponential(active_dims=list(range(dim))) -# # Gram constructor static method: -# kern.gram - -# # Create inputs x: -# x = jr.uniform(_initialise_key, (n, dim)) -# params = { -# "lengthscale": jnp.array([ell]), -# "variance": jnp.array([sigma]), -# "power": jnp.array([power]), -# } - -# # Test gram matrix eigenvalues are positive: -# Kxx = kern.gram(params, x) -# Kxx += identity(n) * _jitter -# eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) -# assert (eigen_values > 0.0).all() - - -# @pytest.mark.parametrize("kernel", -# [ -# RBF, -# #Matern12, -# #Matern32, -# #Matern52, -# ], -# ) -# @pytest.mark.parametrize("dim", [None, 1, 2, 5, 10]) -# def test_initialisation(kernel: AbstractKernel, dim: int) -> None: - -# if dim is None: -# kern = kernel() -# assert kern.ndims == 1 - -# else: -# kern = kernel(active_dims=[i for i in range(dim)]) -# params = kern.init_params(_initialise_key) - -# assert list(params.keys()) == ["lengthscale", "variance"] -# assert all(params["lengthscale"] == jnp.array([1.0] * dim)) -# assert params["variance"] == jnp.array([1.0]) - -# if dim > 1: -# assert kern.ard -# else: -# assert not kern.ard - - -# @pytest.mark.parametrize( -# "kernel", -# [ -# RBF, -# # Matern12, -# # Matern32, -# # Matern52, -# # RationalQuadratic, -# # Periodic, -# # PoweredExponential, -# ], -# ) -# def test_dtype(kernel: AbstractKernel) -> None: -# parameter_state = initialise(kernel(), _initialise_key) -# params, *_ = parameter_state.unpack() -# for k, v in params.items(): -# assert v.dtype == jnp.float64 -# assert isinstance(k, str) - - -@pytest.mark.parametrize( - "kernel", - [ - RBF, - # Matern12, - # Matern32, - # Matern52, - # RationalQuadratic, - ], -) -def test_active_dim(kernel: AbstractKernel) -> None: - dim_list = [0, 1, 2, 3] - perm_length = 2 - dim_pairs = list(permutations(dim_list, r=perm_length)) - n_dims = len(dim_list) - - # Generate random inputs - x = jr.normal(_initialise_key, shape=(20, n_dims)) - - for dp in dim_pairs: - # Take slice of x - slice = x[..., dp] - # Define kernels - ad_kern = kernel(active_dims=dp) - manual_kern = kernel(active_dims=[i for i in range(perm_length)]) +class TestPoweredExponential(BaseTestKernel): + kernel = PoweredExponential + fields = prod( + {"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0], "power": [0.1, 2.0]} + ) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation - # Compute gram matrices - ad_Kxx = ad_kern.gram(x) - manual_Kxx = manual_kern.gram(slice) - # Test gram matrices are equal - assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) +class TestRationalQuadratic(BaseTestKernel): + kernel = RationalQuadratic + fields = prod( + {"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0], "alpha": [0.1, 1.0]} + ) + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation @pytest.mark.parametrize("smoothness", [1, 2, 3]) def test_build_studentt_dist(smoothness: int) -> None: dist = build_student_t_distribution(smoothness) assert isinstance(dist, dx.Distribution) - - -# @pytest.mark.parametrize( -# "kern, df", [(Matern12(), 1), (Matern32(), 3), (Matern52(), 5)] -# ) -# def test_matern_spectral_density(kern, df) -> None: -# sdensity = kern.spectral_density -# assert sdensity.name == "StudentT" -# assert sdensity.df == df -# assert sdensity.loc == jnp.array(0.0) -# assert sdensity.scale == jnp.array(1.0) - - -# def test_rbf_spectral_density() -> None: -# kern = RBF() -# sdensity = kern.spectral_density -# assert sdensity.name == "Normal" -# assert sdensity.loc == jnp.array(0.0) -# assert sdensity.scale == jnp.array(1.0)