-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix support for multi node JAX sharding #5242
Conversation
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@@ -172,7 +172,7 @@ def _next_impl(self): | |||
for category_id, category_name in enumerate(self.output_map): | |||
category_outputs = self._gather_outputs_for_category(pipelines_outputs, category_id) | |||
|
|||
if self._num_gpus == 1: | |||
if self._num_gpus == 1 and self._sharding is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._num_gpus
is equal to the number of pipelines run by this instance of the iterator. So if we need to distinguish multi node training with one GPU per node (process) from just one GPU training.
if isinstance(self._sharding, NamedSharding): | ||
global_shape = (self._sharding.mesh.size * shard_shape[0], *shard_shape[1:]) | ||
else: | ||
global_shape = (self._sharding.shape[0] * shard_shape[0], *shard_shape[1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sharding
variants have inconsisten APIs when it comes to getting the global shape.
!build |
CI MESSAGE: [11566280]: BUILD STARTED |
CI MESSAGE: [11566280]: BUILD PASSED |
assert jax.local_device_count() == jax.device_count(), ( | ||
"Iterator compatible with pmapped JAX functions does not support " | ||
"running in multiprocess mode. Use `sharding` argument instead." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assertions != error checking. If you intend this to be a proper error, please use appropriate exception with explicit raise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [11573496]: BUILD STARTED |
CI MESSAGE: [11573496]: BUILD PASSED |
!build |
CI MESSAGE: [11618023]: BUILD STARTED |
CI MESSAGE: [11618023]: BUILD FAILED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [11946905]: BUILD STARTED |
CI MESSAGE: [11946905]: BUILD PASSED |
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [12275250]: BUILD STARTED |
CI MESSAGE: [12275250]: BUILD PASSED |
Category:
Bug fix
Description:
In some situations
data_iterator
for JAX did not work well in multiprocess environment. This PR improves that.Additional information:
Affected modules and functionalities:
Iterator for JAX. Some adjustments were mode on a code path where
sharding
argument is provided.Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: DALI-3670