Skip to content

Commit

Permalink
Add Support for multiple train loaders (#1959)
Browse files Browse the repository at this point in the history
* add support for wrong dtype in apply_func

* apply loader resetting to possible collection of loaders

* add combined loader iter class

* integrate combined loader iter to training loop

* fix imports

* fix imports

* finish supporters

* add tests for supporters

* add test for model with multiple loaders

* fix trainer integration

* fix instance check

* Train loaders (#4032)

* patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader

* update data_loading.py to it uses patch discussed in #1959

* rename class

* Separate CombinedLoaderIterator into two classes, and update related tests. (#4606)

* Fix the bugs after rebasing.

* Add custom get_len for apply_to_collection

* Refactor MultiIterator to be as CombinedLoaderIterator

* To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py

* Reload _loader_iters when calling __iter__

* Don't transform DataLoader to CombinedLoaderIterator when it's along

* Updates test_fit_multiple_train_loaders for testing num_training_batches

* Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format.

* Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders

* Update tests for supporters

* Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders.

* Fix pep8 issues

* Add tests for train_loader_patch.py

* Add descriptions to multiple_trainloader_mode

* Remove unused variables

* Add docstrings and typing

* Add more tests for better converage

* Remove unused commented codes

* Add sampler property

* Remove extract_dataset

* Update typing

* pep8

* Update train_loader_patch.py

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/supporters.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* reviewer comments

* fix stupid import

* add docs

* add back line separator

* fix line sep

* pep8

* Apply suggestions from code review

* fix

* fix

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* flake8

Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box>
Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com>
Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
9 people committed Jan 4, 2021
1 parent b72ed71 commit d88cf4a
Show file tree
Hide file tree
Showing 12 changed files with 723 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Add Support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959))

- `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))

- `Accuracy` metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the `subset_accuracy` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))
Expand Down
31 changes: 29 additions & 2 deletions docs/source/multiple_loaders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@ Multiple Datasets
Lightning supports multiple dataloaders in a few ways.

1. Create a dataloader that iterates multiple datasets under the hood.
2. In the validation and test loop you also have the option to return multiple dataloaders
2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning
will automatically combine the batches from different loaders.
3. In the validation and test loop you also have the option to return multiple dataloaders
which lightning will call sequentially.

----------

Multiple training dataloaders
-----------------------------
For training, the best way to use multiple dataloaders is to create a ``DataLoader`` class
For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class
which wraps your multiple dataloaders (this of course also works for testing and validation
dataloaders).

Expand Down Expand Up @@ -59,6 +61,31 @@ dataloaders).
# SAME
...

However, with lightning you can also return multiple loaders and lightning will take care of batch combination.

For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer.Trainer.multiple_trainloader_mode`

.. testcode::

class LitModel(LightningModule):

def train_dataloader(self):

loader_a = torch.utils.data.DataLoader(range(6), batch_size=4)
loader_b = torch.utils.data.DataLoader(range(15), batch_size=5)

# pass loaders as a dict. This will create batches like this:
# {'a': batch from loader_a, 'b': batch from loader_b}
loaders = {'a': loader_a,
'b': loader_b}

# OR:
# pass loaders as sequence. This will create batches like this:
# [batch from loader_a, batch from loader_b]
loaders = [loader_a, loader_b]

return loaders

----------

Test/Val dataloaders
Expand Down
16 changes: 12 additions & 4 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden

from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.trainer.supporters import CombinedLoader


class TrainerDataLoadingMixin(ABC):

Expand Down Expand Up @@ -137,6 +140,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)

if (self.overfit_batches > 0):
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
Expand All @@ -147,13 +151,17 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
# debugging
self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])

self.num_training_batches = 0

# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, shuffle=True)
self.train_dataloader = apply_to_collection(
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True)

# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)

self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')

if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
Expand Down
Loading

0 comments on commit d88cf4a

Please sign in to comment.