Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax: incorrect data_iterator documentation on pipeline arguments #5136

Closed
ConnorBaker opened this issue Nov 1, 2023 · 3 comments
Closed
Assignees
Labels
JAX Issues related to DALI and JAX integration

Comments

@ConnorBaker
Copy link

ConnorBaker commented Nov 1, 2023

New as of #5050 is the data_iterator decorator for JAX.

Its documentation states

Decorated function should return DALI pipeline definition function. Decorator accepts
all arguments of :meth:`nvidia.dali.plugin.base_iterator.DALIGenericIterator.__init__` and
passes them to the iterator constructor. It also accepts all arguments of
:meth:`nvidia.dali.pipeline.pipeline_def` and passes them to the pipeline definition
function.

However, the function definition seems to contradict this:

def data_iterator(
pipeline_fn=None,
output_map=[],
size=-1,
reader_name=None,
auto_reset=False,
last_batch_padded=False,
last_batch_policy=LastBatchPolicy.FILL,
prepare_first_batch=True,
sharding=None):

@awolant is this documentation correct? It seems to me that it doesn't accept any pipeline arguments or forward them to the pipeline function -- I don't see kwargs or arguments the pipeline decorator would accept present.

I discovered this problem trying to use the data_iterator to forward arguments to my pipeline function. Am I missing something?

(Also, thank you for improving the JAX iterator -- it's greatly appreciated!)

@jantonguirao jantonguirao self-assigned this Nov 1, 2023
@JanuszL JanuszL assigned awolant and unassigned jantonguirao Nov 2, 2023
@awolant
Copy link
Contributor

awolant commented Nov 2, 2023

Hello @ConnorBaker thanks for your question.

Nicely spotted. Unfortunately, this sentence of the docs is incorrect. It slipped through from a test version of this functionality where something like this was enabled. Apologies for the confusion. I created a fix: #5140

For now, the way you can pass pipeline_def parameters through the decorated function call:

@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs)
    images = fn.resize(images, resize_x=256, resize_y=256)

    return images, labels

iterator = iterator_fn(batch_size=8)

I decided to skip this argument forwarding for now. Would enabling this feature in the future help you somehow with working on DALI & JAX workflows?

@ConnorBaker
Copy link
Author

No worries! It was more of a matter of confusion. When I’ve got callables wrapping callables returning callables it’s easy to get confused about which application to provide different arguments to.

@awolant
Copy link
Contributor

awolant commented Nov 3, 2023

I merged the fix and I am closing the issue. If you have any more questions please reopen it or create a new one. Thanks!

@awolant awolant closed this as completed Nov 3, 2023
@JanuszL JanuszL added this to the Release_1.32.0 milestone Nov 3, 2023
@awolant awolant added the JAX Issues related to DALI and JAX integration label Nov 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
JAX Issues related to DALI and JAX integration
Projects
None yet
Development

No branches or pull requests

4 participants