-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic JAX multi process test #4906
Conversation
Signed-off-by: Albert Wolant <awolant@nvidia.com>
qa/TL0_multigpu/test_body.sh
Outdated
# Multiprocess tests | ||
CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py & | ||
CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any ideas on how to achieve something like this better are welcome.
!build |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [8627787]: BUILD STARTED |
CI MESSAGE: [8627787]: BUILD PASSED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [8638329]: BUILD STARTED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [8638630]: BUILD STARTED |
f"process_id = {device.process_index}, kind = {device.device_kind}") | ||
|
||
|
||
def test_lax_workflow(process_id): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lax or jax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lax: https://jax.readthedocs.io/en/latest/jax.lax.html
lax is one of the ways you can do multi gpu computations in JAX
dali/test/python/jax/jax_server.py
Outdated
|
||
def run_multiprocess_workflow(process_id=0): | ||
jax.distributed.initialize( | ||
coordinator_address="localhost:1234", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick, perhaps, but this is a fairly low port number and it may be taken - especially on desktop machines (specifically, it's a default streaming port for VLC). Suggested alternative: 12321
- it was not listed as a port used by any known service.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [8638630]: BUILD FAILED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [8640720]: BUILD STARTED |
CI MESSAGE: [8640724]: BUILD STARTED |
CI MESSAGE: [8640720]: BUILD FAILED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [8677647]: BUILD STARTED |
CI MESSAGE: [8677633]: BUILD FAILED |
CI MESSAGE: [8677633]: BUILD STARTED |
CI MESSAGE: [8677647]: BUILD PASSED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [8687375]: BUILD STARTED |
CI MESSAGE: [8687374]: BUILD STARTED |
CI MESSAGE: [8687374]: BUILD PASSED |
CI MESSAGE: [8687375]: BUILD PASSED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [8688313]: BUILD STARTED |
CI MESSAGE: [8688311]: BUILD STARTED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [8688820]: BUILD STARTED |
CI MESSAGE: [8688821]: BUILD STARTED |
CI MESSAGE: [8688821]: BUILD FAILED |
CI MESSAGE: [8688820]: BUILD FAILED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [8698347]: BUILD STARTED |
CI MESSAGE: [8698350]: BUILD STARTED |
CI MESSAGE: [8699075]: BUILD STARTED |
CI MESSAGE: [8699074]: BUILD STARTED |
CI MESSAGE: [8699075]: BUILD PASSED |
CI MESSAGE: [8699074]: BUILD PASSED |
* Add JAX multiprocess test This PR adds basic test for JAX multi process. Main purpose of the PR is to establish that this test works and make sure that code that DALI uses for device placement works in multi process environment. Subsequent tests will be added to files created in this PR to test next features. Signed-off-by: Albert Wolant <awolant@nvidia.com>
Category:
New feature
Description:
This PR adds basic test for JAX multi process. Main purpose of the PR is to establish that this test works and make sure that code that DALI uses for device placement works in multi process environment.
Additional information:
Subsequent tests will be added to files created in this PR to test next features.
Affected modules and functionalities:
JAX integration Python tests.
Key points relevant for the review:
Did added tests run in the CI?
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: DALI-3479