Skip to content

Commit

Permalink
enable rocm/hip for multi_client_test
Browse files Browse the repository at this point in the history
  • Loading branch information
i-chaochen committed Jan 31, 2023
1 parent 8fb4e30 commit d29b6d6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tensorflow/dtensor/python/tests/multi_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def multi_client_main():

# No GPU visible to the flock controller.
os.environ['CUDA_VISIBLE_DEVICES'] = ''

os.environ['HIP_VISIBLE_DEVICES'] = ''
# Python multiprocess module in OSS.
mp_context = test_backend_util.get_mp_context()

Expand Down Expand Up @@ -260,7 +260,7 @@ def run_client(idx, server_ports, additional_ports, num_devices):
Virtual devices are configured before test.main() is called.
Each client is configured to only have access to the physical GPU device
corresponding to its client id via CUDA_VISIBLE_DEVICES.
corresponding to its client id via CUDA_VISIBLE_DEVICES/HIP_VISIBLE_DEVICES.
Each client is configured to only have access to some TPU cores
corresponding to its client id via flags.
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/dtensor/python/tests/test_backend_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def slice_host_devices_for_multiworker(num_clients, client_id, ports):
if num_clients == 0:
# All GPUs are visible to the client.
del os.environ['CUDA_VISIBLE_DEVICES']
else:
del os.environ['HIP_VISIBLE_DEVICES']
else:
# Make the client_id-th GPU visible to the client.
os.environ['CUDA_VISIBLE_DEVICES'] = f'{client_id}'
os.environ['HIP_VISIBLE_DEVICES'] = f'{client_id}'
# Make the client_id-th (4x) TPU cores visible to the client.
os.environ['CLOUD_TPU_TASK_ID'] = f'{client_id}'
if 'tpu' in DTENSOR_TEST_UTIL_BACKEND.value:
Expand Down

0 comments on commit d29b6d6

Please sign in to comment.