diff --git a/CHANGELOG.md b/CHANGELOG.md index ae0515cf22703..ba00fb2243979 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -150,10 +150,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -- +- Fixed `Trainer(move_metrics_to_cpu=True)` not moving the evaluation logged results to CPU ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631)) -- +- Fixed the `{validation,test}_step` outputs getting moved to CPU with `Trainer(move_metrics_to_cpu=True)` ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631)) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index b4660c96a0989..102603f20302b 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -24,7 +24,6 @@ from pytorch_lightning.trainer.progress import BatchProgress from pytorch_lightning.utilities.auto_restart import MergedIteratorState, reload_dataloader_state_dict from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher -from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -134,10 +133,13 @@ def advance( self.trainer.logger_connector.update_eval_step_metrics() # track epoch level outputs - if self._should_track_batch_outputs_for_epoch_end(): - output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu) - if output is not None: - self.outputs.append(output) + if self._should_track_batch_outputs_for_epoch_end() and output is not None: + self.outputs.append(output) + + if self.trainer.move_metrics_to_cpu: + # the evaluation step output is not moved as they are not considered "metrics" + assert self.trainer._results is not None + self.trainer._results.cpu() if not self.batch_progress.is_last_batch: # if fault tolerant is enabled and process has been notified, exit. diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 88229effbc8c9..66c91eaf15f1b 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -26,6 +26,7 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf def test__validation_step__log(tmpdir): @@ -699,3 +700,26 @@ def test_filter_metrics_for_dataloader(kwargs, expected): """Logged metrics should only include metrics from the concerned dataloader.""" actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs) assert actual == expected + + +@RunIf(min_gpus=1) +def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir): + class TestModel(BoringModel): + def validation_step(self, *args): + x = torch.tensor(2.0, requires_grad=True, device=self.device) + y = x * 2 + assert x.requires_grad is True + assert y.grad_fn is None # disabled by validation + + self.log("foo", y) + return y + + def validation_epoch_end(self, outputs): + # the step outputs were not moved + assert all(o.device == self.device for o in outputs), outputs + # but the logging results were + assert self.trainer.callback_metrics["foo"].device.type == "cpu" + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=2, move_metrics_to_cpu=True, gpus=1) + trainer.validate(model, verbose=False)