Skip to content

Commit

Permalink
test(tests/test_kernels): add diagonal tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stephen-huan committed Apr 2, 2024
1 parent 93ed7d2 commit 57cfe67
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 3 deletions.
31 changes: 30 additions & 1 deletion tests/test_kernels/test_approximations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Tuple

from cola.ops import Dense
from cola.ops import (
Dense,
Diagonal,
)
import jax
from jax import config
import jax.numpy as jnp
Expand Down Expand Up @@ -63,6 +66,32 @@ def test_gram(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: i
assert jnp.all(evals > 0)


@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52])
@pytest.mark.parametrize("num_basis_fns", [2, 10, 20])
@pytest.mark.parametrize("n_dims", [1, 2, 5])
@pytest.mark.parametrize("n_data", [50, 100])
def test_diagonal(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: int):
key = jr.key(123)
x = jr.uniform(key, shape=(n_data, 1), minval=-3.0, maxval=3.0).reshape(-1, 1)
if n_dims > 1:
x = jnp.hstack([x] * n_dims)
base_kernel = kernel(active_dims=list(range(n_dims)))
approximate = RFF(base_kernel=base_kernel, num_basis_fns=num_basis_fns)

linop = approximate.diagonal(x)

# Check the return type
assert isinstance(linop, Diagonal)

Kxx = linop.diag + _jitter

# Check that the shape is correct
assert Kxx.shape == (n_data,)

# Check that the diagonal is positive
assert jnp.all(Kxx > 0)


@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52])
@pytest.mark.parametrize("num_basis_fns", [2, 10, 20])
@pytest.mark.parametrize("n_dims", [1, 2, 5])
Expand Down
20 changes: 19 additions & 1 deletion tests/test_kernels/test_nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from itertools import product
from typing import List

from cola.ops import LinearOperator
from cola.ops import (
Diagonal,
LinearOperator,
)
import jax
from jax import config
import jax.numpy as jnp
Expand Down Expand Up @@ -129,6 +132,21 @@ def test_gram(self, dim: int, n: int) -> None:
assert Kxx.shape == (n, n)
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0)

@pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}")
@pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}")
def test_diagonal(self, dim: int, n: int) -> None:
# Initialise kernel
kernel: AbstractKernel = self.kernel()

# Inputs
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)

# Test diagonal
Kxx = kernel.diagonal(x)
assert isinstance(Kxx, Diagonal)
assert Kxx.shape == (n, n)
assert jnp.all(Kxx.diag + 1e-6 > 0.0)

@pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}")
@pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}")
@pytest.mark.parametrize("dim", [1, 2, 5], ids=lambda x: f"dim={x}")
Expand Down
20 changes: 19 additions & 1 deletion tests/test_kernels/test_stationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from dataclasses import is_dataclass
from itertools import product

from cola.ops import LinearOperator
from cola.ops import (
Diagonal,
LinearOperator,
)
import jax
from jax import config
import jax.numpy as jnp
Expand Down Expand Up @@ -133,6 +136,21 @@ def test_gram(self, dim: int, n: int) -> None:
assert Kxx.shape == (n, n)
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0)

@pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}")
@pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}")
def test_diagonal(self, dim: int, n: int) -> None:
# Initialise kernel
kernel: AbstractKernel = self.kernel()

# Inputs
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)

# Test diagonal
Kxx = kernel.diagonal(x)
assert isinstance(Kxx, Diagonal)
assert Kxx.shape == (n, n)
assert jnp.all(Kxx.diag + 1e-6 > 0.0)

@pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}")
@pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}")
@pytest.mark.parametrize("dim", [1, 2, 5], ids=lambda x: f"dim={x}")
Expand Down

0 comments on commit 57cfe67

Please sign in to comment.