Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions src/causalprog/graph/node/parameter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
"""Graph nodes representing parameters."""

from __future__ import annotations

import typing

import numpy as np
import jax
import jax.numpy as jnp
import numpy.typing as npt
from typing_extensions import override

from .base import Node

if typing.TYPE_CHECKING:
import jax
import numpy.typing as npt


class ParameterNode(Node):
"""
Expand Down Expand Up @@ -44,15 +38,15 @@ def __init__(self, *, label: str) -> None:
def sample(
self,
parameter_values: dict[str, float],
sampled_dependencies: dict[str, npt.NDArray[float]],
sampled_dependencies: dict[str, npt.ArrayLike],
samples: int,
*,
rng_key: jax.Array,
) -> npt.NDArray[float]:
) -> npt.ArrayLike:
if self.label not in parameter_values:
msg = f"Missing input for parameter node: {self.label}."
raise ValueError(msg)
return np.full(samples, parameter_values[self.label])
return jnp.full(samples, parameter_values[self.label])

@override
def copy(self) -> Node:
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures/numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Callable
from typing import Concatenate, TypeAlias

import numpy as np
import jax.numpy as jnp
import pytest
from jax import Array
from numpyro.infer import MCMC, NUTS
Expand Down Expand Up @@ -77,7 +77,7 @@ def _inner(left_mcmc: MCMC, right_mcmc: MCMC) -> None:
f"Samples on left ({sample_name}) not present on right"
)
# Confirm samples match.
assert np.allclose(sample_values, samples_r[sample_name]), (
assert jnp.allclose(sample_values, samples_r[sample_name]), (
f"Samples '{sample_name}' do not match"
)
for sample_name in samples_r:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_algorithms/test_moments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for moment algorithms."""

import numpy as np
import jax.numpy as jnp
import pytest

from causalprog import algorithms
Expand All @@ -26,14 +26,14 @@ def test_expectation_stdev_single_normal_node(
graph = normal_graph(mean, stdev)

# Check within hand-computation
assert np.isclose(
assert jnp.isclose(
algorithms.expectation(
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
),
mean,
rtol=rtol,
)
assert np.isclose(
assert jnp.isclose(
algorithms.standard_deviation(
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
),
Expand Down Expand Up @@ -79,18 +79,18 @@ def test_mean_stdev_two_node_graph(

graph = two_normal_graph(mean=mean, cov=stdev, cov2=stdev2)

assert np.isclose(
assert jnp.isclose(
algorithms.expectation(
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
),
mean,
rtol=rtol,
)
assert np.isclose(
assert jnp.isclose(
algorithms.standard_deviation(
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
),
np.sqrt(stdev**2 + stdev2**2),
jnp.sqrt(stdev**2 + stdev2**2),
rtol=rtol,
)

Expand All @@ -108,7 +108,7 @@ def test_expectation(two_normal_graph, rng_key, samples, rtol):
pytest.xfail("Test currently too slow")
graph = two_normal_graph(1.0, 1.2, 0.8)

assert np.isclose(
assert jnp.isclose(
algorithms.expectation(
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
),
Expand All @@ -132,7 +132,7 @@ def test_stdev(two_normal_graph, rng_key, samples, rtol):
pytest.xfail("Test currently too slow")
graph = two_normal_graph(1.0, 1.2, 0.8)

assert np.isclose(
assert jnp.isclose(
algorithms.standard_deviation(
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
),
Expand Down
15 changes: 4 additions & 11 deletions tests/test_graph/test_parameters.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
"""Tests for graph module."""

from typing import Literal, TypeAlias
import jax.numpy as jnp

import numpy as np

from causalprog.graph import DistributionNode, ParameterNode

NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "X"]
NormalGraphNodes: TypeAlias = dict[
NormalGraphNodeNames, DistributionNode | ParameterNode
]
from causalprog.graph import ParameterNode


def test_parameter_node(rng_key, raises_context):
Expand All @@ -18,6 +11,6 @@ def test_parameter_node(rng_key, raises_context):
with raises_context(ValueError("Missing input for parameter")):
node.sample({}, {}, 1, rng_key=rng_key)

assert np.allclose(
node.sample({node.label: 0.3}, {}, 10, rng_key=rng_key)[0], [0.3] * 10
assert jnp.allclose(
node.sample({node.label: 0.3}, {}, 10, rng_key=rng_key), jnp.full((10,), 0.3)
)
10 changes: 5 additions & 5 deletions tests/test_utils/test_norms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable

import numpy as np
import jax.numpy as jnp
import pytest

from causalprog.utils.norms import PyTree, l2_normsq
Expand All @@ -11,15 +11,15 @@
[
pytest.param(1.0, l2_normsq, 1.0, id="l2^2, scalar"),
pytest.param(
np.array([1.0, 2.0, 3.0]), l2_normsq, 14.0, id="l2^2, numpy array"
jnp.array([1.0, 2.0, 3.0]), l2_normsq, 14.0, id="l2^2, numpy array"
),
pytest.param(
{"a": 1.0, "b": (np.arange(3), [2.0, (-1.0, 0.0)])},
{"a": 1.0, "b": (jnp.arange(3), [2.0, (-1.0, 0.0)])},
l2_normsq,
1.0 + (np.arange(3) ** 2).sum() + 4.0 + 1.0,
1.0 + (jnp.arange(3) ** 2).sum() + 4.0 + 1.0,
id="l2^2, PyTree",
),
],
)
def test_norm_value(pt: PyTree, norm: Callable[[PyTree], float], expected_value: float):
assert np.allclose(norm(pt), expected_value)
assert jnp.allclose(norm(pt), expected_value)