diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index eb049b6..6863ccc 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -62,13 +62,13 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable del util try: - # jax>0.7.0 + # jax>=0.7.1 from jax.extend import backend # pylint: disable=g-import-not-at-top ifrt_proxy = backend.ifrt_proxy del backend except AttributeError: - # jax<=0.7.0 + # jax<0.7.1 from jax.lib import xla_extension # pylint: disable=g-import-not-at-top ifrt_proxy = xla_extension.ifrt_proxy @@ -76,15 +76,15 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable try: - # jax>=0.7.2 + # jax>=0.8.0 from jax.jaxlib import _pathways # pylint: disable=g-import-not-at-top jaxlib_pathways = _pathways del _pathways -except (ModuleNotFoundError, AttributeError): - # jax<0.7.2 +except ModuleNotFoundError: + # jax<0.8.0 - jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.7.2") + jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.8.0") del _FakeJaxModule