Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default arg values to JAX decorator #5115

Merged
merged 21 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions dali/python/nvidia/dali/plugin/jax/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,20 @@ def _assert_shards_shapes(self, category_outputs):
"Shards shapes have to be the same."


def default_num_threads_value():
"""Returns default value for num_threads argument of DALI iterator decorator.

.. note::
This value should not be considered as optimized for any particular workload. For best
performance, it is recommended to set this value manually.

.. note::
This value is subject to change in the future.
"""

return 4


def data_iterator_impl(
iterator_type,
pipeline_fn=None,
Expand All @@ -246,7 +260,15 @@ def data_iterator_decorator(func):
def create_iterator(*args, **wrapper_kwargs):
pipeline_def_fn = pipeline_def(func)

if 'num_threads' not in wrapper_kwargs:
wrapper_kwargs['num_threads'] = default_num_threads_value()

if sharding is None:
if 'device_id' not in wrapper_kwargs:
# Due to https://github.com/google/jax/issues/16024 the best we can do is to
# assume that the first device is the one we want to use.
wrapper_kwargs['device_id'] = 0

pipelines = [pipeline_def_fn(*args, **wrapper_kwargs)]
else:
pipelines = []
Expand Down Expand Up @@ -301,6 +323,8 @@ def data_iterator(
passes them to the iterator constructor. It also accepts all arguments of
:meth:`nvidia.dali.pipeline.pipeline_def` and passes them to the pipeline definition
function.
If no `device_id` argument is passed to the decorated function, it is assumed that
the first device is the one we want to use and `device_id` is set to 0.
If the same argument is passed to the decorator and the decorated function,
exception is raised.

Expand Down
15 changes: 15 additions & 0 deletions dali/test/python/jax_plugin/test_iterator_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@ def iterator_function():
run_and_assert_sequential_iterator(iter)


def test_dali_iterator_decorator_declarative_with_default_args():
# given
@data_iterator(
output_map=['data'],
reader_name='reader')
def iterator_function():
return iterator_function_def()

iter = iterator_function(
batch_size=batch_size)

# then
run_and_assert_sequential_iterator(iter)


def test_dali_iterator_decorator_declarative_pipeline_fn_with_argument():
# given
@data_iterator(
Expand Down
20 changes: 20 additions & 0 deletions dali/test/python/jax_plugin/test_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,23 @@ def iterator_function(shard_id, num_shards):

# then
run_sharded_iterator_test(data_iterator_instance)


def test_dali_sequential_iterator_decorator_non_default_device():
# given
@data_iterator(
output_map=['data'],
reader_name='reader')
def iterator_function():
return iterator_function_def()

# when
iter = iterator_function(
num_threads=4,
device_id=1,
batch_size=batch_size)

batch = next(iter)

# then
assert batch['data'].device_buffers[0].device() == jax.devices()[1]
180 changes: 150 additions & 30 deletions docs/examples/frameworks/jax/jax-getting_started.ipynb

Large diffs are not rendered by default.