diff --git a/src/causalprog/graph/node/parameter.py b/src/causalprog/graph/node/parameter.py index 469ecd7..56db0cb 100644 --- a/src/causalprog/graph/node/parameter.py +++ b/src/causalprog/graph/node/parameter.py @@ -1,18 +1,12 @@ """Graph nodes representing parameters.""" -from __future__ import annotations - -import typing - -import numpy as np +import jax +import jax.numpy as jnp +import numpy.typing as npt from typing_extensions import override from .base import Node -if typing.TYPE_CHECKING: - import jax - import numpy.typing as npt - class ParameterNode(Node): """ @@ -44,15 +38,15 @@ def __init__(self, *, label: str) -> None: def sample( self, parameter_values: dict[str, float], - sampled_dependencies: dict[str, npt.NDArray[float]], + sampled_dependencies: dict[str, npt.ArrayLike], samples: int, *, rng_key: jax.Array, - ) -> npt.NDArray[float]: + ) -> npt.ArrayLike: if self.label not in parameter_values: msg = f"Missing input for parameter node: {self.label}." raise ValueError(msg) - return np.full(samples, parameter_values[self.label]) + return jnp.full(samples, parameter_values[self.label]) @override def copy(self) -> Node: diff --git a/tests/fixtures/numpyro/mcmc.py b/tests/fixtures/numpyro/mcmc.py index ce00be1..aa694b7 100644 --- a/tests/fixtures/numpyro/mcmc.py +++ b/tests/fixtures/numpyro/mcmc.py @@ -3,7 +3,7 @@ from collections.abc import Callable from typing import Concatenate, TypeAlias -import numpy as np +import jax.numpy as jnp import pytest from jax import Array from numpyro.infer import MCMC, NUTS @@ -77,7 +77,7 @@ def _inner(left_mcmc: MCMC, right_mcmc: MCMC) -> None: f"Samples on left ({sample_name}) not present on right" ) # Confirm samples match. - assert np.allclose(sample_values, samples_r[sample_name]), ( + assert jnp.allclose(sample_values, samples_r[sample_name]), ( f"Samples '{sample_name}' do not match" ) for sample_name in samples_r: diff --git a/tests/test_algorithms/test_moments.py b/tests/test_algorithms/test_moments.py index 8deb39b..cfac010 100644 --- a/tests/test_algorithms/test_moments.py +++ b/tests/test_algorithms/test_moments.py @@ -1,6 +1,6 @@ """Tests for moment algorithms.""" -import numpy as np +import jax.numpy as jnp import pytest from causalprog import algorithms @@ -26,14 +26,14 @@ def test_expectation_stdev_single_normal_node( graph = normal_graph(mean, stdev) # Check within hand-computation - assert np.isclose( + assert jnp.isclose( algorithms.expectation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), mean, rtol=rtol, ) - assert np.isclose( + assert jnp.isclose( algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), @@ -79,18 +79,18 @@ def test_mean_stdev_two_node_graph( graph = two_normal_graph(mean=mean, cov=stdev, cov2=stdev2) - assert np.isclose( + assert jnp.isclose( algorithms.expectation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), mean, rtol=rtol, ) - assert np.isclose( + assert jnp.isclose( algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), - np.sqrt(stdev**2 + stdev2**2), + jnp.sqrt(stdev**2 + stdev2**2), rtol=rtol, ) @@ -108,7 +108,7 @@ def test_expectation(two_normal_graph, rng_key, samples, rtol): pytest.xfail("Test currently too slow") graph = two_normal_graph(1.0, 1.2, 0.8) - assert np.isclose( + assert jnp.isclose( algorithms.expectation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), @@ -132,7 +132,7 @@ def test_stdev(two_normal_graph, rng_key, samples, rtol): pytest.xfail("Test currently too slow") graph = two_normal_graph(1.0, 1.2, 0.8) - assert np.isclose( + assert jnp.isclose( algorithms.standard_deviation( graph, outcome_node_label="X", samples=samples, rng_key=rng_key ), diff --git a/tests/test_graph/test_parameters.py b/tests/test_graph/test_parameters.py index a6ad7db..589ddbe 100644 --- a/tests/test_graph/test_parameters.py +++ b/tests/test_graph/test_parameters.py @@ -1,15 +1,8 @@ """Tests for graph module.""" -from typing import Literal, TypeAlias +import jax.numpy as jnp -import numpy as np - -from causalprog.graph import DistributionNode, ParameterNode - -NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "X"] -NormalGraphNodes: TypeAlias = dict[ - NormalGraphNodeNames, DistributionNode | ParameterNode -] +from causalprog.graph import ParameterNode def test_parameter_node(rng_key, raises_context): @@ -18,6 +11,6 @@ def test_parameter_node(rng_key, raises_context): with raises_context(ValueError("Missing input for parameter")): node.sample({}, {}, 1, rng_key=rng_key) - assert np.allclose( - node.sample({node.label: 0.3}, {}, 10, rng_key=rng_key)[0], [0.3] * 10 + assert jnp.allclose( + node.sample({node.label: 0.3}, {}, 10, rng_key=rng_key), jnp.full((10,), 0.3) ) diff --git a/tests/test_utils/test_norms.py b/tests/test_utils/test_norms.py index fa0abe6..f8673df 100644 --- a/tests/test_utils/test_norms.py +++ b/tests/test_utils/test_norms.py @@ -1,6 +1,6 @@ from collections.abc import Callable -import numpy as np +import jax.numpy as jnp import pytest from causalprog.utils.norms import PyTree, l2_normsq @@ -11,15 +11,15 @@ [ pytest.param(1.0, l2_normsq, 1.0, id="l2^2, scalar"), pytest.param( - np.array([1.0, 2.0, 3.0]), l2_normsq, 14.0, id="l2^2, numpy array" + jnp.array([1.0, 2.0, 3.0]), l2_normsq, 14.0, id="l2^2, numpy array" ), pytest.param( - {"a": 1.0, "b": (np.arange(3), [2.0, (-1.0, 0.0)])}, + {"a": 1.0, "b": (jnp.arange(3), [2.0, (-1.0, 0.0)])}, l2_normsq, - 1.0 + (np.arange(3) ** 2).sum() + 4.0 + 1.0, + 1.0 + (jnp.arange(3) ** 2).sum() + 4.0 + 1.0, id="l2^2, PyTree", ), ], ) def test_norm_value(pt: PyTree, norm: Callable[[PyTree], float], expected_value: float): - assert np.allclose(norm(pt), expected_value) + assert jnp.allclose(norm(pt), expected_value)