From a22f7ce475bd2518f3de27089424a5885b3d4d6d Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Wed, 2 Apr 2025 12:50:25 -0700 Subject: [PATCH] Skipping a broken unittest PiperOrigin-RevId: 743244667 --- pathwaysutils/plugin_executable.py | 2 +- pathwaysutils/proxy_backend.py | 6 +++--- pathwaysutils/test/proxy_backend_test.py | 7 ++++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pathwaysutils/plugin_executable.py b/pathwaysutils/plugin_executable.py index 9a051e6..21f1b44 100644 --- a/pathwaysutils/plugin_executable.py +++ b/pathwaysutils/plugin_executable.py @@ -19,7 +19,7 @@ import jax from jax._src.interpreters import pxla -from jaxlib.xla_extension import ifrt_programs +from jax.extend.ifrt_programs import ifrt_programs class PluginExecutable: diff --git a/pathwaysutils/proxy_backend.py b/pathwaysutils/proxy_backend.py index 1cac7f9..9ed5b32 100644 --- a/pathwaysutils/proxy_backend.py +++ b/pathwaysutils/proxy_backend.py @@ -14,12 +14,12 @@ """Register the IFRT Proxy as a backend for JAX.""" import jax -from jax._src import xla_bridge -from jaxlib.xla_extension import ifrt_proxy +from jax.extend import backend +from jax.lib.xla_extension import ifrt_proxy def register_backend_factory(): - xla_bridge.register_backend_factory( + backend.register_backend_factory( "proxy", lambda: ifrt_proxy.get_client( jax.config.read("jax_backend_target"), diff --git a/pathwaysutils/test/proxy_backend_test.py b/pathwaysutils/test/proxy_backend_test.py index 0cd1274..08988de 100644 --- a/pathwaysutils/test/proxy_backend_test.py +++ b/pathwaysutils/test/proxy_backend_test.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for the proxy backend module.""" + +from unittest import mock + import jax from jax.extend import backend from jax.lib.xla_extension import ifrt_proxy -import mock from pathwaysutils import proxy_backend from absl.testing import absltest @@ -26,7 +29,9 @@ def setUp(self): super().setUp() jax.config.update("jax_platforms", "proxy") jax.config.update("jax_backend_target", "grpc://localhost:12345") + backend.clear_backends() + @absltest.skip("b/408025233") def test_no_proxy_backend_registration_raises_error(self): self.assertRaises(RuntimeError, backend.backends)