Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimal pytree checks for linops. (Better testing needed in future). #222

Merged
merged 2 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/test_linops/test_constant_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import jax.numpy as jnp
import jax.random as jr
import pytest
import jax.tree_util as jtu
from dataclasses import is_dataclass

from jax.config import config

# Enable Float64 for more stable matrix inversions.
Expand All @@ -38,9 +41,20 @@ def approx_equal(res: jnp.ndarray, actual: jnp.ndarray) -> bool:
def test_init(n: int) -> None:
value = jr.uniform(_PRNGKey, (1,))
constant_diag = ConstantDiagonalLinearOperator(value=value, size=n)

# Check types.
assert isinstance(constant_diag, ConstantDiagonalLinearOperator)
assert is_dataclass(constant_diag)

# Check properties.
assert constant_diag.shape == (n, n)
assert constant_diag.dtype == jnp.float64
assert constant_diag.ndim == 2
assert constant_diag.size == n

# Check pytree.
assert jtu.tree_leaves(constant_diag) == [value] # shape, dtype are static!


@pytest.mark.parametrize("n", [1, 2, 5])
def test_diag(n: int) -> None:
Expand Down
19 changes: 18 additions & 1 deletion tests/test_linops/test_dense_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
import pytest
from jax.config import config

import jax.tree_util as jtu
from dataclasses import is_dataclass



# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
_PRNGKey = jr.PRNGKey(42)
Expand All @@ -34,12 +39,24 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool:
"""Check if two arrays are approximately equal."""
return jnp.linalg.norm(res - actual) < 1e-6


@pytest.mark.parametrize("n", [1, 2, 5])
def test_init(n: int) -> None:
values = jr.uniform(_PRNGKey, (n, n))
dense = DenseLinearOperator(values)

# Check types.
assert isinstance(dense, DenseLinearOperator)
assert is_dataclass(dense)

# Check properties.
assert dense.shape == (n, n)
assert dense.dtype == jnp.float64
assert dense.ndim == 2

# Check pytree.
assert jtu.tree_leaves(dense) == [values] # shape, dtype are static!




@pytest.mark.parametrize("n", [1, 2, 5])
Expand Down
17 changes: 17 additions & 0 deletions tests/test_linops/test_diagonal_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
import jax.numpy as jnp
import jax.random as jr
import pytest

import jax
import jax.tree_util as jtu
from dataclasses import is_dataclass


from jax.config import config

# Enable Float64 for more stable matrix inversions.
Expand All @@ -37,7 +43,18 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool:
def test_init(n: int) -> None:
values = jr.uniform(_PRNGKey, (n,))
diag = DiagonalLinearOperator(values)

# Check types.
assert isinstance(diag, DiagonalLinearOperator)
assert is_dataclass(diag)

# Check properties.
assert diag.shape == (n, n)
assert diag.dtype == jnp.float64
assert diag.ndim == 2

# Check pytree.
assert jtu.tree_leaves(diag) == [values] # shape, dtype are static!


@pytest.mark.parametrize("n", [1, 2, 5])
Expand Down
13 changes: 13 additions & 0 deletions tests/test_linops/test_identity_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import jax.random as jr
import pytest
from jax.config import config
import jax.tree_util as jtu
from dataclasses import is_dataclass

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand All @@ -39,9 +41,20 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool:
@pytest.mark.parametrize("n", [1, 2, 5])
def test_init(n: int) -> None:
id = IdentityLinearOperator(size=n)

# Check types.
assert isinstance(id, ConstantDiagonalLinearOperator)
assert is_dataclass(id)

# Check properties.
assert id.shape == (n, n)
assert id.dtype == jnp.float64
assert id.ndim == 2
assert id.size == n

# Check pytree.
assert jtu.tree_leaves(id) == [1.0] # shape, dtype are static!


@pytest.mark.parametrize("n", [1, 2, 5])
def test_diag(n: int) -> None:
Expand Down
63 changes: 42 additions & 21 deletions tests/test_linops/test_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,31 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass

import jax.numpy as jnp
import jax.tree_util as jtu
import pytest
from simple_pytree import static_field

import jax.tree_util as jtu
import jax.numpy as jnp
from gpjax.linops.linear_operator import LinearOperator
from dataclasses import dataclass, is_dataclass
from simple_pytree import static_field

def test_abstract_operator() -> None:

def test_covariance_operator() -> None:
# Test abstract linear operator raises an error.
with pytest.raises(TypeError):
LinearOperator()

# Test dataclass wrapped abstract linear operator raise an error.
with pytest.raises(TypeError):
dataclass(LinearOperator)()



@pytest.mark.parametrize("is_dataclass", [True, False])
@pytest.mark.parametrize("test_dataclass", [True, False])
@pytest.mark.parametrize("shape", [(1, 1), (2, 3), (4, 5, 6), [7, 8]])
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64])
def test_instantiate_no_attributes(is_dataclass, shape, dtype) -> None:
"""Test if the abstract operator can be instantiated, given the abstract methods."""
def test_instantiate_no_attributes(test_dataclass, shape, dtype) -> None:

# Test can instantiate a linear operator with the abstract methods defined.
class DummyLinearOperator(LinearOperator):
def diagonal(self, *args, **kwargs):
pass
Expand All @@ -60,24 +64,33 @@ def to_dense(self, *args, **kwargs):
def from_dense(cls, *args, **kwargs):
pass

if is_dataclass:
dataclass(DummyLinearOperator)
# Ensure we check dataclass case.
if test_dataclass:
DummyLinearOperator = dataclass(DummyLinearOperator)

# Initialise linear operator.
linop = DummyLinearOperator(shape=shape, dtype=dtype)

# Check types.
assert isinstance(linop, DummyLinearOperator)
assert isinstance(linop, LinearOperator)

if test_dataclass:
assert is_dataclass(linop)

# Check properties.
assert linop.shape == shape
assert linop.dtype == dtype
assert linop.ndim == len(shape)
assert jtu.tree_leaves(linop) == [] # shape and dtype are static!

# if not is_dataclass:
# Check pytree.
assert jtu.tree_leaves(linop) == [] # shape and dtype are static!


@pytest.mark.parametrize("is_dataclass", [True, False])
@pytest.mark.parametrize("test_dataclass", [True, False])
@pytest.mark.parametrize("shape", [(1, 1), (2, 3), (4, 5, 6), [7, 8]])
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.float64])
def test_instantiate_with_attributes(is_dataclass, shape, dtype) -> None:
def test_instantiate_with_attributes(test_dataclass, shape, dtype) -> None:
"""Test if the covariance operator can be instantiated with attribute annotations."""

class DummyLinearOperator(LinearOperator):
Expand Down Expand Up @@ -117,16 +130,24 @@ def to_dense(self, *args, **kwargs):
def from_dense(cls, *args, **kwargs):
pass

if is_dataclass:
dataclass(DummyLinearOperator)
# Ensure we check dataclass case.
if test_dataclass:
DummyLinearOperator = dataclass(DummyLinearOperator)

# Initialise linear operator.
linop = DummyLinearOperator(shape=shape, dtype=dtype)

# Check types.
assert isinstance(linop, DummyLinearOperator)
assert isinstance(linop, LinearOperator)

if test_dataclass:
assert is_dataclass(linop)

# Check properties.
assert linop.shape == shape
assert linop.dtype == dtype
assert linop.ndim == len(shape)
assert jtu.tree_leaves(linop) == [1, 3] # b, shape, dtype are static!

# if not is_dataclass:
# assert linop.__repr__() == f"DummyLinearOperator(shape={shape}, dtype={dtype})"
# Check pytree.
assert jtu.tree_leaves(linop) == [1, 3] # b, shape, dtype are static!
13 changes: 13 additions & 0 deletions tests/test_linops/test_zero_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import jax.random as jr
import pytest
from jax.config import config
import jax.tree_util as jtu
from dataclasses import is_dataclass

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand All @@ -37,7 +39,18 @@ def approx_equal(res: jax.Array, actual: jax.Array) -> bool:
@pytest.mark.parametrize("n", [1, 2, 5])
def test_init(n: int) -> None:
zero = ZeroLinearOperator(shape=(n, n))

# Check types.
assert isinstance(zero, ZeroLinearOperator)
assert is_dataclass(zero)

# Check properties.
assert zero.shape == (n, n)
assert zero.dtype == jnp.float64
assert zero.ndim == 2

# Check pytree.
assert jtu.tree_leaves(zero) == [] # shape, dtype are static!


@pytest.mark.parametrize("n", [1, 2, 5])
Expand Down