From b15cafc56a89f8cf6675f4eca3d60a2385b73095 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Tue, 13 Jun 2023 15:26:25 +0200 Subject: [PATCH 01/15] Add JAX multiprocess test Signed-off-by: Albert Wolant --- dali/test/python/jax/jax_client.py | 19 +++++++ dali/test/python/jax/jax_server.py | 66 ++++++++++++++++++++++++ dali/test/python/jax/test_integration.py | 4 +- qa/TL0_multigpu/test_body.sh | 6 +++ 4 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 dali/test/python/jax/jax_client.py create mode 100644 dali/test/python/jax/jax_server.py diff --git a/dali/test/python/jax/jax_client.py b/dali/test/python/jax/jax_client.py new file mode 100644 index 0000000000..6c3b50abed --- /dev/null +++ b/dali/test/python/jax/jax_client.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from jax_server import run_multiprocess_workflow + +if __name__ == "__main__": + run_multiprocess_workflow(process_id=1) diff --git a/dali/test/python/jax/jax_server.py b/dali/test/python/jax/jax_server.py new file mode 100644 index 0000000000..4e7b25bbfd --- /dev/null +++ b/dali/test/python/jax/jax_server.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import jax + +from test_integration import get_dali_tensor_gpu + +import nvidia.dali.types as types +import nvidia.dali.plugin.jax as dax + + +def print_devices(process_id): + print(f"PID {process_id}: Local devices = {jax.local_device_count()}, " + f"global devices = {jax.device_count()}") + + print(f"PID {process_id}: All devices: ") + print_devices_details(jax.devices(), process_id) + + print(f"PID {process_id}: Local devices:") + print_devices_details(jax.local_devices(), process_id) + + +def print_devices_details(devices_list, process_id): + for device in devices_list: + print(f"PID {process_id}: Id = {device.id}, host_id = {device.host_id}, " + f"process_id = {device.process_index}, kind = {device.device_kind}") + + +def test_lax_workflow(process_id): + array_from_dali = dax._to_jax_array(get_dali_tensor_gpu(1, (1), types.INT32)) + + assert array_from_dali.device() == jax.local_devices()[0], \ + "Array should be backed by the device local to current process." + + sum_across_devices = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(array_from_dali) + + assert sum_across_devices[0] == len(jax.devices()),\ + "Sum across devices should be equal to the number of devices as data per device = [1]" + + print(f"PID {process_id}: Passed lax workflow test") + + +def run_multiprocess_workflow(process_id=0): + jax.distributed.initialize( + coordinator_address="localhost:1234", + num_processes=2, + process_id=process_id) + + print_devices(process_id=process_id) + test_lax_workflow(process_id=process_id) + + +if __name__ == "__main__": + run_multiprocess_workflow(process_id=0) diff --git a/dali/test/python/jax/test_integration.py b/dali/test/python/jax/test_integration.py index da31075fa8..9767abc19c 100644 --- a/dali/test/python/jax/test_integration.py +++ b/dali/test/python/jax/test_integration.py @@ -29,7 +29,7 @@ from nose2.tools import cartesian_params -def get_dali_tensor_gpu(value, shape, dtype) -> TensorGPU: +def get_dali_tensor_gpu(value, shape, dtype, device_id=0) -> TensorGPU: """Helper function to create DALI TensorGPU. Args: @@ -47,7 +47,7 @@ def dali_pipeline(): return values - pipe = dali_pipeline(device_id=0) + pipe = dali_pipeline(device_id=device_id) pipe.build() dali_output = pipe.run() diff --git a/qa/TL0_multigpu/test_body.sh b/qa/TL0_multigpu/test_body.sh index 629162b413..f75410c556 100644 --- a/qa/TL0_multigpu/test_body.sh +++ b/qa/TL0_multigpu/test_body.sh @@ -46,6 +46,12 @@ test_pytorch() { test_jax() { ${python_new_invoke_test} -s jax test_integration_multigpu + + # Multiprocess tests + CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & + CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py + + } test_no_fw() { From 7d704ff28993b810a2f9d7ad4c3b67eea24e8419 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Tue, 13 Jun 2023 15:31:06 +0200 Subject: [PATCH 02/15] Remove empty lines Signed-off-by: Albert Wolant --- qa/TL0_multigpu/test_body.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/qa/TL0_multigpu/test_body.sh b/qa/TL0_multigpu/test_body.sh index f75410c556..7087e0782c 100644 --- a/qa/TL0_multigpu/test_body.sh +++ b/qa/TL0_multigpu/test_body.sh @@ -50,8 +50,6 @@ test_jax() { # Multiprocess tests CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py - - } test_no_fw() { From 671caf892bb357159c39dd05db2e8c0339712196 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Wed, 14 Jun 2023 10:41:32 +0200 Subject: [PATCH 03/15] Add missing test run command Signed-off-by: Albert Wolant --- qa/TL0_multigpu/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/TL0_multigpu/test.sh b/qa/TL0_multigpu/test.sh index 948a3ddcbc..ed02153ec4 100755 --- a/qa/TL0_multigpu/test.sh +++ b/qa/TL0_multigpu/test.sh @@ -2,3 +2,4 @@ bash -e ./test_nofw.sh bash -e ./test_cupy.sh bash -e ./test_pytorch.sh +bash -e ./test_jax.sh From 0e63bada71bc63e1c6e75585c0edf6080b71df10 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Wed, 14 Jun 2023 11:09:11 +0200 Subject: [PATCH 04/15] Add logging with logger Signed-off-by: Albert Wolant --- dali/test/python/jax/jax_server.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/dali/test/python/jax/jax_server.py b/dali/test/python/jax/jax_server.py index 4e7b25bbfd..0bbbc928a5 100644 --- a/dali/test/python/jax/jax_server.py +++ b/dali/test/python/jax/jax_server.py @@ -15,6 +15,8 @@ import jax +import logging as log + from test_integration import get_dali_tensor_gpu import nvidia.dali.types as types @@ -22,20 +24,20 @@ def print_devices(process_id): - print(f"PID {process_id}: Local devices = {jax.local_device_count()}, " - f"global devices = {jax.device_count()}") + log.info(f"Local devices = {jax.local_device_count()}, " + f"global devices = {jax.device_count()}") - print(f"PID {process_id}: All devices: ") + log.info("All devices: ") print_devices_details(jax.devices(), process_id) - print(f"PID {process_id}: Local devices:") + log.info("Local devices:") print_devices_details(jax.local_devices(), process_id) def print_devices_details(devices_list, process_id): for device in devices_list: - print(f"PID {process_id}: Id = {device.id}, host_id = {device.host_id}, " - f"process_id = {device.process_index}, kind = {device.device_kind}") + log.info(f"Id = {device.id}, host_id = {device.host_id}, " + f"process_id = {device.process_index}, kind = {device.device_kind}") def test_lax_workflow(process_id): @@ -49,7 +51,7 @@ def test_lax_workflow(process_id): assert sum_across_devices[0] == len(jax.devices()),\ "Sum across devices should be equal to the number of devices as data per device = [1]" - print(f"PID {process_id}: Passed lax workflow test") + log.info("Passed lax workflow test") def run_multiprocess_workflow(process_id=0): @@ -58,6 +60,10 @@ def run_multiprocess_workflow(process_id=0): num_processes=2, process_id=process_id) + log.basicConfig( + level=log.INFO, + format=f"PID {process_id}: %(message)s") + print_devices(process_id=process_id) test_lax_workflow(process_id=process_id) From 52c90108306c616c269640e4ca858983267f34c7 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Wed, 14 Jun 2023 13:39:37 +0200 Subject: [PATCH 05/15] Review fix Signed-off-by: Albert Wolant --- dali/test/python/jax/jax_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dali/test/python/jax/jax_server.py b/dali/test/python/jax/jax_server.py index 0bbbc928a5..6d80ddc036 100644 --- a/dali/test/python/jax/jax_server.py +++ b/dali/test/python/jax/jax_server.py @@ -56,7 +56,7 @@ def test_lax_workflow(process_id): def run_multiprocess_workflow(process_id=0): jax.distributed.initialize( - coordinator_address="localhost:1234", + coordinator_address="localhost:12321", num_processes=2, process_id=process_id) From 2984f4a65bf471717e52105801543eae2fe4e397 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Wed, 14 Jun 2023 13:59:56 +0200 Subject: [PATCH 06/15] Fix review Signed-off-by: Albert Wolant --- dali/test/python/jax/test_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dali/test/python/jax/test_integration.py b/dali/test/python/jax/test_integration.py index 9767abc19c..e0458908fe 100644 --- a/dali/test/python/jax/test_integration.py +++ b/dali/test/python/jax/test_integration.py @@ -43,7 +43,7 @@ def get_dali_tensor_gpu(value, shape, dtype, device_id=0) -> TensorGPU: """ @pipeline_def(num_threads=1, batch_size=1) def dali_pipeline(): - values = fn.constant(idata=value, shape=shape, dtype=dtype, device='gpu') + values = types.Constant(value=value, shape=shape, dtype=dtype, device='gpu') return values From 97e58710c7e39a5161d0af7f77d7ebca1e1b1d71 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Wed, 14 Jun 2023 14:03:54 +0200 Subject: [PATCH 07/15] Add type to test Signed-off-by: Albert Wolant --- dali/test/python/jax/test_integration.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dali/test/python/jax/test_integration.py b/dali/test/python/jax/test_integration.py index e0458908fe..c2bf56264a 100644 --- a/dali/test/python/jax/test_integration.py +++ b/dali/test/python/jax/test_integration.py @@ -69,7 +69,10 @@ def test_dali_tensor_gpu_to_jax_array(dtype, shape, value): # then assert jax.numpy.array_equal( jax_array, - jax.numpy.full(shape, value)) + jax.numpy.full( + shape, + value, + types.to_numpy_type(dtype))) # Make sure JAX array is backed by the GPU assert jax_array.device() == jax.devices()[0] From efe13833e8c8de26e5254086773385c51aa4dfd1 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Wed, 14 Jun 2023 14:32:25 +0200 Subject: [PATCH 08/15] Make test better Signed-off-by: Albert Wolant --- dali/test/python/jax/test_integration.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/dali/test/python/jax/test_integration.py b/dali/test/python/jax/test_integration.py index c2bf56264a..694c3640d0 100644 --- a/dali/test/python/jax/test_integration.py +++ b/dali/test/python/jax/test_integration.py @@ -43,7 +43,7 @@ def get_dali_tensor_gpu(value, shape, dtype, device_id=0) -> TensorGPU: """ @pipeline_def(num_threads=1, batch_size=1) def dali_pipeline(): - values = types.Constant(value=value, shape=shape, dtype=dtype, device='gpu') + values = types.Constant(value=np.full(shape, value, dtype), device='gpu') return values @@ -55,7 +55,7 @@ def dali_pipeline(): @cartesian_params( - (types.FLOAT, types.INT32), # dtypes to test + (np.float32, np.int32), # dtypes to test ([], [1], [10], [2, 4], [1, 2, 3]), # shapes to test (1, -99)) # values to test def test_dali_tensor_gpu_to_jax_array(dtype, shape, value): @@ -69,10 +69,7 @@ def test_dali_tensor_gpu_to_jax_array(dtype, shape, value): # then assert jax.numpy.array_equal( jax_array, - jax.numpy.full( - shape, - value, - types.to_numpy_type(dtype))) + jax.numpy.full(shape, value, dtype)) # Make sure JAX array is backed by the GPU assert jax_array.device() == jax.devices()[0] From 3d3a7fe3e59c361dc73096c65b87e6a118e4507d Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Wed, 14 Jun 2023 15:48:09 +0200 Subject: [PATCH 09/15] Add NCCL debug info Signed-off-by: Albert Wolant --- qa/TL0_multigpu/test_body.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/qa/TL0_multigpu/test_body.sh b/qa/TL0_multigpu/test_body.sh index 7087e0782c..2b15ec9d0b 100644 --- a/qa/TL0_multigpu/test_body.sh +++ b/qa/TL0_multigpu/test_body.sh @@ -48,6 +48,8 @@ test_jax() { ${python_new_invoke_test} -s jax test_integration_multigpu # Multiprocess tests + export NCCL_DEBUG=INFO + CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py } From bdb82d718570d4c83392668dbaefdd7025cf7432 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Wed, 14 Jun 2023 20:19:09 +0200 Subject: [PATCH 10/15] Fix JAX server test Signed-off-by: Albert Wolant --- dali/test/python/jax/jax_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dali/test/python/jax/jax_server.py b/dali/test/python/jax/jax_server.py index 6d80ddc036..b34a41e231 100644 --- a/dali/test/python/jax/jax_server.py +++ b/dali/test/python/jax/jax_server.py @@ -14,12 +14,12 @@ import jax +import numpy as np import logging as log from test_integration import get_dali_tensor_gpu -import nvidia.dali.types as types import nvidia.dali.plugin.jax as dax @@ -41,7 +41,7 @@ def print_devices_details(devices_list, process_id): def test_lax_workflow(process_id): - array_from_dali = dax._to_jax_array(get_dali_tensor_gpu(1, (1), types.INT32)) + array_from_dali = dax._to_jax_array(get_dali_tensor_gpu(1, (1), np.int32)) assert array_from_dali.device() == jax.local_devices()[0], \ "Array should be backed by the device local to current process." From 3934ee848561b0dccfad356c2b02aa1fdac1c153 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Tue, 20 Jun 2023 13:18:01 +0200 Subject: [PATCH 11/15] Fix for CUDA 12 CI run Signed-off-by: Albert Wolant --- qa/TL0_multigpu/test_body.sh | 16 +++++++++++----- qa/setup_packages.py | 3 --- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/qa/TL0_multigpu/test_body.sh b/qa/TL0_multigpu/test_body.sh index 2b15ec9d0b..55dc79d44c 100644 --- a/qa/TL0_multigpu/test_body.sh +++ b/qa/TL0_multigpu/test_body.sh @@ -47,11 +47,17 @@ test_pytorch() { test_jax() { ${python_new_invoke_test} -s jax test_integration_multigpu - # Multiprocess tests - export NCCL_DEBUG=INFO - - CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & - CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py + # Workaround for NCCL version mismatch + # TODO: Fix this in the CI setup_packages.py + echo "DALI_CUDA_VERSION_MAJOR=$DALI_CUDA_MAJOR_VERSION" + if [ "$DALI_CUDA_MAJOR_VERSION" == "12" ] + then + python -m pip uninstall -y jax jaxlib + python -m pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + + CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & + CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py + fi } test_no_fw() { diff --git a/qa/setup_packages.py b/qa/setup_packages.py index 0131f6ed26..2e6cc366b9 100755 --- a/qa/setup_packages.py +++ b/qa/setup_packages.py @@ -502,9 +502,6 @@ def get_pyvers_name(self, url, cuda_version): "whl/linux/mkl/avx/stable.html"), CudaPackageExtraIndex("jax", # name used in our test script, see the mxnet case {"113": [PckgVer("0.4.11", - python_min_ver="3.8", - dependencies=["jaxlib"])], - "121": [PckgVer("0.4.11", python_min_ver="3.8", dependencies=["jaxlib"])]}, # name used during installation From ada372d2fbc2d04cdd9b70d8250fe13aa90f4cc1 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Tue, 20 Jun 2023 15:31:43 +0200 Subject: [PATCH 12/15] Fix test run for CUDA 11 Signed-off-by: Albert Wolant --- qa/TL0_multigpu/test_body.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/qa/TL0_multigpu/test_body.sh b/qa/TL0_multigpu/test_body.sh index 55dc79d44c..dde29c5b2e 100644 --- a/qa/TL0_multigpu/test_body.sh +++ b/qa/TL0_multigpu/test_body.sh @@ -54,10 +54,10 @@ test_jax() { then python -m pip uninstall -y jax jaxlib python -m pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - - CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & - CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py fi + + CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & + CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py } test_no_fw() { From 8f4289e0ab211e50c4597521cff7923d16d63c26 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Tue, 20 Jun 2023 16:36:33 +0200 Subject: [PATCH 13/15] Fix test run Signed-off-by: Albert Wolant --- qa/TL0_multigpu/test_body.sh | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/qa/TL0_multigpu/test_body.sh b/qa/TL0_multigpu/test_body.sh index dde29c5b2e..c3060cb80e 100644 --- a/qa/TL0_multigpu/test_body.sh +++ b/qa/TL0_multigpu/test_body.sh @@ -50,11 +50,9 @@ test_jax() { # Workaround for NCCL version mismatch # TODO: Fix this in the CI setup_packages.py echo "DALI_CUDA_VERSION_MAJOR=$DALI_CUDA_MAJOR_VERSION" - if [ "$DALI_CUDA_MAJOR_VERSION" == "12" ] - then - python -m pip uninstall -y jax jaxlib - python -m pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - fi + + python -m pip uninstall -y jax jaxlib + python -m pip install --upgrade "jax[cuda${DALI_CUDA_MAJOR_VERSION}_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py From 5bebb2d2035c65c1077c346af3b68c3ff101ea40 Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Wed, 21 Jun 2023 10:20:32 +0200 Subject: [PATCH 14/15] Fix test run Signed-off-by: Albert Wolant --- qa/TL0_multigpu/test_body.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/qa/TL0_multigpu/test_body.sh b/qa/TL0_multigpu/test_body.sh index c3060cb80e..4b7e68b19e 100644 --- a/qa/TL0_multigpu/test_body.sh +++ b/qa/TL0_multigpu/test_body.sh @@ -49,10 +49,9 @@ test_jax() { # Workaround for NCCL version mismatch # TODO: Fix this in the CI setup_packages.py - echo "DALI_CUDA_VERSION_MAJOR=$DALI_CUDA_MAJOR_VERSION" - + # or move this test to the L3 with JAX container as base python -m pip uninstall -y jax jaxlib - python -m pip install --upgrade "jax[cuda${DALI_CUDA_MAJOR_VERSION}_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + python -m pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py From ac3f02a30d0918575812e0437967e61266458d7a Mon Sep 17 00:00:00 2001 From: Albert Wolant Date: Wed, 21 Jun 2023 11:45:58 +0200 Subject: [PATCH 15/15] Fix test Signed-off-by: Albert Wolant --- qa/TL0_multigpu/test_body.sh | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/qa/TL0_multigpu/test_body.sh b/qa/TL0_multigpu/test_body.sh index 4b7e68b19e..7f02b0ce4d 100644 --- a/qa/TL0_multigpu/test_body.sh +++ b/qa/TL0_multigpu/test_body.sh @@ -50,11 +50,15 @@ test_jax() { # Workaround for NCCL version mismatch # TODO: Fix this in the CI setup_packages.py # or move this test to the L3 with JAX container as base - python -m pip uninstall -y jax jaxlib - python -m pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + echo "DALI_CUDA_VERSION_MAJOR=$DALI_CUDA_MAJOR_VERSION" + if [ "$DALI_CUDA_MAJOR_VERSION" == "12" ] + then + python -m pip uninstall -y jax jaxlib + python -m pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & - CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py + CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & + CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py + fi } test_no_fw() {