From a03d862071bfc9e44811da98cec3f70bd680ce47 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 12 Aug 2025 11:29:37 +0100 Subject: [PATCH 1/4] Remove distrax from backend tests, and adapt test to still be meaningful --- ...ackends.py => test_backend_as_expected.py} | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) rename tests/test_distributions/{test_different_backends.py => test_backend_as_expected.py} (50%) diff --git a/tests/test_distributions/test_different_backends.py b/tests/test_distributions/test_backend_as_expected.py similarity index 50% rename from tests/test_distributions/test_different_backends.py rename to tests/test_distributions/test_backend_as_expected.py index e93c054..55d3e8f 100644 --- a/tests/test_distributions/test_different_backends.py +++ b/tests/test_distributions/test_backend_as_expected.py @@ -1,32 +1,31 @@ """Integration tests checking that the ``Distribution`` class is backend-agnostic.""" -import distrax import jax.numpy as jnp from numpyro.distributions.continuous import MultivariateNormal from causalprog.distribution.base import Distribution, SampleTranslator -def test_different_backends(rng_key) -> None: +def test_backend_matches_explicit(rng_key) -> None: """ - Test that ``Distribution`` can use different (but equivalent) backends. + Test that a ``Distribution`` operates identically to the backend it is supposed + to support.. In this integration test, we setup the same multivariate normal distribution - using both ``NumPyro`` and ``distrax`` as backends. We then use the - ``Distribution`` wrapper class to draw samples from each distribution using the - frontend ``sample`` method, and check the results are identical. + using both ``NumPyro``. We then use the ``Distribution`` wrapper class to draw + samples using the frontend ``sample`` method, and check the results are identical + to what we get from just directly sampling from the ``NumPyro`` object. """ n_dims = 2 mean = jnp.array([0.0] * n_dims) cov = jnp.diag(jnp.array([1.0] * n_dims)) sample_size = (10, 5) - distrax_normal = distrax.MultivariateNormalFullCovariance(mean, cov) - distrax_dist = Distribution(distrax_normal, SampleTranslator(rng_key="seed")) - distrax_samples = distrax_dist.sample(rng_key, sample_size) - npyo_normal = MultivariateNormal(mean, cov) npyo_dist = Distribution(npyo_normal, SampleTranslator(rng_key="key")) npyo_samples = npyo_dist.sample(rng_key, sample_size) - assert jnp.allclose(distrax_samples, npyo_samples) + npyo_samples_explicitly_drawn = MultivariateNormal(mean, cov).sample( + rng_key, sample_size + ) + assert jnp.allclose(npyo_samples_explicitly_drawn, npyo_samples) From e7bae30212421a69f126f65c800b3456f390a641 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 12 Aug 2025 11:30:06 +0100 Subject: [PATCH 2/4] Delete test_family.py, which tests the DistributionFamily class. This class will be removed from the codebase anyway, so there is no point preserving the tests for it. However, we should hold off merging this branch until DistributionFamily is well and truly gone. --- tests/test_distributions/test_family.py | 28 ------------------------- 1 file changed, 28 deletions(-) delete mode 100644 tests/test_distributions/test_family.py diff --git a/tests/test_distributions/test_family.py b/tests/test_distributions/test_family.py deleted file mode 100644 index 72a29e7..0000000 --- a/tests/test_distributions/test_family.py +++ /dev/null @@ -1,28 +0,0 @@ -import distrax -import pytest - -from causalprog.distribution.base import SampleTranslator -from causalprog.distribution.family import DistributionFamily - - -@pytest.mark.parametrize( - ("n_dim_std_normal"), - [pytest.param(2, id="2D normal")], - indirect=["n_dim_std_normal"], -) -def test_builder_matches_backend(n_dim_std_normal) -> None: - """ - Test that building from a family is equivalent - to building via the backend explicitly. - - """ - mnv = distrax.MultivariateNormalFullCovariance - - mnv_family = DistributionFamily(mnv, SampleTranslator(rng_key="seed")) - via_family = mnv_family.construct( - loc=n_dim_std_normal["mean"], covariance_matrix=n_dim_std_normal["cov"] - ) - via_backend = mnv(n_dim_std_normal["mean"], n_dim_std_normal["cov"]) - - assert via_backend.kl_divergence(via_family.get_dist()) == pytest.approx(0.0) - assert via_family.get_dist().kl_divergence(via_backend) == pytest.approx(0.0) From 66d935cc39cf45b21b4c33e3aa511152b478f56d Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 12 Aug 2025 11:31:16 +0100 Subject: [PATCH 3/4] distrax no longer required for tests --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b080730..8a761b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ optional-dependencies = {dev = [ "mkdocstrings", "mkdocstrings-python", ], test = [ - "distrax", "numpy", "pytest", "pytest-cov", From 71704b4af19cfc80d8f80801ac789e1783390749 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 12 Aug 2025 11:31:53 +0100 Subject: [PATCH 4/4] Unpin jax version --- pyproject.toml | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8a761b4..9596f2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,13 +15,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Typing :: Typed", ] -dependencies = [ - "jax<0.7.0", - "networkx", - "numpy", - "numpyro", - "typing_extensions", -] +dependencies = ["jax", "networkx", "numpy", "numpyro", "typing_extensions"] description = "A Python package for causal modelling and inference with stochastic causal programming" dynamic = ["version"] keywords = []