Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multigpu JAX tutorial #4956

Merged
merged 12 commits into from
Aug 1, 2023
Merged

Conversation

awolant
Copy link
Contributor

@awolant awolant commented Jul 26, 2023

Category:

New feature

Description:

Adds tutorial on how to train a neural network with DALI and JAX on multiple GPUs.

Additional information:

Affected modules and functionalities:

JAX docs.

Key points relevant for the review:

Is this understandable? Spelling, grammar?

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: 3553

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@awolant
Copy link
Contributor Author

awolant commented Jul 26, 2023

!build

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9121124]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9121124]: BUILD FAILED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@awolant awolant marked this pull request as ready for review July 31, 2023 08:10
@awolant
Copy link
Contributor Author

awolant commented Jul 31, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9175574]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9175574]: BUILD PASSED

@@ -0,0 +1,304 @@
{
Copy link
Contributor

@mzient mzient Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we show how to run training from "Training neural network with DALI and JAX" usingon multiple GPUs.

If you haven't already done so, it is best to start with single GPU example to better understand following content.


Reply via ReviewNB

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,304 @@
{
Copy link
Contributor

@mzient mzient Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(...)creating a pipeline definition function.

Note the new arguments passed to the fn.readers.caffe2

(...) used to controllcontrol sharding:

(...) sets the total number of shards

Also, (<--comma) the (not entirely sure about this one)device_id argument was removed from the decorator

(...)particualr particular

batch_size_per_gpu as batch_size // jax.device_cout()

^^^^ don't we want to round up?


Reply via ReviewNB

Copy link
Contributor Author

@awolant awolant Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

When it comes to batch_size_per_gpu: for this test I set it up with batch_size equal to 200 so it is divisible by common number of possible GPUs (2, 4, 8).

I wanted to make this code as simple as possible.

I added a note to this part to explain that this may need some adjustment to make sure that you use all samples in every epoch.

@@ -0,0 +1,304 @@
{
Copy link
Contributor

@mzient mzient Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each of them will start the preprocessing from a differnt shard

Does it mean it will then proceed to the next shard? If they process only items belonging to a particular shard, then better wording would be

Each of them will process a different shard of the dataset

Similar as Like in the single GPU example

or

Similar as in Similarly to the single GPU example

(...) . It will encapsule encapsulate (...) return a dictionary of JAX arrays (...)


Reply via ReviewNB

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean it will then proceed to the next shard? If they process only items belonging to a particular shard

This is controlled by stick_to_shard argument. By default it is false, so in the next epoch pipeline will move to the next shard. I added a sentence with the information about this argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rest done

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@awolant
Copy link
Contributor Author

awolant commented Jul 31, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9179529]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9179529]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9179529]: BUILD PASSED

@awolant awolant merged commit c147a2e into NVIDIA:main Aug 1, 2023
5 checks passed
JanuszL pushed a commit to JanuszL/DALI that referenced this pull request Oct 13, 2023
Adds tutorial on how to train a neural network with DALI and JAX on multiple GPUs.

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants