-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JAX iterator decorator #5050
Conversation
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [9966859]: BUILD STARTED |
CI MESSAGE: [9966142]: BUILD FAILED |
CI MESSAGE: [9966859]: BUILD PASSED |
CI MESSAGE: [9966142]: BUILD PASSED |
data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db', 'single', 'jpeg') | ||
|
||
|
||
def get_all_files_from_directory(dir_path, ext): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out of curiosity. Any reason why you are not using the file root instead of listing every file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the test for the JAX iterators I care a lot about the order of samples and things like that. So for example I test that with 2 GPUs dataset is sharded properly.
To order the samples I use labels:
file_names = get_all_files_from_directory(data_path, '.jpg')
file_labels = [*range(len(file_names))]
I was using an external source based pipeline before but I wanted to test it with "proper" reader. So I pass files and labels instead of file_root
and then can check if the label (sample id) is as expected.
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
Signed-off-by: Albert Wolant <awolant@nvidia.com>
!build |
CI MESSAGE: [10103548]: BUILD STARTED |
CI MESSAGE: [10103548]: BUILD PASSED |
the iterator. | ||
Mutually exclusive with `reader_name` argument | ||
reader_name : str, default = None | ||
Name of the reader which will be queried to the shard size, number of shards and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Name of the reader which will be queried to the shard size, number of shards and | |
Name of the reader which will be queried for the shard size, number of shards and |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done for all iterators since this was copied between them.
Signed-off-by: Albert Wolant <awolant@nvidia.com>
CI MESSAGE: [10117490]: BUILD STARTED |
CI MESSAGE: [10117497]: BUILD STARTED |
CI MESSAGE: [10117490]: BUILD PASSED |
CI MESSAGE: [10117497]: BUILD PASSED |
It adds decorator API for JAX iterators. With this PR it is possible to write code: @data_iterator(...): def pipeline_def(...): ... return output # This is an iterator returning JAX compatible outputs iterator = pipeline_def(...) Additional information: There are 2 iterator for JAX DALIGenericIterator and DALIGenericPeekableIterator. Decorators were added for both. To make sure that API does not diverge there are tests to make sure that decorators have the same args as the objective way of creating the iterators. Decorators can be used both in declarative way: @data_iterator(...) def function(...): .... and functional way: data_iterator(function)(...) Both are tested. --------- Signed-off-by: Albert Wolant <awolant@nvidia.com>
Category:
New feature
Description:
It adds decorator API for JAX iterators. With this PR it is possible to write code:
Additional information:
There are 2 iterator for JAX
DALIGenericIterator
andDALIGenericPeekableIterator
. Decorators were added for both.To make sure that API does not diverge there are tests to make sure that decorators have the same args as the objective way of creating the iterators.
Decorators can be used both in declarative way:
and functional way:
Both are tested.
Affected modules and functionalities:
JAX plugin
Key points relevant for the review:
Is the documentation enough? Are there any issues with the API? Is it tested well enough?
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: DALI-3620