Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for multiple train loaders #1959

Merged
merged 38 commits into from Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
98aafa1
add support for wrong dtype in apply_func
justusschock May 26, 2020
f7e2405
apply loader resetting to possible collection of loaders
justusschock May 26, 2020
dc1f0d0
add combined loader iter class
justusschock May 26, 2020
4ab644f
integrate combined loader iter to training loop
justusschock May 26, 2020
6b8f317
fix imports
Jun 1, 2020
04ce0ef
fix imports
justusschock Jun 1, 2020
aa9b153
finish supporters
justusschock Jun 29, 2020
9525024
add tests for supporters
justusschock Jun 29, 2020
dab57af
add test for model with multiple loaders
justusschock Jun 29, 2020
2f867dc
fix trainer integration
justusschock Jun 29, 2020
0ae527b
fix instance check
justusschock Jun 29, 2020
608f503
Train loaders (#4032)
christofer-f Oct 10, 2020
a1ced87
rename class
justusschock Oct 28, 2020
eea6aae
Separate CombinedLoaderIterator into two classes, and update related …
ylsung Nov 30, 2020
32aeb70
pep8
justusschock Dec 8, 2020
c4482c4
Update train_loader_patch.py
justusschock Dec 8, 2020
1a531a5
Apply suggestions from code review
justusschock Dec 8, 2020
a5d3652
Update pytorch_lightning/trainer/supporters.py
justusschock Dec 8, 2020
dc94a59
reviewer comments
justusschock Dec 9, 2020
3c85cfb
fix stupid import
justusschock Dec 9, 2020
c091414
add docs
justusschock Dec 9, 2020
ebb6277
add back line separator
justusschock Dec 9, 2020
e4e50ab
fix line sep
justusschock Dec 9, 2020
0227a68
pep8
justusschock Dec 9, 2020
644a490
Apply suggestions from code review
Borda Dec 21, 2020
552e6a6
fix
rohitgr7 Dec 21, 2020
3ab1907
fix
rohitgr7 Dec 21, 2020
a2d017f
Apply suggestions from code review
Borda Dec 31, 2020
6af5c90
Apply suggestions from code review
Borda Dec 31, 2020
e138c7a
flake8
Borda Dec 31, 2020
9265651
chlog
Borda Dec 31, 2020
7669a40
Update pytorch_lightning/trainer/supporters.py
justusschock Jan 4, 2021
8e9bd3d
add missing test
justusschock Jan 4, 2021
854756b
fix dataset length
justusschock Jan 4, 2021
55037d4
Update supporters.py
justusschock Jan 4, 2021
9d42c7e
remove unused patch
justusschock Jan 4, 2021
232c7ce
remove tests of otherwise unused patch
justusschock Jan 4, 2021
59651c6
Merge branch 'release/1.2-dev' into train_loaders
tchaton Jan 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Add Support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959))


### Changed

Expand Down
31 changes: 29 additions & 2 deletions docs/source/multiple_loaders.rst
Expand Up @@ -9,14 +9,16 @@ Multiple Datasets
Lightning supports multiple dataloaders in a few ways.

1. Create a dataloader that iterates multiple datasets under the hood.
2. In the validation and test loop you also have the option to return multiple dataloaders
2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning
will automatically combine the batches from different loaders.
3. In the validation and test loop you also have the option to return multiple dataloaders
which lightning will call sequentially.

----------

Multiple training dataloaders
-----------------------------
For training, the best way to use multiple dataloaders is to create a ``DataLoader`` class
For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class
which wraps your multiple dataloaders (this of course also works for testing and validation
dataloaders).

Expand Down Expand Up @@ -59,6 +61,31 @@ dataloaders).
# SAME
...

However, with lightning you can also return multiple loaders and lightning will take care of batch combination.

For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer.Trainer.multiple_trainloader_mode`

.. testcode::

class LitModel(LightningModule):

def train_dataloader(self):

loader_a = torch.utils.data.DataLoader(range(6), batch_size=4)
loader_b = torch.utils.data.DataLoader(range(15), batch_size=5)

# pass loaders as a dict. This will create batches like this:
# {'a': batch from loader_a, 'b': batch from loader_b}
loaders = {'a': loader_a,
'b': loader_b}

# OR:
# pass loaders as sequence. This will create batches like this:
# [batch from loader_a, batch from loader_b]
loaders = [loader_a, loader_b]

return loaders

----------

Test/Val dataloaders
Expand Down
16 changes: 12 additions & 4 deletions pytorch_lightning/trainer/data_loading.py
Expand Up @@ -29,6 +29,9 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden

from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.trainer.supporters import CombinedLoader


class TrainerDataLoadingMixin(ABC):

Expand Down Expand Up @@ -137,6 +140,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)

if (self.overfit_batches > 0):
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
Expand All @@ -147,13 +151,17 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
# debugging
self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])

self.num_training_batches = 0

# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, shuffle=True)
self.train_dataloader = apply_to_collection(
justusschock marked this conversation as resolved.
Show resolved Hide resolved
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True)

# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)

self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')

if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
Expand Down