Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic JAX multi process test #4906

Merged
merged 16 commits into from
Jun 21, 2023

Conversation

awolant
Copy link
Contributor

@awolant awolant commented Jun 13, 2023

Category:

New feature

Description:

This PR adds basic test for JAX multi process. Main purpose of the PR is to establish that this test works and make sure that code that DALI uses for device placement works in multi process environment.

Additional information:

Subsequent tests will be added to files created in this PR to test next features.

Affected modules and functionalities:

JAX integration Python tests.

Key points relevant for the review:

Did added tests run in the CI?

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-3479

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Comment on lines 50 to 52
# Multiprocess tests
CUDA_VISIBLE_DEVICES="1" python jax/jax_client.py &
CUDA_VISIBLE_DEVICES="0" python jax/jax_server.py
Copy link
Contributor Author

@awolant awolant Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any ideas on how to achieve something like this better are welcome.

@awolant
Copy link
Contributor Author

awolant commented Jun 13, 2023

!build

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8627787]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8627787]: BUILD PASSED

@awolant awolant changed the title [WIP] Add JAX multi process test Add JAX multi process test Jun 13, 2023
@awolant awolant changed the title Add JAX multi process test Add basic JAX multi process test Jun 13, 2023
@awolant awolant marked this pull request as ready for review June 13, 2023 15:57
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8638329]: BUILD STARTED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8638630]: BUILD STARTED

f"process_id = {device.process_index}, kind = {device.device_kind}")


def test_lax_workflow(process_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lax or jax?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lax: https://jax.readthedocs.io/en/latest/jax.lax.html
lax is one of the ways you can do multi gpu computations in JAX


def run_multiprocess_workflow(process_id=0):
jax.distributed.initialize(
coordinator_address="localhost:1234",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, perhaps, but this is a fairly low port number and it may be taken - especially on desktop machines (specifically, it's a default streaming port for VLC). Suggested alternative: 12321 - it was not listed as a port used by any known service.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8638630]: BUILD FAILED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8640720]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8640724]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8640720]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8677647]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8677633]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8677633]: BUILD STARTED

@awolant awolant mentioned this pull request Jun 19, 2023
18 tasks
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8677647]: BUILD PASSED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8687375]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8687374]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8687374]: BUILD PASSED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8687375]: BUILD PASSED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8688313]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8688311]: BUILD STARTED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8688820]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8688821]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8688821]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8688820]: BUILD FAILED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8698347]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8698350]: BUILD STARTED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8699075]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8699074]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8699075]: BUILD PASSED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [8699074]: BUILD PASSED

@awolant awolant merged commit e96c3dc into NVIDIA:main Jun 21, 2023
4 checks passed
JanuszL pushed a commit to JanuszL/DALI that referenced this pull request Oct 13, 2023
* Add JAX multiprocess test

This PR adds basic test for JAX multi process. Main purpose of the PR is to establish that this test works and make sure that code that DALI uses for device placement works in multi process environment.

Subsequent tests will be added to files created in this PR to test next features.

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants