diff --git a/pathwaysutils/plugin_executable.py b/pathwaysutils/plugin_executable.py index 9a051e6..21f1b44 100644 --- a/pathwaysutils/plugin_executable.py +++ b/pathwaysutils/plugin_executable.py @@ -19,7 +19,7 @@ import jax from jax._src.interpreters import pxla -from jaxlib.xla_extension import ifrt_programs +from jax.extend.ifrt_programs import ifrt_programs class PluginExecutable: diff --git a/pathwaysutils/proxy_backend.py b/pathwaysutils/proxy_backend.py index 1cac7f9..9ed5b32 100644 --- a/pathwaysutils/proxy_backend.py +++ b/pathwaysutils/proxy_backend.py @@ -14,12 +14,12 @@ """Register the IFRT Proxy as a backend for JAX.""" import jax -from jax._src import xla_bridge -from jaxlib.xla_extension import ifrt_proxy +from jax.extend import backend +from jax.lib.xla_extension import ifrt_proxy def register_backend_factory(): - xla_bridge.register_backend_factory( + backend.register_backend_factory( "proxy", lambda: ifrt_proxy.get_client( jax.config.read("jax_backend_target"), diff --git a/pathwaysutils/test/proxy_backend_test.py b/pathwaysutils/test/proxy_backend_test.py index 0cd1274..08988de 100644 --- a/pathwaysutils/test/proxy_backend_test.py +++ b/pathwaysutils/test/proxy_backend_test.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for the proxy backend module.""" + +from unittest import mock + import jax from jax.extend import backend from jax.lib.xla_extension import ifrt_proxy -import mock from pathwaysutils import proxy_backend from absl.testing import absltest @@ -26,7 +29,9 @@ def setUp(self): super().setUp() jax.config.update("jax_platforms", "proxy") jax.config.update("jax_backend_target", "grpc://localhost:12345") + backend.clear_backends() + @absltest.skip("b/408025233") def test_no_proxy_backend_registration_raises_error(self): self.assertRaises(RuntimeError, backend.backends)