diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 1264fa458ca41..8608e84ab21f6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -109,7 +109,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # optionally can be set by user self._example_input_array = None self._current_fx_name: Optional[str] = None - self._current_dataloader_idx: Optional[int] = None self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = {} @@ -419,7 +418,6 @@ def log( reduce_fx=reduce_fx, enable_graph=enable_graph, add_dataloader_idx=add_dataloader_idx, - dataloader_idx=self._current_dataloader_idx, batch_size=batch_size, sync_dist=sync_dist and distributed_available(), sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index d9f6ea8c0b181..5971cc9df776f 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -100,14 +100,16 @@ def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader.""" void(*args, **kwargs) - dataloader_idx: int = self.current_dataloader_idx + dataloader_idx = self.current_dataloader_idx dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader) self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader( dataloader, dataloader_idx=dataloader_idx ) dl_max_batches = self._max_batches[dataloader_idx] - dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) + dl_outputs = self.epoch_loop.run( + dataloader, dataloader_idx if self.num_dataloaders > 1 else None, dl_max_batches + ) # store batch level output per dataloader self._outputs.append(dl_outputs) @@ -212,17 +214,13 @@ def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None: # inform logger the batch loop has finished self.trainer.logger_connector.epoch_end_reached() - # call the model epoch end - model = self.trainer.lightning_module - - # unset dataloader_idx in model - model._current_dataloader_idx = None - # with a single dataloader don't pass a 2D list output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = ( outputs[0] if len(outputs) > 0 and self.num_dataloaders == 1 else outputs ) + # call the model epoch end + model = self.trainer.lightning_module if self.trainer.testing: if is_overridden("test_epoch_end", model): model._current_fx_name = "test_epoch_end" diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 70fbefd1d1cc6..c0e76cba60ffe 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -48,7 +48,6 @@ def __init__(self) -> None: self._outputs: EPOCH_OUTPUT = [] self._dl_max_batches = 0 - self._num_dataloaders = 0 self._dataloader_iter: Optional[Iterator] = None self._data_fetcher: Optional[AbstractDataFetcher] = None self._dataloader_state_dict: Dict[str, Any] = {} @@ -61,7 +60,6 @@ def done(self) -> bool: def reset(self) -> None: """Resets the loop's internal state.""" self._dl_max_batches = 0 - self._num_dataloaders = 0 self._data_fetcher = None self._outputs = [] @@ -71,7 +69,7 @@ def reset(self) -> None: self.batch_progress.reset_on_restart() def on_run_start( # type: ignore[override] - self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int + self, data_fetcher: AbstractDataFetcher, dataloader_idx: Optional[int], dl_max_batches: int ) -> None: """Adds the passed arguments to the loop's state if necessary. @@ -79,18 +77,16 @@ def on_run_start( # type: ignore[override] data_fetcher: the current data_fetcher wrapping the dataloader dataloader_idx: index of the current dataloader dl_max_batches: maximum number of batches the dataloader can produce - num_dataloaders: the total number of dataloaders """ void(dataloader_idx) self._dl_max_batches = dl_max_batches - self._num_dataloaders = num_dataloaders self._data_fetcher = data_fetcher self._reload_dataloader_state_dict(data_fetcher) self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_progress.current.ready) def advance( # type: ignore[override] - self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int + self, data_fetcher: AbstractDataFetcher, dataloader_idx: Optional[int], dl_max_batches: int ) -> None: """Calls the evaluation step with the corresponding hooks and updates the logger connector. @@ -98,12 +94,11 @@ def advance( # type: ignore[override] data_fetcher: iterator over the dataloader dataloader_idx: index of the current dataloader dl_max_batches: maximum number of batches the dataloader can produce - num_dataloaders: the total number of dataloaders Raises: StopIteration: If the current batch is None """ - void(dl_max_batches, num_dataloaders) + void(dl_max_batches) assert self._dataloader_iter is not None batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter) @@ -113,24 +108,27 @@ def advance( # type: ignore[override] if not data_fetcher.store_on_device: with self.trainer.profiler.profile("evaluation_batch_to_device"): - batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx) + batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=(dataloader_idx or 0)) self.batch_progress.increment_ready() + # configure step_kwargs + kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) + # hook - self._on_evaluation_batch_start(batch, batch_idx, dataloader_idx) + self._on_evaluation_batch_start(**kwargs) self.batch_progress.increment_started() # lightning module methods with self.trainer.profiler.profile("evaluation_step_and_end"): - output = self._evaluation_step(batch, batch_idx, dataloader_idx) + output = self._evaluation_step(**kwargs) output = self._evaluation_step_end(output) self.batch_progress.increment_processed() # track loss history - self._on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) + self._on_evaluation_batch_end(output, **kwargs) self.batch_progress.increment_completed() @@ -208,7 +206,7 @@ def _num_completed_batches_reached(self) -> bool: def _has_completed(self) -> bool: return self.batch_progress.current.ready == self.batch_progress.current.completed - def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: + def _evaluation_step(self, **kwargs: Any) -> Optional[STEP_OUTPUT]: """The evaluation step (validation_step or test_step depending on the trainer's state). Args: @@ -219,17 +217,14 @@ def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> O Returns: the outputs of the step """ - # configure step_kwargs - step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - if self.trainer.testing: self.trainer.lightning_module._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): - output = self.trainer.accelerator.test_step(*step_kwargs.values()) + output = self.trainer.accelerator.test_step(*kwargs.values()) else: self.trainer.lightning_module._current_fx_name = "validation_step" with self.trainer.profiler.profile("validation_step"): - output = self.trainer.accelerator.validation_step(*step_kwargs.values()) + output = self.trainer.accelerator.validation_step(*kwargs.values()) return output @@ -239,7 +234,7 @@ def _evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPU output = self.trainer.call_hook(hook_name, *args, **kwargs) return output - def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def _on_evaluation_batch_start(self, **kwargs: Any) -> None: """Calls the ``on_{validation/test}_batch_start`` hook. Args: @@ -250,19 +245,15 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: Raises: AssertionError: If the number of dataloaders is None (has not yet been set). """ - self.trainer.logger_connector.on_batch_start(batch_idx, batch) - - assert self._num_dataloaders is not None - self.trainer.logger_connector.on_evaluation_batch_start(dataloader_idx, self._num_dataloaders) + self.trainer.logger_connector.on_batch_start(**kwargs) + kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these if self.trainer.testing: - self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx) + self.trainer.call_hook("on_test_batch_start", *kwargs.values()) else: - self.trainer.call_hook("on_validation_batch_start", batch, batch_idx, dataloader_idx) + self.trainer.call_hook("on_validation_batch_start", *kwargs.values()) - def _on_evaluation_batch_end( - self, output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any) -> None: """The ``on_{validation/test}_batch_end`` hook. Args: @@ -271,12 +262,13 @@ def _on_evaluation_batch_end( batch_idx: The index of the current batch dataloader_idx: Index of the dataloader producing the current batch """ + kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end" - self.trainer.call_hook(hook_name, output, batch, batch_idx, dataloader_idx) + self.trainer.call_hook(hook_name, output, *kwargs.values()) self.trainer.logger_connector.on_batch_end() - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> Dict[str, Union[Any, int]]: """Helper function to build the arguments for the current step. Args: @@ -289,13 +281,8 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict """ # make dataloader_idx arg in validation_step optional step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) - - multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1 - multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1 - - if multiple_test_loaders or multiple_val_loaders: + if dataloader_idx is not None: step_kwargs["dataloader_idx"] = dataloader_idx - return step_kwargs @lru_cache(1) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index ad5fa1b649b23..086f5b17b8951 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -158,7 +158,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov self.batch_progress.increment_ready() - self.trainer.logger_connector.on_batch_start(batch_idx, batch) + self.trainer.logger_connector.on_batch_start(batch, batch_idx) if batch is None: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 014ce9623e9dc..20a59fb440357 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -139,17 +139,12 @@ def _increment_eval_log_step(self) -> None: elif self.trainer.state.stage is RunningStage.TESTING: self._test_log_step += 1 - def on_evaluation_batch_start(self, dataloader_idx: int, num_dataloaders: int) -> None: - model = self.trainer.lightning_module - # set dataloader_idx only if multiple ones - model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None - def update_eval_step_metrics(self) -> None: + assert not self._epoch_end_reached if self.trainer.sanity_checking: return # logs user requested information to logger - assert not self._epoch_end_reached self.log_metrics(self.metrics["log"], step=self._eval_log_step) # increment the step even if nothing was logged @@ -259,23 +254,29 @@ def _log_gpus_metrics(self) -> None: def on_epoch_start(self) -> None: self._epoch_end_reached = False - def on_batch_start(self, batch_idx: int, batch: Any) -> None: + def on_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> None: self._batch_idx = batch_idx self._epoch_end_reached = False - assert self.trainer._results is not None + results = self.trainer._results + assert results is not None # attach reference to the new batch and remove the cached batch_size - self.trainer._results.batch = batch - self.trainer._results.batch_size = None + results.batch = batch + results.batch_size = None + results.dataloader_idx = dataloader_idx def epoch_end_reached(self) -> None: self._epoch_end_reached = True self._batch_idx = None self._split_idx = None - assert self.trainer._results is not None def on_epoch_end(self) -> None: assert self._epoch_end_reached + results = self.trainer._results + assert results is not None + # we need to reset this index before the `self.metrics` call below + results.dataloader_idx = None + metrics = self.metrics self._progress_bar_metrics.update(metrics["pbar"]) self._callback_metrics.update(metrics["callback"]) @@ -308,8 +309,9 @@ def reset_metrics(self) -> None: self._callback_metrics = {} def reset_results(self) -> None: - if self.trainer._results is not None: - self.trainer._results.reset() + results = self.trainer._results + if results is not None: + results.reset() self._batch_idx = None self._split_idx = None diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 47375088430ea..4878099afc524 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -393,6 +393,7 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = self.device: Optional[Union[str, torch.device]] = device self.batch: Optional[Any] = None self.batch_size: Optional[int] = None + self.dataloader_idx: Optional[int] = None @property def result_metrics(self) -> List[ResultMetric]: @@ -436,7 +437,6 @@ def log( sync_dist_fn: Callable = _Sync.no_op, sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, - dataloader_idx: Optional[int] = None, batch_size: Optional[int] = None, metric_attribute: Optional[str] = None, rank_zero_only: bool = False, @@ -453,9 +453,9 @@ def log( # storage key key = f"{fx}.{name}" # add dataloader_suffix to both key and fx - if add_dataloader_idx and dataloader_idx is not None: - key += f".{dataloader_idx}" - fx += f".{dataloader_idx}" + if add_dataloader_idx and self.dataloader_idx is not None: + key += f".{self.dataloader_idx}" + fx += f".{self.dataloader_idx}" meta = _Metadata( fx=fx, @@ -467,7 +467,7 @@ def log( reduce_fx=reduce_fx, enable_graph=enable_graph, add_dataloader_idx=add_dataloader_idx, - dataloader_idx=dataloader_idx, + dataloader_idx=self.dataloader_idx, metric_attribute=metric_attribute, ) meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, _group=sync_dist_group, rank_zero_only=rank_zero_only) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bb34db9c0d61f..2bed020c501bd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1430,7 +1430,6 @@ def _call_teardown_hook(self) -> None: self.call_hook("teardown", stage=fn) self.lightning_module._current_fx_name = None - self.lightning_module._current_dataloader_idx = None # these could have become stale if metrics are defined in `setup` self.lightning_module._metric_attributes = None diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 72eeb197e9e57..ed4f5169cb1cb 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -81,6 +81,7 @@ def test_loops_state_dict_structure(): "epoch_loop.val_loop._results": { "batch": None, "batch_size": None, + "dataloader_idx": None, "training": False, "device": None, "items": {}, @@ -88,6 +89,7 @@ def test_loops_state_dict_structure(): "epoch_loop._results": { "batch": None, "batch_size": None, + "dataloader_idx": None, "training": True, "device": None, "items": {}, @@ -109,6 +111,7 @@ def test_loops_state_dict_structure(): "_results": { "batch": None, "batch_size": None, + "dataloader_idx": None, "training": False, "device": None, "items": {}, @@ -126,6 +129,7 @@ def test_loops_state_dict_structure(): "_results": { "batch": None, "batch_size": None, + "dataloader_idx": None, "training": False, "device": None, "items": {}, diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 82303dbf36865..04ef7568e13ca 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -393,8 +393,8 @@ def make_logging(self, pl_module, func_name, on_steps, on_epochs, prob_bars): pl_module.log(custom_func_name, self.count, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) num_dl_ext = "" - if pl_module._current_dataloader_idx is not None: - dl_idx = pl_module._current_dataloader_idx + dl_idx = pl_module.trainer._results.dataloader_idx + if dl_idx is not None: num_dl_ext = f"/dataloader_idx_{dl_idx}" func_name += num_dl_ext @@ -471,13 +471,13 @@ def test_dataloader(self): assert cb.funcs_called_count["on_test_batch_end"] == 4 assert cb.funcs_called_count["on_test_epoch_end"] == 1 - callback_metrics_keys = list(trainer.callback_metrics) - for func_name in cb.callback_funcs_called.keys(): - is_in = False - for callback_metrics_key in callback_metrics_keys: - if func_name in callback_metrics_key: - is_in = True - assert is_in, (func_name, callback_metrics_keys) + callback_metrics = trainer.callback_metrics + for func_name in cb.callback_funcs_called: + for key in callback_metrics: + if func_name in key: + break + else: + assert False, (func_name, list(callback_metrics)) def get_expected(on_epoch, values): reduction = np.mean if on_epoch else np.max diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index dc7070c9c2ac0..f904cc7b196ec 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1876,11 +1876,9 @@ def test_module_current_fx_attributes_reset(tmpdir): trainer.fit(model) assert model._current_fx_name is None - assert model._current_dataloader_idx is None trainer.test(model) assert model._current_fx_name is None - assert model._current_dataloader_idx is None def test_exception_when_lightning_module_is_not_set_on_trainer():