-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic JAX multi process test #4906
Merged
Merged
Changes from 5 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
b15cafc
Add JAX multiprocess test
awolant 7d704ff
Remove empty lines
awolant 671caf8
Add missing test run command
awolant 0e63bad
Add logging with logger
awolant 52c9010
Review fix
awolant 2984f4a
Fix review
awolant 97e5871
Add type to test
awolant efe1383
Make test better
awolant 3d3a7fe
Add NCCL debug info
awolant ac56b1d
Merge remote-tracking branch 'nvidia/main' into add_jax_multiprocess_…
awolant bdb82d7
Fix JAX server test
awolant 3934ee8
Fix for CUDA 12 CI run
awolant ada372d
Fix test run for CUDA 11
awolant 8f4289e
Fix test run
awolant 5bebb2d
Fix test run
awolant ac3f02a
Fix test
awolant File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from jax_server import run_multiprocess_workflow | ||
|
||
if __name__ == "__main__": | ||
run_multiprocess_workflow(process_id=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import jax | ||
|
||
import logging as log | ||
|
||
from test_integration import get_dali_tensor_gpu | ||
|
||
import nvidia.dali.types as types | ||
import nvidia.dali.plugin.jax as dax | ||
|
||
|
||
def print_devices(process_id): | ||
log.info(f"Local devices = {jax.local_device_count()}, " | ||
f"global devices = {jax.device_count()}") | ||
|
||
log.info("All devices: ") | ||
print_devices_details(jax.devices(), process_id) | ||
|
||
log.info("Local devices:") | ||
print_devices_details(jax.local_devices(), process_id) | ||
|
||
|
||
def print_devices_details(devices_list, process_id): | ||
for device in devices_list: | ||
log.info(f"Id = {device.id}, host_id = {device.host_id}, " | ||
f"process_id = {device.process_index}, kind = {device.device_kind}") | ||
|
||
|
||
def test_lax_workflow(process_id): | ||
array_from_dali = dax._to_jax_array(get_dali_tensor_gpu(1, (1), types.INT32)) | ||
|
||
assert array_from_dali.device() == jax.local_devices()[0], \ | ||
"Array should be backed by the device local to current process." | ||
|
||
sum_across_devices = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(array_from_dali) | ||
|
||
assert sum_across_devices[0] == len(jax.devices()),\ | ||
"Sum across devices should be equal to the number of devices as data per device = [1]" | ||
|
||
log.info("Passed lax workflow test") | ||
|
||
|
||
def run_multiprocess_workflow(process_id=0): | ||
jax.distributed.initialize( | ||
coordinator_address="localhost:12321", | ||
num_processes=2, | ||
process_id=process_id) | ||
|
||
log.basicConfig( | ||
level=log.INFO, | ||
format=f"PID {process_id}: %(message)s") | ||
|
||
print_devices(process_id=process_id) | ||
test_lax_workflow(process_id=process_id) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_multiprocess_workflow(process_id=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
bash -e ./test_nofw.sh | ||
bash -e ./test_cupy.sh | ||
bash -e ./test_pytorch.sh | ||
bash -e ./test_jax.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,10 @@ test_pytorch() { | |
|
||
test_jax() { | ||
${python_new_invoke_test} -s jax test_integration_multigpu | ||
|
||
# Multiprocess tests | ||
CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & | ||
CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any ideas on how to achieve something like this better are welcome. |
||
} | ||
|
||
test_no_fw() { | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lax or jax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lax: https://jax.readthedocs.io/en/latest/jax.lax.html
lax is one of the ways you can do multi gpu computations in JAX