Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear dataloader references before attaching new dataloaders to Trainer #8442

Merged
merged 5 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed progress bar updates for Pod Training ([#8258](https://github.com/PyTorchLightning/pytorch-lightning/pull/8258))


- Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}麓 runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442))


## [1.3.8] - 2021-07-01

### Fixed
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,19 @@ def attach_dataloaders(
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloaders is not None:
self.trainer.train_dataloader = None
model.train_dataloader = _PatchDataLoader(train_dataloaders)

if val_dataloaders is not None:
self.trainer.val_dataloaders = None
model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
self.trainer.test_dataloaders = None
model.test_dataloader = _PatchDataLoader(test_dataloaders)

if predict_dataloaders is not None:
self.trainer.predict_dataloaders = None
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)

def attach_datamodule(
Expand Down
43 changes: 43 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,49 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir):
assert call['name'] == expected


def test_dataloaders_reset_and_attach(tmpdir):
"""
Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before
attaching the new one.
"""
dataloader_0 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_1 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_2 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_3 = DataLoader(dataset=RandomDataset(32, 64))
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)

# 1st fit
trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1)
assert trainer.train_dataloader.loaders is dataloader_0
assert trainer.val_dataloaders[0] is dataloader_1
# 2nd fit
trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
assert trainer.train_dataloader.loaders is dataloader_2
carmocca marked this conversation as resolved.
Show resolved Hide resolved
assert trainer.val_dataloaders[0] is dataloader_3

# 1st validate
trainer.validate(model, dataloaders=dataloader_0)
assert trainer.val_dataloaders[0] is dataloader_0
# 2nd validate
trainer.validate(model, dataloaders=dataloader_1)
assert trainer.val_dataloaders[0] is dataloader_1

# 1st test
trainer.test(model, dataloaders=dataloader_0)
assert trainer.test_dataloaders[0] is dataloader_0
# 2nd test
trainer.test(model, dataloaders=dataloader_1)
assert trainer.test_dataloaders[0] is dataloader_1

# 1st predict
trainer.predict(model, dataloaders=dataloader_0)
assert trainer.predict_dataloaders[0] is dataloader_0
# 2nd predict
trainer.predict(model, dataloaders=dataloader_1)
assert trainer.predict_dataloaders[0] is dataloader_1


def test_replace_sampler_with_multiprocessing_context(tmpdir):
"""
This test verifies that replace_sampler conserves multiprocessing context
Expand Down