-
Notifications
You must be signed in to change notification settings - Fork 621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add default arg values to JAX decorator #5115
Conversation
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [10415392]: BUILD STARTED |
CI MESSAGE: [10415392]: BUILD PASSED |
if 'num_threads' not in wrapper_kwargs: | ||
wrapper_kwargs['num_threads'] = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This choice seems quite arbitrary. I'd recommend getting the number of processors (or perhaps the number of processors per GPU?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did that because that's what we do it a lot in our codebase and especially in the tutorials. Do we have some benchmarks of what it does to the performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But you need to account for the number of CPUs that may need to be used for the JAX itself, the speed of the network, and things that may run in parallel.
I'm afraid that without an elaborate autotune we can't provide an accurate value here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of middle ground between elaborate autotune and a fixed number.
Still, I'd at least like to have a big, fat comment saying that it's subject to change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, added docs, comments and tutorial section on num_threads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment about the "4" not being special.
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [10473251]: BUILD STARTED |
CI MESSAGE: [10473251]: BUILD PASSED |
Category:
New feature
Description:
Adds default values for
device_id
andnum_threads
for decorated JAX iterator functions. Thanks to this change it is easier to applysharding
to scale up the computation in the most common pattern.Additional information:
Affected modules and functionalities:
JAX iterator decorator.
Key points relevant for the review:
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: DALI-3671