Skip to content

Commit

Permalink
Add missing test for "multiple dataloader + percent_check fix" (#2226)
Browse files Browse the repository at this point in the history
* Init fix num_batches

* Fix num_batches in case of multiple dataloaders

* Apply suggestions from code review

* Changes based on suggestions

* Flake8

* Add test to check num_batches

* generalize dataloader percent check test

* fix formatting

* remove hparams

* tests

* CHANGELOG

* Update CHANGELOG.md

* max_batches can be int

* conflict and rebase

* add back the test


fix


fix message


0.0 works


Revert "fix message"

This reverts commit 839cacf8b8610f4e697e654ef6f3d2501bf23984.

* update changelog

* Update CHANGELOG.md

* Fix num batches in case of multiple dataloaders and percent_check (#1920)

* git conflict

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* missing union

* doc update suggestion by @rohitgr7

* extend test

* changelog

* docs add note about multiple loaders

* update changelog

* remove unused variable

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 23, 2020
1 parent 44385bb commit e085e93
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 20 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed number batches in case of multiple dataloaders and `limit_{*}_batches` ([#1920](https://github.com/PyTorchLightning/pytorch-lightning/pull/1920), [#2226](https://github.com/PyTorchLightning/pytorch-lightning/pull/2226))

- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))

- Fixed ROC metric for CUDA tensors ([#2304](https://github.com/PyTorchLightning/pytorch-lightning/pull/2304))
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ def on_train_end(self, trainer, pl_module):
# run for only 10 batches
trainer = Trainer(limit_test_batches=10)
In the case of multiple test dataloaders, the limit applies to each dataloader individually.
limit_val_batches
^^^^^^^^^^^^^^^^^
Expand All @@ -473,6 +475,8 @@ def on_train_end(self, trainer, pl_module):
# run for only 10 batches
trainer = Trainer(limit_val_batches=10)
In the case of multiple validation dataloaders, the limit applies to each dataloader individually.
log_gpu_memory
^^^^^^^^^^^^^^
Options:
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,6 @@ def _reset_eval_dataloader(
for i, dataloader in enumerate(dataloaders):
num_batches = 0
self._worker_check(dataloader, f'{mode} dataloader {i}')
if not _has_len(dataloader):
num_batches = float('inf')

# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
Expand Down
21 changes: 16 additions & 5 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@

from abc import ABC, abstractmethod
from pprint import pprint
from typing import Callable, Optional, List
from typing import Callable, Optional, List, Union

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -222,13 +222,20 @@ def reset_test_dataloader(self, *args):
def reset_val_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def _evaluate(self, model: LightningModule, dataloaders, max_batches: List[int], test_mode: bool = False):
def _evaluate(
self,
model: LightningModule,
dataloaders: List[DataLoader],
max_batches: Union[int, List[int]],
test_mode: bool = False
):
"""Run evaluation code.
Args:
model: PT model
dataloaders: list of PT dataloaders
max_batches: List of scalars
model: The model to evaluate.
dataloaders: A list of PyTorch dataloaders.
max_batches: An integer or list of integers with length of the number of dataloaders. Each
entry is the number of batches to process in the corresponding dataloader.
test_mode:
"""
# enable eval mode
Expand All @@ -244,6 +251,10 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: List[int],
# bookkeeping
outputs = []

# convert max_batches to list
if isinstance(max_batches, int):
max_batches = [max_batches] * len(dataloaders)

# run validation
for dataloader_idx, dataloader in enumerate(dataloaders):
dl_outputs = []
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __init__(
min_steps: Force training for at least these number of steps. Disabled by default (None).
limit_train_batches: How much of training dataset to check.
limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)
limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ class TrainerTrainLoopMixin(ABC):
check_val_every_n_epoch: ...
num_training_batches: int
val_check_batch: ...
num_val_batches: int
disable_validation: bool
fast_dev_run: ...
accumulation_scheduler: ...
Expand Down
7 changes: 6 additions & 1 deletion tests/base/model_test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class TestDataloaderVariations(ABC):

@abstractmethod
def dataloader(self, train: bool):
def dataloader(self, *args, **kwargs):
"""placeholder"""

def test_dataloader(self):
Expand All @@ -19,6 +19,11 @@ def test_dataloader__infinite(self):
def test_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))

def test_dataloader__multiple_mixed_length(self):
lengths = [50, 30, 40]
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
return dataloaders

def test_dataloader__empty(self):
return None

Expand Down
4 changes: 2 additions & 2 deletions tests/base/model_test_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class TestEpochEndVariations(ABC):

def test_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
Called at the end of test epoch to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
Expand Down Expand Up @@ -40,7 +40,7 @@ def test_epoch_end(self, outputs):

def test_epoch_end__multiple_dataloaders(self, outputs):
"""
Called at the end of validation to aggregate outputs
Called at the end of test epoch to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/base/model_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
class ModelTemplateData:
hparams: ...

def dataloader(self, train):
dataset = TrialMNIST(root=self.data_root, train=train, download=True)
def dataloader(self, train: bool, num_samples: int = 100):
dataset = TrialMNIST(root=self.data_root, train=train, num_samples=num_samples, download=True)

loader = DataLoader(
dataset=dataset,
Expand Down
7 changes: 6 additions & 1 deletion tests/base/model_valid_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
class ValDataloaderVariations(ABC):

@abstractmethod
def dataloader(self, train: bool):
def dataloader(self, *args, **kwargs):
"""placeholder"""

def val_dataloader(self):
return self.dataloader(train=False)

def val_dataloader__multiple_mixed_length(self):
lengths = [100, 30]
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
return dataloaders

def val_dataloader__multiple(self):
return [self.dataloader(train=False),
self.dataloader(train=False)]
Expand Down
2 changes: 1 addition & 1 deletion tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _mean(res, key):
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results

def validation_epoch_end_multiple_dataloaders(self, outputs):
def validation_epoch_end__multiple_dataloaders(self, outputs):
"""
Called at the end of validation to aggregate outputs
Expand Down
4 changes: 2 additions & 2 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=[1])
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': torch.tensor(0.6)}

model = ModelVer0_7(hparams)
Expand All @@ -153,5 +153,5 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=[1])
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': torch.tensor(0.7)}
79 changes: 77 additions & 2 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_multiple_val_dataloader(tmpdir):
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end_multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders

# fit model
trainer = Trainer(
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):

model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end_multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders

# train, multiple val and multiple test passed to fit
Expand All @@ -251,6 +251,81 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0.0, 0.0, 0.0),
pytest.param(0, 0, 0.5),
pytest.param(1.0, 1.0, 1.0),
pytest.param(0.2, 0.4, 0.4),
]
)
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

# train, multiple val and multiple test passed with percent_check
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)
trainer.fit(model)
expected_train_batches = int(len(trainer.train_dataloader) * limit_train_batches)
expected_val_batches = [
int(len(dataloader) * limit_val_batches) for dataloader in trainer.val_dataloaders
]
assert trainer.num_training_batches == expected_train_batches
assert trainer.num_val_batches == expected_val_batches

trainer.test(ckpt_path=None)
expected_test_batches = [
int(len(dataloader) * limit_test_batches) for dataloader in trainer.test_dataloaders
]
assert trainer.num_test_batches == expected_test_batches


@pytest.mark.parametrize(
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
[
pytest.param(0, 0, 0),
pytest.param(1, 2, 3),
pytest.param(1, 2, 1e50),
]
)
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

# train, multiple val and multiple test passed with percent_check
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
)
trainer.fit(model)
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
trainer.test(ckpt_path=None)
assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders)


@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
"""Verify that dataloaders can be passed to fit"""
Expand Down

0 comments on commit e085e93

Please sign in to comment.