Skip to content

Commit

Permalink
Add minimal pytree checks for linops. (Better testing needed in futur…
Browse files Browse the repository at this point in the history
…e). (#222)

Signed-off-by: Daniel Dodd <daniel_dodd@icloud.com>
  • Loading branch information
daniel-dodd committed Apr 11, 2023
1 parent 0ec837c commit 6ec004e
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 22 deletions.
14 changes: 14 additions & 0 deletions tests/test_linops/test_constant_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import jax.numpy as jnp
import jax.random as jr
import pytest
import jax.tree_util as jtu
from dataclasses import is_dataclass

from jax.config import config

# Enable Float64 for more stable matrix inversions.
Expand All @@ -38,9 +41,20 @@ def approx_equal(res: jnp.ndarray, actual: jnp.ndarray) -> bool:
def test_init(n: int) -> None:
value = jr.uniform(_PRNGKey, (1,))
constant_diag = ConstantDiagonalLinearOperator(value=value, size=n)

# Check types.
assert isinstance(constant_diag, ConstantDiagonalLinearOperator)
assert is_dataclass(constant_diag)

# Check properties.
assert constant_diag.shape == (n, n)
assert constant_diag.dtype == jnp.float64
assert constant_diag.ndim == 2
assert constant_diag.size == n

# Check pytree.
assert jtu.tree_leaves(constant_diag) == [value] # shape, dtype are static!


@pytest.mark.parametrize("n", [1, 2, 5])
def test_diag(n: int) -> None:
Expand Down
19 changes: 18 additions & 1 deletion tests/test_linops/test_dense_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
import pytest
from jax.config import config

import jax.tree_util as jtu
from dataclasses import is_dataclass



# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
_PRNGKey = jr.PRNGKey(42)
Expand All @@ -34,12 +39,24 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool:
"""Check if two arrays are approximately equal."""
return jnp.linalg.norm(res - actual) < 1e-6


@pytest.mark.parametrize("n", [1, 2, 5])
def test_init(n: int) -> None:
values = jr.uniform(_PRNGKey, (n, n))
dense = DenseLinearOperator(values)

# Check types.
assert isinstance(dense, DenseLinearOperator)
assert is_dataclass(dense)

# Check properties.
assert dense.shape == (n, n)
assert dense.dtype == jnp.float64
assert dense.ndim == 2

# Check pytree.
assert jtu.tree_leaves(dense) == [values] # shape, dtype are static!




@pytest.mark.parametrize("n", [1, 2, 5])
Expand Down
17 changes: 17 additions & 0 deletions tests/test_linops/test_diagonal_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
import jax.numpy as jnp
import jax.random as jr
import pytest

import jax
import jax.tree_util as jtu
from dataclasses import is_dataclass


from jax.config import config

# Enable Float64 for more stable matrix inversions.
Expand All @@ -37,7 +43,18 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool:
def test_init(n: int) -> None:
values = jr.uniform(_PRNGKey, (n,))
diag = DiagonalLinearOperator(values)

# Check types.
assert isinstance(diag, DiagonalLinearOperator)
assert is_dataclass(diag)

# Check properties.
assert diag.shape == (n, n)
assert diag.dtype == jnp.float64
assert diag.ndim == 2

# Check pytree.
assert jtu.tree_leaves(diag) == [values] # shape, dtype are static!


@pytest.mark.parametrize("n", [1, 2, 5])
Expand Down
13 changes: 13 additions & 0 deletions tests/test_linops/test_identity_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import jax.random as jr
import pytest
from jax.config import config
import jax.tree_util as jtu
from dataclasses import is_dataclass

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand All @@ -39,9 +41,20 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool:
@pytest.mark.parametrize("n", [1, 2, 5])
def test_init(n: int) -> None:
id = IdentityLinearOperator(size=n)

# Check types.
assert isinstance(id, ConstantDiagonalLinearOperator)
assert is_dataclass(id)

# Check properties.
assert id.shape == (n, n)
assert id.dtype == jnp.float64
assert id.ndim == 2
assert id.size == n

# Check pytree.
assert jtu.tree_leaves(id) == [1.0] # shape, dtype are static!


@pytest.mark.parametrize("n", [1, 2, 5])
def test_diag(n: int) -> None:
Expand Down
63 changes: 42 additions & 21 deletions tests/test_linops/test_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,31 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass

import jax.numpy as jnp
import jax.tree_util as jtu
import pytest
from simple_pytree import static_field

import jax.tree_util as jtu
import jax.numpy as jnp
from gpjax.linops.linear_operator import LinearOperator
from dataclasses import dataclass, is_dataclass
from simple_pytree import static_field

def test_abstract_operator() -> None:

def test_covariance_operator() -> None:
# Test abstract linear operator raises an error.
with pytest.raises(TypeError):
LinearOperator()

# Test dataclass wrapped abstract linear operator raise an error.
with pytest.raises(TypeError):
dataclass(LinearOperator)()



@pytest.mark.parametrize("is_dataclass", [True, False])
@pytest.mark.parametrize("test_dataclass", [True, False])
@pytest.mark.parametrize("shape", [(1, 1), (2, 3), (4, 5, 6), [7, 8]])
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64])
def test_instantiate_no_attributes(is_dataclass, shape, dtype) -> None:
"""Test if the abstract operator can be instantiated, given the abstract methods."""
def test_instantiate_no_attributes(test_dataclass, shape, dtype) -> None:

# Test can instantiate a linear operator with the abstract methods defined.
class DummyLinearOperator(LinearOperator):
def diagonal(self, *args, **kwargs):
pass
Expand All @@ -60,24 +64,33 @@ def to_dense(self, *args, **kwargs):
def from_dense(cls, *args, **kwargs):
pass

if is_dataclass:
dataclass(DummyLinearOperator)
# Ensure we check dataclass case.
if test_dataclass:
DummyLinearOperator = dataclass(DummyLinearOperator)

# Initialise linear operator.
linop = DummyLinearOperator(shape=shape, dtype=dtype)

# Check types.
assert isinstance(linop, DummyLinearOperator)
assert isinstance(linop, LinearOperator)

if test_dataclass:
assert is_dataclass(linop)

# Check properties.
assert linop.shape == shape
assert linop.dtype == dtype
assert linop.ndim == len(shape)
assert jtu.tree_leaves(linop) == [] # shape and dtype are static!

# if not is_dataclass:
# Check pytree.
assert jtu.tree_leaves(linop) == [] # shape and dtype are static!


@pytest.mark.parametrize("is_dataclass", [True, False])
@pytest.mark.parametrize("test_dataclass", [True, False])
@pytest.mark.parametrize("shape", [(1, 1), (2, 3), (4, 5, 6), [7, 8]])
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64])
def test_instantiate_with_attributes(is_dataclass, shape, dtype) -> None:
def test_instantiate_with_attributes(test_dataclass, shape, dtype) -> None:
"""Test if the covariance operator can be instantiated with attribute annotations."""

class DummyLinearOperator(LinearOperator):
Expand Down Expand Up @@ -117,16 +130,24 @@ def to_dense(self, *args, **kwargs):
def from_dense(cls, *args, **kwargs):
pass

if is_dataclass:
dataclass(DummyLinearOperator)
# Ensure we check dataclass case.
if test_dataclass:
DummyLinearOperator = dataclass(DummyLinearOperator)

# Initialise linear operator.
linop = DummyLinearOperator(shape=shape, dtype=dtype)

# Check types.
assert isinstance(linop, DummyLinearOperator)
assert isinstance(linop, LinearOperator)

if test_dataclass:
assert is_dataclass(linop)

# Check properties.
assert linop.shape == shape
assert linop.dtype == dtype
assert linop.ndim == len(shape)
assert jtu.tree_leaves(linop) == [1, 3] # b, shape, dtype are static!

# if not is_dataclass:
# assert linop.__repr__() == f"DummyLinearOperator(shape={shape}, dtype={dtype})"
# Check pytree.
assert jtu.tree_leaves(linop) == [1, 3] # b, shape, dtype are static!
13 changes: 13 additions & 0 deletions tests/test_linops/test_zero_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import jax.random as jr
import pytest
from jax.config import config
import jax.tree_util as jtu
from dataclasses import is_dataclass

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand All @@ -37,7 +39,18 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool:
@pytest.mark.parametrize("n", [1, 2, 5])
def test_init(n: int) -> None:
zero = ZeroLinearOperator(shape=(n, n))

# Check types.
assert isinstance(zero, ZeroLinearOperator)
assert is_dataclass(zero)

# Check properties.
assert zero.shape == (n, n)
assert zero.dtype == jnp.float64
assert zero.ndim == 2

# Check pytree.
assert jtu.tree_leaves(zero) == [] # shape, dtype are static!


@pytest.mark.parametrize("n", [1, 2, 5])
Expand Down

0 comments on commit 6ec004e

Please sign in to comment.