Skip to content

Commit

Permalink
Clear dataloader references before attaching new dataloaders to Train…
Browse files Browse the repository at this point in the history
…er (#8442)

* regression test

* apply fix

* simplify test and docs

* update changlog
  • Loading branch information
awaelchli committed Jul 19, 2021
1 parent 374fae5 commit 1bfa29a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -477,6 +477,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed progress bar updates for Pod Training ([#8258](https://github.com/PyTorchLightning/pytorch-lightning/pull/8258))


- Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}´ runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442))


## [1.3.8] - 2021-07-01

### Fixed
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/connectors/data_connector.py
Expand Up @@ -117,15 +117,19 @@ def attach_dataloaders(
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloaders is not None:
self.trainer.train_dataloader = None
model.train_dataloader = _PatchDataLoader(train_dataloaders)

if val_dataloaders is not None:
self.trainer.val_dataloaders = None
model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
self.trainer.test_dataloaders = None
model.test_dataloader = _PatchDataLoader(test_dataloaders)

if predict_dataloaders is not None:
self.trainer.predict_dataloaders = None
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)

def attach_datamodule(
Expand Down
43 changes: 43 additions & 0 deletions tests/trainer/test_dataloaders.py
Expand Up @@ -1503,6 +1503,49 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir):
assert call['name'] == expected


def test_dataloaders_reset_and_attach(tmpdir):
"""
Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before
attaching the new one.
"""
dataloader_0 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_1 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_2 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_3 = DataLoader(dataset=RandomDataset(32, 64))
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)

# 1st fit
trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1)
assert trainer.train_dataloader.loaders is dataloader_0
assert trainer.val_dataloaders[0] is dataloader_1
# 2nd fit
trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
assert trainer.train_dataloader.loaders is dataloader_2
assert trainer.val_dataloaders[0] is dataloader_3

# 1st validate
trainer.validate(model, dataloaders=dataloader_0)
assert trainer.val_dataloaders[0] is dataloader_0
# 2nd validate
trainer.validate(model, dataloaders=dataloader_1)
assert trainer.val_dataloaders[0] is dataloader_1

# 1st test
trainer.test(model, dataloaders=dataloader_0)
assert trainer.test_dataloaders[0] is dataloader_0
# 2nd test
trainer.test(model, dataloaders=dataloader_1)
assert trainer.test_dataloaders[0] is dataloader_1

# 1st predict
trainer.predict(model, dataloaders=dataloader_0)
assert trainer.predict_dataloaders[0] is dataloader_0
# 2nd predict
trainer.predict(model, dataloaders=dataloader_1)
assert trainer.predict_dataloaders[0] is dataloader_1


def test_replace_sampler_with_multiprocessing_context(tmpdir):
"""
This test verifies that replace_sampler conserves multiprocessing context
Expand Down

0 comments on commit 1bfa29a

Please sign in to comment.