Support checkpointing in JAX iterator #5282
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Category:
New feature (non-breaking change which adds functionality)
Description:
Checkpointing support was already implemented in
BaseIterator
in #5061, but wasn't tested for frameworks other than Pytorch. In this PR I add tests for JAX Iterator and fix a problem with ES checkpointing.When an iterator is created and there's no data in the external source, DALI reports "no data in pipeline" error. The problem is that if we checkpoint after the last iteration and restore from such checkpoint, there's no data in the external source but it's not an error. In this PR I silence this error if we're restoring from checkpoint. Exactly the same change was made in #5213.
Additional information:
Affected modules and functionalities:
JAX Iterator
Key points relevant for the review:
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: DALI-3749