From 2fa06234edc0f74c4303c74bd52fa71100760868 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Thu, 25 Sep 2025 15:08:40 -0700 Subject: [PATCH] Fix the JAX version required for `jaxlib._pathways`. PiperOrigin-RevId: 811518044 --- pathwaysutils/jax/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index eb049b6..6863ccc 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -62,13 +62,13 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable del util try: - # jax>0.7.0 + # jax>=0.7.1 from jax.extend import backend # pylint: disable=g-import-not-at-top ifrt_proxy = backend.ifrt_proxy del backend except AttributeError: - # jax<=0.7.0 + # jax<0.7.1 from jax.lib import xla_extension # pylint: disable=g-import-not-at-top ifrt_proxy = xla_extension.ifrt_proxy @@ -76,15 +76,15 @@ def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable try: - # jax>=0.7.2 + # jax>=0.8.0 from jax.jaxlib import _pathways # pylint: disable=g-import-not-at-top jaxlib_pathways = _pathways del _pathways -except (ModuleNotFoundError, AttributeError): - # jax<0.7.2 +except ModuleNotFoundError: + # jax<0.8.0 - jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.7.2") + jaxlib_pathways = _FakeJaxModule("jax.jaxlib._pathways", "0.8.0") del _FakeJaxModule