Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic jax.Sharding support for the iterator #4969

Merged
merged 8 commits into from
Aug 3, 2023

Conversation

awolant
Copy link
Contributor

@awolant awolant commented Aug 2, 2023

Category:

New feature

Description:

Add basic jax.Sharding support for the iterator

Additional information:

Affected modules and functionalities:

JAX iterator has new argument. If it is provided we have different output format.

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-3558

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@awolant awolant marked this pull request as ready for review August 2, 2023 09:56
@awolant awolant changed the title Add basic jax.Sharding support Add basic jax.Sharding support for the iterator Aug 2, 2023
@awolant
Copy link
Contributor Author

awolant commented Aug 2, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9206164]: BUILD STARTED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9206164]: BUILD PASSED

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@@ -140,13 +143,15 @@ def __init__(
auto_reset=False,
last_batch_padded=False,
last_batch_policy=LastBatchPolicy.FILL,
prepare_first_batch=True):
prepare_first_batch=True,
sharding=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sharding=None):
sharding: jax.sharding.Sharding=None):

Small suggestion, you may consider adding type hint here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the semantics of type hints? The documentation says it only needs to be jax.sharding.Sharding compatible, not necessarily an instance of this type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.sharding.Sharding is abstract base class. This should accept NamesSharding, PositionalSharding and maybe others in the future.

When it comes to inheritance and type hints this should be done as

from typing import Type
...

sharding; Type[jax.sharding.Sharding]=None):

I added straight up type assertion against NamesSharding and PositionalSharding since these are the ones we are testing against. If there is new type of sharding in the future we will add it and add tests for it.

We are not using type hints anywhere so I didn't want to add them just here.

Comment on lines 113 to 115
sharding : ``jax.sharding.Sharding`` comaptible object that if present will be used to build
output jax.Array for each category. If ``None`` iterator returns values compatible
with pmaped JAX functions.
Copy link
Contributor

@mzient mzient Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sharding : ``jax.sharding.Sharding`` comaptible object that if present will be used to build
output jax.Array for each category. If ``None`` iterator returns values compatible
with pmaped JAX functions.
sharding : ``jax.sharding.Sharding`` comaptible object that, if present, will be used to build an
output jax.Array for each category. If ``None``, the iterator returns values compatible
with pmapped JAX functions.

Not sure about "pmapped" - double 'p' if it's somehow derived from a verb "to map" (map -> mapped).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the other changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They weren't there yet :)
Done

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Comment on lines 114 to 115
sharding : ``jax.sharding.Sharding`` comaptible object that if present will be used to build
output jax.Array for each category. If ``None`` iterator returns values compatible
Copy link
Contributor

@mzient mzient Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These still apply:

Suggested change
sharding : ``jax.sharding.Sharding`` comaptible object that if present will be used to build
output jax.Array for each category. If ``None`` iterator returns values compatible
sharding : ``jax.sharding.Sharding`` comaptible object that, if present, will be used to build an
output jax.Array for each category. If ``None``, the iterator returns values compatible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


return category_outputs

def _build_output_with_devices(self, next_output, category_name, category_outputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite follow what this function does. An error message in L234 mentions sharding, but the function is invoked when _sharding is None.
Some comment explaining high level functionality would be nice to avoid making this code "write only".

Copy link
Contributor Author

@awolant awolant Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, there are two mechanisms of building sharded output: jax.sharding.Sharding and jax.device_put_sharded(). We want to support both. The error mentioned shard as a general concept not sharding object.

I changed the name to _build_output_with_device_put to indicate which function takes which path. This should be enough to distinguish them.

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@awolant
Copy link
Contributor Author

awolant commented Aug 3, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9223789]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9223789]: BUILD PASSED

@awolant awolant merged commit e92a94a into NVIDIA:main Aug 3, 2023
5 checks passed
JanuszL pushed a commit to JanuszL/DALI that referenced this pull request Oct 13, 2023
Add basic jax.Sharding support for the iterator

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants