-
Notifications
You must be signed in to change notification settings - Fork 621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MNIST example for DALI and PyTorch Lightning #2360
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
!build |
CI MESSAGE: [1701254]: BUILD STARTED |
I didn't read the whole thing, but how about bringing the |
CI MESSAGE: [1701254]: BUILD FAILED |
"\n", | ||
"This example shows how to use DALI in PyTorch Lightning.\n", | ||
"\n", | ||
"Let us grab [a toy example](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html) of the clasification network and let us see how DALI can accelerate it.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Let us grab [a toy example](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html) of the clasification network and let us see how DALI can accelerate it.\n", | |
"Let us grab [a toy example](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html) showcasing a classification network and see how DALI can accelerate it.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"\n", | ||
"Let us grab [a toy example](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html) of the clasification network and let us see how DALI can accelerate it.\n", | ||
"\n", | ||
"DALI_EXTRA_PATH environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"DALI_EXTRA_PATH environment variable should point to the place where data from DALI extra repository is downloaded. Please make sure that the proper release tag is checked out." | |
"The DALI_EXTRA_PATH environment variable should point to a [DALI extra](https://github.com/NVIDIA/DALI_extra) copy. Please make sure that the proper release tag, the one associated with your DALI version, is checked out." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now let us implement the bare training class with the native data loader" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Now let us implement the bare training class with the native data loader" | |
"We will start by implement a training class that uses the native data loader" |
I feel there's too much of "let us ..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me fix it...
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now let us define a DALI pipeline which would load the data." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The next step is to define a DALI pipeline that will be used for loading and pre-processing data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Add DALI to the data preparation step in the training class and adjust the `process_batch` to accept the data returned by DALIClassificationIterator, which returns a list of dictionaries, where each list element corresponds to one pipeline wrapped by the DALIIterator, and entries in the dictionary corresponds to the relevant outputs. Check for details in the DALIGenericIterator documenation." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we are ready to modify the training class to use the DALI pipeline we have just defined. Because we want to integrate with PyTorch, we wrap our pipeline with a PyTorch DALI iterator, that can replace the native data loader with some minor changes in the code. The DALI iterator returns a list dictionaries, where each element in the list corresponds to a pipeline instance, and the entries in the dictionary map to the outputs of the pipeline. For more information, check the documentation of DALIGenericIterator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
" num_shards = self.trainer.world_size\n", | ||
" mnist_pipeline = MnistPipeline(BATCH_SIZE, device='cpu', device_id=device_id, shard_id=shard_id, num_shards=num_shards, num_threads=8)\n", | ||
"\n", | ||
" class LightingWrapper(DALIClassificationIterator):\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are several typos Lighting -> Lightning. Please search and replace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now let us provide the custom DALI iterator wrapper so we don't have to do any extra processing inside `LitMNIST.process_batch`, also PyTorch can learn how big is the dataset" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Now let us provide the custom DALI iterator wrapper so we don't have to do any extra processing inside `LitMNIST.process_batch`, also PyTorch can learn how big is the dataset" | |
"For even better integration, we can provide a custom DALI iterator wrapper so that no extra processing is required inside `LitMNIST.process_batch`. Also, PyTorch can learn the size of the dataset this way. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -2,7 +2,7 @@ | |||
|
|||
# used pip packages | |||
# TODO(janton): remove explicit pillow version installation when torch fixes the issue with PILLOW_VERSION not being defined |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just wondering, should we try to remove the version pin on pillow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will try
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Now, let us rerun the training:" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Now, let us rerun the training:" | |
"Let us run the training one more time:" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
" def __next__(self):\n", | ||
" out = super().__next__()\n", | ||
" # DALIClassificationIterator calls next during the construction\n", | ||
" # so first brach would be already converted to a list not a dict, no need to post process it\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
" # so first brach would be already converted to a list not a dict, no need to post process it\n", | |
" # so first batch would be already converted to a list not a dict, no need to post process it\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
!build |
CI MESSAGE: [1703974]: BUILD STARTED |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Let us run the training one more timemodel = BetterDALILitMNIST()" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Let us run the training one more timemodel = BetterDALILitMNIST()" | |
"Let us run the training one more time" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed that. Done
CI MESSAGE: [1703974]: BUILD PASSED |
Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
62bf0c0
to
b31e9bc
Compare
!build |
CI MESSAGE: [1708056]: BUILD STARTED |
CI MESSAGE: [1708056]: BUILD PASSED |
Signed-off-by: Janusz Lisiecki jlisiecki@nvidia.com
Why we need this PR?
Pick one, remove the rest
What happened in this PR?
Fill relevant points, put NA otherwise. Replace anything inside []
new example of DALI + PyTorch Lightning integration
examples
NA
CI
new example is added
JIRA TASK: [DALI-1660]