Skip to content

Commit

Permalink
Add default arg values to JAX decorator (#5115)
Browse files Browse the repository at this point in the history
* Add default arg values to JAX decorator 

Adds default values for device_id and num_threads for decorated JAX iterator functions. Thanks to this change it is easier to apply sharding to scale up the computation in the most common pattern.

Signed-off-by: Albert Wolant <awolant@nvidia.com>
  • Loading branch information
awolant committed Oct 27, 2023
1 parent 969a158 commit 40a1d9f
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 30 deletions.
24 changes: 24 additions & 0 deletions dali/python/nvidia/dali/plugin/jax/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,20 @@ def _assert_shards_shapes(self, category_outputs):
"Shards shapes have to be the same."


def default_num_threads_value():
"""Returns default value for num_threads argument of DALI iterator decorator.
.. note::
This value should not be considered as optimized for any particular workload. For best
performance, it is recommended to set this value manually.
.. note::
This value is subject to change in the future.
"""

return 4


def data_iterator_impl(
iterator_type,
pipeline_fn=None,
Expand All @@ -246,7 +260,15 @@ def data_iterator_decorator(func):
def create_iterator(*args, **wrapper_kwargs):
pipeline_def_fn = pipeline_def(func)

if 'num_threads' not in wrapper_kwargs:
wrapper_kwargs['num_threads'] = default_num_threads_value()

if sharding is None:
if 'device_id' not in wrapper_kwargs:
# Due to https://github.com/google/jax/issues/16024 the best we can do is to
# assume that the first device is the one we want to use.
wrapper_kwargs['device_id'] = 0

pipelines = [pipeline_def_fn(*args, **wrapper_kwargs)]
else:
pipelines = []
Expand Down Expand Up @@ -301,6 +323,8 @@ def data_iterator(
passes them to the iterator constructor. It also accepts all arguments of
:meth:`nvidia.dali.pipeline.pipeline_def` and passes them to the pipeline definition
function.
If no `device_id` argument is passed to the decorated function, it is assumed that
the first device is the one we want to use and `device_id` is set to 0.
If the same argument is passed to the decorator and the decorated function,
exception is raised.
Expand Down
15 changes: 15 additions & 0 deletions dali/test/python/jax_plugin/test_iterator_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@ def iterator_function():
run_and_assert_sequential_iterator(iter)


def test_dali_iterator_decorator_declarative_with_default_args():
# given
@data_iterator(
output_map=['data'],
reader_name='reader')
def iterator_function():
return iterator_function_def()

iter = iterator_function(
batch_size=batch_size)

# then
run_and_assert_sequential_iterator(iter)


def test_dali_iterator_decorator_declarative_pipeline_fn_with_argument():
# given
@data_iterator(
Expand Down
20 changes: 20 additions & 0 deletions dali/test/python/jax_plugin/test_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,23 @@ def iterator_function(shard_id, num_shards):

# then
run_sharded_iterator_test(data_iterator_instance)


def test_dali_sequential_iterator_decorator_non_default_device():
# given
@data_iterator(
output_map=['data'],
reader_name='reader')
def iterator_function():
return iterator_function_def()

# when
iter = iterator_function(
num_threads=4,
device_id=1,
batch_size=batch_size)

batch = next(iter)

# then
assert batch['data'].device_buffers[0].device() == jax.devices()[1]
180 changes: 150 additions & 30 deletions docs/examples/frameworks/jax/jax-getting_started.ipynb

Large diffs are not rendered by default.

0 comments on commit 40a1d9f

Please sign in to comment.