From cb214704862b07d398eeed4785ee53511fb2a022 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Tue, 11 Apr 2023 12:21:30 +0100 Subject: [PATCH] Add minimal pytree checks for linops. (Better testing needed in future). --- .../test_constant_linear_operator.py | 14 +++++ .../test_linops/test_dense_linear_operator.py | 17 +++++- .../test_diagonal_linear_operator.py | 14 +++++ .../test_identity_linear_operator.py | 13 +++++ tests/test_linops/test_linear_operator.py | 57 +++++++++++++------ .../test_linops/test_zero_linear_operator.py | 14 ++++- 6 files changed, 110 insertions(+), 19 deletions(-) diff --git a/tests/test_linops/test_constant_linear_operator.py b/tests/test_linops/test_constant_linear_operator.py index 0ec6707d2..a3f46c1af 100644 --- a/tests/test_linops/test_constant_linear_operator.py +++ b/tests/test_linops/test_constant_linear_operator.py @@ -17,6 +17,9 @@ import jax.numpy as jnp import jax.random as jr import pytest +import jax.tree_util as jtu +from dataclasses import is_dataclass + from jax.config import config @@ -38,9 +41,20 @@ def approx_equal(res: jnp.ndarray, actual: jnp.ndarray) -> bool: def test_init(n: int) -> None: value = jr.uniform(_PRNGKey, (1,)) constant_diag = ConstantDiagonalLinearOperator(value=value, size=n) + + # Check types. + assert isinstance(constant_diag, ConstantDiagonalLinearOperator) + assert is_dataclass(constant_diag) + + # Check properties. assert constant_diag.shape == (n, n) + assert constant_diag.dtype == jnp.float64 + assert constant_diag.ndim == 2 assert constant_diag.size == n + # Check pytree. + assert jtu.tree_leaves(constant_diag) == [value] # shape, dtype are static! + @pytest.mark.parametrize("n", [1, 2, 5]) def test_diag(n: int) -> None: diff --git a/tests/test_linops/test_dense_linear_operator.py b/tests/test_linops/test_dense_linear_operator.py index f46f688f6..fae714919 100644 --- a/tests/test_linops/test_dense_linear_operator.py +++ b/tests/test_linops/test_dense_linear_operator.py @@ -20,6 +20,9 @@ import pytest from jax.config import config +import jax.tree_util as jtu +from dataclasses import is_dataclass + # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -34,12 +37,24 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool: """Check if two arrays are approximately equal.""" return jnp.linalg.norm(res - actual) < 1e-6 - @pytest.mark.parametrize("n", [1, 2, 5]) def test_init(n: int) -> None: values = jr.uniform(_PRNGKey, (n, n)) dense = DenseLinearOperator(values) + + # Check types. + assert isinstance(dense, DenseLinearOperator) + assert is_dataclass(dense) + + # Check properties. assert dense.shape == (n, n) + assert dense.dtype == jnp.float64 + assert dense.ndim == 2 + + # Check pytree. + assert jtu.tree_leaves(dense) == [values] # shape, dtype are static! + + @pytest.mark.parametrize("n", [1, 2, 5]) diff --git a/tests/test_linops/test_diagonal_linear_operator.py b/tests/test_linops/test_diagonal_linear_operator.py index 5187af33a..dbbca4344 100644 --- a/tests/test_linops/test_diagonal_linear_operator.py +++ b/tests/test_linops/test_diagonal_linear_operator.py @@ -18,6 +18,9 @@ import jax.random as jr import pytest import jax +import jax.tree_util as jtu +from dataclasses import is_dataclass + from jax.config import config @@ -38,7 +41,18 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool: def test_init(n: int) -> None: values = jr.uniform(_PRNGKey, (n,)) diag = DiagonalLinearOperator(values) + + # Check types. + assert isinstance(diag, DiagonalLinearOperator) + assert is_dataclass(diag) + + # Check properties. assert diag.shape == (n, n) + assert diag.dtype == jnp.float64 + assert diag.ndim == 2 + + # Check pytree. + assert jtu.tree_leaves(diag) == [values] # shape, dtype are static! @pytest.mark.parametrize("n", [1, 2, 5]) diff --git a/tests/test_linops/test_identity_linear_operator.py b/tests/test_linops/test_identity_linear_operator.py index 5e10daaec..c97c5cf2e 100644 --- a/tests/test_linops/test_identity_linear_operator.py +++ b/tests/test_linops/test_identity_linear_operator.py @@ -19,6 +19,8 @@ import jax import pytest from jax.config import config +import jax.tree_util as jtu +from dataclasses import is_dataclass # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -38,9 +40,20 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool: @pytest.mark.parametrize("n", [1, 2, 5]) def test_init(n: int) -> None: id = IdentityLinearOperator(size=n) + + # Check types. + assert isinstance(id, ConstantDiagonalLinearOperator) + assert is_dataclass(id) + + # Check properties. assert id.shape == (n, n) + assert id.dtype == jnp.float64 + assert id.ndim == 2 assert id.size == n + # Check pytree. + assert jtu.tree_leaves(id) == [1.0] # shape, dtype are static! + @pytest.mark.parametrize("n", [1, 2, 5]) def test_diag(n: int) -> None: diff --git a/tests/test_linops/test_linear_operator.py b/tests/test_linops/test_linear_operator.py index 0729e8ecb..e2ae16391 100644 --- a/tests/test_linops/test_linear_operator.py +++ b/tests/test_linops/test_linear_operator.py @@ -17,20 +17,27 @@ import jax.tree_util as jtu import jax.numpy as jnp from gpjax.linops.linear_operator import LinearOperator -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from simple_pytree import static_field +def test_abstract_operator() -> None: -def test_covariance_operator() -> None: + # Test abstract linear operator raises an error. with pytest.raises(TypeError): LinearOperator() -@pytest.mark.parametrize("is_dataclass", [True, False]) + # Test dataclass wrapped abstract linear operator raise an error. + with pytest.raises(TypeError): + dataclass(LinearOperator)() + + + +@pytest.mark.parametrize("test_dataclass", [True, False]) @pytest.mark.parametrize("shape", [(1, 1), (2, 3), (4, 5, 6), [7, 8]]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64]) -def test_instantiate_no_attributes(is_dataclass, shape, dtype) -> None: - """Test if the abstract operator can be instantiated, given the abstract methods.""" +def test_instantiate_no_attributes(test_dataclass, shape, dtype) -> None: + # Test can instantiate a linear operator with the abstract methods defined. class DummyLinearOperator(LinearOperator): def diagonal(self, *args, **kwargs): pass @@ -57,25 +64,33 @@ def to_dense(self, *args, **kwargs): def from_dense(cls, *args, **kwargs): pass - if is_dataclass: - dataclass(DummyLinearOperator) + # Ensure we check dataclass case. + if test_dataclass: + DummyLinearOperator = dataclass(DummyLinearOperator) + # Initialise linear operator. linop = DummyLinearOperator(shape=shape, dtype=dtype) + + # Check types. assert isinstance(linop, DummyLinearOperator) assert isinstance(linop, LinearOperator) + + if test_dataclass: + assert is_dataclass(linop) + + # Check properties. assert linop.shape == shape assert linop.dtype == dtype assert linop.ndim == len(shape) - assert jtu.tree_leaves(linop) == [] # shape and dtype are static! - # if not is_dataclass: - # assert linop.__repr__() == f"DummyLinearOperator(shape={shape}, dtype={dtype})" + # Check pytree. + assert jtu.tree_leaves(linop) == [] # shape and dtype are static! -@pytest.mark.parametrize("is_dataclass", [True, False]) +@pytest.mark.parametrize("test_dataclass", [True, False]) @pytest.mark.parametrize("shape", [(1, 1), (2, 3), (4, 5, 6), [7, 8]]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64]) -def test_instantiate_with_attributes(is_dataclass, shape, dtype) -> None: +def test_instantiate_with_attributes(test_dataclass, shape, dtype) -> None: """Test if the covariance operator can be instantiated with attribute annotations.""" class DummyLinearOperator(LinearOperator): @@ -115,16 +130,24 @@ def to_dense(self, *args, **kwargs): def from_dense(cls, *args, **kwargs): pass - if is_dataclass: - dataclass(DummyLinearOperator) + # Ensure we check dataclass case. + if test_dataclass: + DummyLinearOperator = dataclass(DummyLinearOperator) + # Initialise linear operator. linop = DummyLinearOperator(shape=shape, dtype=dtype) + + # Check types. assert isinstance(linop, DummyLinearOperator) assert isinstance(linop, LinearOperator) + + if test_dataclass: + assert is_dataclass(linop) + + # Check properties. assert linop.shape == shape assert linop.dtype == dtype assert linop.ndim == len(shape) - assert jtu.tree_leaves(linop) == [1, 3] # b, shape, dtype are static! - # if not is_dataclass: - # assert linop.__repr__() == f"DummyLinearOperator(shape={shape}, dtype={dtype})" \ No newline at end of file + # Check pytree. + assert jtu.tree_leaves(linop) == [1, 3] # b, shape, dtype are static! \ No newline at end of file diff --git a/tests/test_linops/test_zero_linear_operator.py b/tests/test_linops/test_zero_linear_operator.py index 149f9a81d..5238f4c48 100644 --- a/tests/test_linops/test_zero_linear_operator.py +++ b/tests/test_linops/test_zero_linear_operator.py @@ -19,7 +19,8 @@ import jax import pytest from jax.config import config - +import jax.tree_util as jtu +from dataclasses import is_dataclass # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -38,7 +39,18 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool: @pytest.mark.parametrize("n", [1, 2, 5]) def test_init(n: int) -> None: zero = ZeroLinearOperator(shape=(n, n)) + + # Check types. + assert isinstance(zero, ZeroLinearOperator) + assert is_dataclass(zero) + + # Check properties. assert zero.shape == (n, n) + assert zero.dtype == jnp.float64 + assert zero.ndim == 2 + + # Check pytree. + assert jtu.tree_leaves(zero) == [] # shape, dtype are static! @pytest.mark.parametrize("n", [1, 2, 5])