-
Notifications
You must be signed in to change notification settings - Fork 618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multigpu JAX tutorial #4956
Conversation
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
CI MESSAGE: [9121124]: BUILD STARTED |
CI MESSAGE: [9121124]: BUILD FAILED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [9175574]: BUILD STARTED |
CI MESSAGE: [9175574]: BUILD PASSED |
@@ -0,0 +1,304 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we show how to run training from "Training neural network with DALI and JAX" usingon multiple GPUs.
If you haven't already done so, it is best to start with single GPU example to better understand following content.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -0,0 +1,304 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(...)creating a pipeline definition function.
Note the new arguments passed to the fn.readers.caffe2
(...) used to controllcontrol sharding:
(...) sets the total number of shards
Also, (<--comma) the (not entirely sure about this one)device_id
argument was removed from the decorator
(...)particualr particular
batch_size_per_gpu
as batch_size // jax.device_cout()
^^^^ don't we want to round up?
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
When it comes to batch_size_per_gpu
: for this test I set it up with batch_size
equal to 200 so it is divisible by common number of possible GPUs (2, 4, 8).
I wanted to make this code as simple as possible.
I added a note to this part to explain that this may need some adjustment to make sure that you use all samples in every epoch.
@@ -0,0 +1,304 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each of them will start the preprocessing from a differnt shard
Does it mean it will then proceed to the next shard? If they process only items belonging to a particular shard, then better wording would be
Each of them will process a different shard of the dataset
Similar as Like in the single GPU example
or
Similar as in Similarly to the single GPU example
(...) . It will encapsule encapsulate (...) return a dictionary of JAX arrays (...)
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean it will then proceed to the next shard? If they process only items belonging to a particular shard
This is controlled by stick_to_shard
argument. By default it is false, so in the next epoch pipeline will move to the next shard. I added a sentence with the information about this argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rest done
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [9179529]: BUILD STARTED |
CI MESSAGE: [9179529]: BUILD FAILED |
CI MESSAGE: [9179529]: BUILD PASSED |
Adds tutorial on how to train a neural network with DALI and JAX on multiple GPUs. Signed-off-by: Albert Wolant <awolant@nvidia.com>
Category:
New feature
Description:
Adds tutorial on how to train a neural network with DALI and JAX on multiple GPUs.
Additional information:
Affected modules and functionalities:
JAX docs.
Key points relevant for the review:
Is this understandable? Spelling, grammar?
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: 3553