Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX iterator decorator #5050

Merged
merged 61 commits into from
Oct 5, 2023
Merged

Conversation

awolant
Copy link
Contributor

@awolant awolant commented Sep 12, 2023

Category:

New feature

Description:

It adds decorator API for JAX iterators. With this PR it is possible to write code:

@data_iterator(...):
def pipeline_def(...):
    ...
    return output


# This is an iterator returning JAX compatible outputs
iterator = pipeline_def(...)

Additional information:

There are 2 iterator for JAX DALIGenericIterator and DALIGenericPeekableIterator. Decorators were added for both.

To make sure that API does not diverge there are tests to make sure that decorators have the same args as the objective way of creating the iterators.

Decorators can be used both in declarative way:

@data_iterator(...)
def function(...):
    ....

and functional way:

data_iterator(function)(...)

Both are tested.

Affected modules and functionalities:

JAX plugin

Key points relevant for the review:

Is the documentation enough? Are there any issues with the API? Is it tested well enough?

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-3620

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9966859]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9966142]: BUILD FAILED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9966859]: BUILD PASSED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [9966142]: BUILD PASSED

@awolant awolant marked this pull request as ready for review October 3, 2023 09:51
@awolant awolant changed the title [WIP] Add JAX iterator decorator Add JAX iterator decorator Oct 3, 2023
data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db', 'single', 'jpeg')


def get_all_files_from_directory(dir_path, ext):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity. Any reason why you are not using the file root instead of listing every file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the test for the JAX iterators I care a lot about the order of samples and things like that. So for example I test that with 2 GPUs dataset is sharded properly.
To order the samples I use labels:

file_names = get_all_files_from_directory(data_path, '.jpg')
file_labels = [*range(len(file_names))]

I was using an external source based pipeline before but I wanted to test it with "proper" reader. So I pass files and labels instead of file_root and then can check if the label (sample id) is as expected.

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
@awolant
Copy link
Contributor Author

awolant commented Oct 4, 2023

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [10103548]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [10103548]: BUILD PASSED

the iterator.
Mutually exclusive with `reader_name` argument
reader_name : str, default = None
Name of the reader which will be queried to the shard size, number of shards and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Name of the reader which will be queried to the shard size, number of shards and
Name of the reader which will be queried for the shard size, number of shards and

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for all iterators since this was copied between them.

Signed-off-by: Albert Wolant <awolant@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [10117490]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [10117497]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [10117490]: BUILD PASSED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [10117497]: BUILD PASSED

@awolant awolant merged commit 3e1123d into NVIDIA:main Oct 5, 2023
4 checks passed
JanuszL pushed a commit to JanuszL/DALI that referenced this pull request Oct 13, 2023
It adds decorator API for JAX iterators. With this PR it is possible to write code:

@data_iterator(...):
def pipeline_def(...):
    ...
    return output


# This is an iterator returning JAX compatible outputs
iterator = pipeline_def(...)
Additional information:
There are 2 iterator for JAX DALIGenericIterator and DALIGenericPeekableIterator. Decorators were added for both.

To make sure that API does not diverge there are tests to make sure that decorators have the same args as the objective way of creating the iterators.

Decorators can be used both in declarative way:

@data_iterator(...)
def function(...):
    ....
and functional way:

data_iterator(function)(...)
Both are tested.
---------

Signed-off-by: Albert Wolant <awolant@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants