diff --git a/tests/distributions/test_abstract_hyperspherical_distribution.py b/tests/distributions/test_abstract_hyperspherical_distribution.py index bd242b770..7e9b9766b 100644 --- a/tests/distributions/test_abstract_hyperspherical_distribution.py +++ b/tests/distributions/test_abstract_hyperspherical_distribution.py @@ -7,6 +7,7 @@ from pyrecest.backend import array, linalg, log, pi, sqrt from pyrecest.distributions import ( AbstractHypersphericalDistribution, + HypersphericalUniformDistribution, VonMisesFisherDistribution, ) @@ -68,6 +69,16 @@ def test_mean_direction_numerical(self): vmf = VonMisesFisherDistribution(mu, kappa) self.assertLess(linalg.norm(vmf.mean_direction_numerical() - mu), 1e-6) + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on this backend", + ) + def test_mean_direction_numerical_undefined_for_uniform_circle(self): + """Tests that undefined mean directions are reported explicitly.""" + uniform_circle = HypersphericalUniformDistribution(1) + with self.assertRaisesRegex(ValueError, "Mean direction is undefined"): + uniform_circle.mean_direction_numerical() + def test_plotting_error_free_1d(self): """Tests the plotting function for circular distributions."""