Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic JAX multi process test #4906

Merged
merged 16 commits into from
Jun 21, 2023
19 changes: 19 additions & 0 deletions dali/test/python/jax/jax_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from jax_server import run_multiprocess_workflow

if __name__ == "__main__":
run_multiprocess_workflow(process_id=1)
72 changes: 72 additions & 0 deletions dali/test/python/jax/jax_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import jax

import logging as log

from test_integration import get_dali_tensor_gpu

import nvidia.dali.types as types
import nvidia.dali.plugin.jax as dax


def print_devices(process_id):
log.info(f"Local devices = {jax.local_device_count()}, "
f"global devices = {jax.device_count()}")

log.info("All devices: ")
print_devices_details(jax.devices(), process_id)

log.info("Local devices:")
print_devices_details(jax.local_devices(), process_id)


def print_devices_details(devices_list, process_id):
for device in devices_list:
log.info(f"Id = {device.id}, host_id = {device.host_id}, "
f"process_id = {device.process_index}, kind = {device.device_kind}")


def test_lax_workflow(process_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lax or jax?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lax: https://jax.readthedocs.io/en/latest/jax.lax.html
lax is one of the ways you can do multi gpu computations in JAX

array_from_dali = dax._to_jax_array(get_dali_tensor_gpu(1, (1), types.INT32))

assert array_from_dali.device() == jax.local_devices()[0], \
"Array should be backed by the device local to current process."

sum_across_devices = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(array_from_dali)

assert sum_across_devices[0] == len(jax.devices()),\
"Sum across devices should be equal to the number of devices as data per device = [1]"

log.info("Passed lax workflow test")


def run_multiprocess_workflow(process_id=0):
jax.distributed.initialize(
coordinator_address="localhost:12321",
num_processes=2,
process_id=process_id)

log.basicConfig(
level=log.INFO,
format=f"PID {process_id}: %(message)s")

print_devices(process_id=process_id)
test_lax_workflow(process_id=process_id)


if __name__ == "__main__":
run_multiprocess_workflow(process_id=0)
4 changes: 2 additions & 2 deletions dali/test/python/jax/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from nose2.tools import cartesian_params


def get_dali_tensor_gpu(value, shape, dtype) -> TensorGPU:
def get_dali_tensor_gpu(value, shape, dtype, device_id=0) -> TensorGPU:
"""Helper function to create DALI TensorGPU.

Args:
Expand All @@ -47,7 +47,7 @@ def dali_pipeline():

mzient marked this conversation as resolved.
Show resolved Hide resolved
return values

pipe = dali_pipeline(device_id=0)
pipe = dali_pipeline(device_id=device_id)
pipe.build()
dali_output = pipe.run()

Expand Down
1 change: 1 addition & 0 deletions qa/TL0_multigpu/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
bash -e ./test_nofw.sh
bash -e ./test_cupy.sh
bash -e ./test_pytorch.sh
bash -e ./test_jax.sh
4 changes: 4 additions & 0 deletions qa/TL0_multigpu/test_body.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ test_pytorch() {

test_jax() {
${python_new_invoke_test} -s jax test_integration_multigpu

# Multiprocess tests
CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py &
CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py
Copy link
Contributor Author

@awolant awolant Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any ideas on how to achieve something like this better are welcome.

}

test_no_fw() {
Expand Down