diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index f39ecdb2bdc25..56293dbad729f 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -42,6 +42,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646)) +- Added support for `predict_step(dataloader_iter, batch_index)` ([#16726](https://github.com/Lightning-AI/lightning/pull/16726)) + + +- Added support for arbitrary iterables as dataloaders ([#16726](https://github.com/Lightning-AI/lightning/pull/16726)) + + - Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743), [#16784](https://github.com/Lightning-AI/lightning/pull/16784)) ### Changed @@ -87,6 +93,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Renamed `CombinedLoader.loaders` to `CombinedLoader.iterables` ([#16743](https://github.com/Lightning-AI/lightning/pull/16743)) +- The top-level loops now own the data sources and combined dataloaders ([#16726](https://github.com/Lightning-AI/lightning/pull/16726)) + + +- The `trainer.*_dataloader` properties now return what the user returned in their `LightningModule.*_dataloader()` hook ([#16726](https://github.com/Lightning-AI/lightning/pull/16726)) + + - The `dataloader_idx` argument is now optional for the `on_{validation,test,predict}_batch_{start,end}` hooks. Remove it or default it to 0 if you don't use multiple dataloaders ([#16753](https://github.com/Lightning-AI/lightning/pull/16753)) @@ -210,6 +222,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * The fetching classes are now marked as protected ([#16664](https://github.com/Lightning-AI/lightning/pull/16664)) +- Removed the `DataLoaderLoop`, `EvaluationEpochLoop`, and `PredictionEpochLoop` classes ([#16726](https://github.com/Lightning-AI/lightning/pull/16726)) + + +- Removed `trainer.reset_*_dataloader()` methods in favor of `Loop.setup_data()` for the top-level loops ([#16726](https://github.com/Lightning-AI/lightning/pull/16726)) + + - Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172)) * Removed the `LightningModule.truncated_bptt_steps` attribute * Removed the `LightningModule.tbptt_split_batch` hook diff --git a/src/lightning/pytorch/callbacks/batch_size_finder.py b/src/lightning/pytorch/callbacks/batch_size_finder.py index 84e04378709f0..f9070c6d00121 100644 --- a/src/lightning/pytorch/callbacks/batch_size_finder.py +++ b/src/lightning/pytorch/callbacks/batch_size_finder.py @@ -128,13 +128,8 @@ def __init__( def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: if trainer._accelerator_connector.is_distributed: raise MisconfigurationException("The Batch size finder is not supported with distributed strategies.") - - running_stage = trainer.state.stage - assert running_stage is not None - dl_source = getattr(trainer._data_connector, f"_{running_stage.dataloader_prefix}_dataloader_source") - # TODO: check if this can be enabled (#4040) - if not trainer._data_connector._train_dataloader_source.is_module(): + if not trainer.fit_loop._data_source.is_module(): raise MisconfigurationException( "The Batch size finder cannot be used with dataloaders passed directly to `.fit()`. Please disable" " the feature or incorporate the dataloader into your LightningModule or LightningDataModule." @@ -142,10 +137,16 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O # TODO: Add support for multiple eval dataloader if stage != "fit": - dataloaders = dl_source.dataloader() - if isinstance(dataloaders, list) and len(dataloaders) > 1: + loop = trainer._active_loop + assert loop is not None + loop.setup_data() + combined_loader = loop._combined_loader + assert combined_loader is not None + if len(combined_loader._flattened) > 1: + stage = trainer.state.stage + assert stage is not None raise MisconfigurationException( - f"The Batch size finder cannot be used with multiple {running_stage.dataloader_prefix} dataloaders." + f"The Batch size finder cannot be used with multiple {stage.dataloader_prefix} dataloaders." ) if not lightning_hasattr(pl_module, self._batch_arg_name): @@ -167,7 +168,6 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: new_size = _scale_batch_size( trainer, - pl_module, self._mode, self._steps_per_trial, self._init_val, diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index c3e7759573246..0f19c771027d5 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -140,7 +140,7 @@ def on_predict_batch_end( ) -> None: if not self.interval.on_batch: return - batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices + batch_indices = trainer.predict_loop.current_batch_indices self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx) def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/src/lightning/pytorch/loops/__init__.py b/src/lightning/pytorch/loops/__init__.py index e5b25857eccfe..d985fdf8dedda 100644 --- a/src/lightning/pytorch/loops/__init__.py +++ b/src/lightning/pytorch/loops/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from lightning.pytorch.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports) -from lightning.pytorch.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401 -from lightning.pytorch.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401 +from lightning.pytorch.loops.epoch import _TrainingEpochLoop # noqa: F401 +from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop # noqa: F401 from lightning.pytorch.loops.fit_loop import _FitLoop # noqa: F401 from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization # noqa: F401 +from lightning.pytorch.loops.prediction_loop import _PredictionLoop # noqa: F401 diff --git a/src/lightning/pytorch/loops/dataloader/__init__.py b/src/lightning/pytorch/loops/dataloader/__init__.py deleted file mode 100644 index c4ae2488ec0ec..0000000000000 --- a/src/lightning/pytorch/loops/dataloader/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from lightning.pytorch.loops.dataloader.dataloader_loop import _DataLoaderLoop # noqa: F401 -from lightning.pytorch.loops.dataloader.evaluation_loop import _EvaluationLoop # noqa: F401 -from lightning.pytorch.loops.dataloader.prediction_loop import _PredictionLoop # noqa: F401 diff --git a/src/lightning/pytorch/loops/dataloader/dataloader_loop.py b/src/lightning/pytorch/loops/dataloader/dataloader_loop.py deleted file mode 100644 index 12ff40277be0a..0000000000000 --- a/src/lightning/pytorch/loops/dataloader/dataloader_loop.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import abstractmethod -from typing import Sequence - -from torch.utils.data import DataLoader - -import lightning.pytorch as pl -from lightning.pytorch.loops.loop import _Loop -from lightning.pytorch.loops.progress import DataLoaderProgress - - -class _DataLoaderLoop(_Loop): - """Base class to loop over all dataloaders.""" - - def __init__(self, trainer: "pl.Trainer") -> None: - super().__init__(trainer) - self.dataloader_progress = DataLoaderProgress() - - @property - @abstractmethod - def dataloaders(self) -> Sequence[DataLoader]: - """Returns the dataloaders to loop over.""" - - @property - def current_dataloader_idx(self) -> int: - """Returns the index of the current dataloader.""" - return self.dataloader_progress.current.ready - 1 - - @property - def current_dataloader(self) -> DataLoader: - """Returns the current dataloader.""" - return self.dataloaders[self.current_dataloader_idx] - - @property - def num_dataloaders(self) -> int: - """Returns the number of dataloaders present.""" - return len(self.dataloaders) if self.dataloaders is not None else 0 - - @property - def done(self) -> bool: - """Returns whether all dataloaders have been processed.""" - return self.dataloader_progress.current.completed >= self.num_dataloaders - - def reset(self) -> None: - """Resets the internal state.""" - if not self.restarting: - self.dataloader_progress.reset_on_run() - else: - self.dataloader_progress.reset_on_restart() - - def on_advance_start(self) -> None: - self.dataloader_progress.increment_ready() - - def on_advance_end(self) -> None: - self.dataloader_progress.increment_completed() diff --git a/src/lightning/pytorch/loops/dataloader/prediction_loop.py b/src/lightning/pytorch/loops/dataloader/prediction_loop.py deleted file mode 100644 index 9b292959515e3..0000000000000 --- a/src/lightning/pytorch/loops/dataloader/prediction_loop.py +++ /dev/null @@ -1,178 +0,0 @@ -from typing import Any, List, Optional, Sequence, Union - -from torch.utils.data import DataLoader - -import lightning.pytorch as pl -from lightning.pytorch.loops.dataloader.dataloader_loop import _DataLoaderLoop -from lightning.pytorch.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop -from lightning.pytorch.loops.utilities import _no_grad_context, _set_sampler_epoch -from lightning.pytorch.strategies import DDPSpawnStrategy -from lightning.pytorch.trainer import call -from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.types import _PREDICT_OUTPUT - - -class _PredictionLoop(_DataLoaderLoop): - """Top-level loop where prediction starts. - - It simply iterates over each predict dataloader from one to the next by calling ``_PredictionEpochLoop.run()`` in - its ``advance()`` method. - """ - - def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None: - super().__init__(trainer) - self.epoch_batch_indices: List[List[List[int]]] = [] # used by PredictionWriter - self.epoch_loop = _PredictionEpochLoop(trainer) - self.inference_mode = inference_mode - - self._results = None # for `trainer._results` access - self._predictions: List[List[Any]] = [] # num_dataloaders x batches - self._return_predictions: bool = False - - @property - def return_predictions(self) -> bool: - """Whether to return the predictions or not.""" - return self._return_predictions - - @return_predictions.setter - def return_predictions(self, return_predictions: Optional[bool] = None) -> None: - # `DDPSpawnStrategy` plugins and derivatives don't support return predictions. - is_ddp_spawn = isinstance(self.trainer.strategy, DDPSpawnStrategy) - if return_predictions and is_ddp_spawn: - raise MisconfigurationException( - "`return_predictions` should be set to `False` when using the `DDPSpawnStrategy` or children class. " - f"Found {return_predictions} with strategy {type(self.trainer.strategy)}." - ) - # For non `DDPSpawnStrategy` plugin, the `return_predictions` is True by default unless user decide otherwise. - self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions - self.epoch_loop.return_predictions = self._return_predictions - - @property - def predictions(self) -> List[Any]: - """The cached predictions.""" - if self._predictions == []: - return self._predictions - return self._predictions[0] if self.num_dataloaders == 1 else self._predictions - - @property - def num_dataloaders(self) -> int: - """Returns the number of prediction dataloaders.""" - # case where user does: - # return dl1, dl2 - dataloaders = self.dataloaders - length = len(dataloaders) - if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): - length = len(dataloaders[0]) - return length - - @property - def max_batches(self) -> List[Union[int, float]]: - """The max number of batches this loop will run for each dataloader.""" - return self.trainer.num_predict_batches - - @property - def dataloaders(self) -> Sequence[DataLoader]: - """Returns all prediction dataloaders.""" - dataloaders = self.trainer.predict_dataloaders - return [] if dataloaders is None else dataloaders - - @property - def skip(self) -> bool: - return sum(self.max_batches) == 0 - - @_no_grad_context - def run(self) -> Optional[_PREDICT_OUTPUT]: - if self.skip: - return None - self.reset() - self.on_run_start() - while not self.done: - try: - self.on_advance_start() - self.advance() - self.on_advance_end() - self._restarting = False - except StopIteration: - break - self._restarting = False - return self.on_run_end() - - def reset(self) -> None: - """Resets the internal state of the loop for a new run.""" - self._predictions = [] - self.epoch_batch_indices = [] - - super().reset() - # when restarting, if we are running twice, since there's no concept of `max_epochs` we need to reset the - # current state when the loop has finished running - if self.done: - self.dataloader_progress.reset_on_run() - - def on_run_start(self) -> None: - """Calls ``_on_predict_model_eval``, ``_on_predict_start`` and ``_on_predict_epoch_start`` hooks.""" - trainer = self.trainer - call._call_lightning_module_hook(trainer, "on_predict_model_eval") - trainer.lightning_module.zero_grad() - self._on_predict_start() - self._on_predict_epoch_start() - - def advance(self) -> None: - """Predicts one entire dataloader.""" - dataloader = self.current_dataloader - if dataloader is not None: - _set_sampler_epoch(dataloader, self.trainer.fit_loop.epoch_progress.current.processed) - dataloader = self.trainer.strategy.process_dataloader(dataloader) - dataloader_iter = enumerate(dataloader) - dl_max_batches = self.max_batches[self.current_dataloader_idx] - - dl_predictions, dl_batch_indices = self.epoch_loop.run( - dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders - ) - self._predictions.append(dl_predictions) - self.epoch_batch_indices.append(dl_batch_indices) - - def on_run_end(self) -> Optional[_PREDICT_OUTPUT]: - """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders.""" - results = self._on_predict_epoch_end() - self._on_predict_end() - return results - - def teardown(self) -> None: - pass - - def _on_predict_start(self) -> None: - """Calls ``on_predict_start`` hooks.""" - trainer = self.trainer - call._call_callback_hooks(trainer, "on_predict_start") - call._call_lightning_module_hook(trainer, "on_predict_start") - call._call_strategy_hook(trainer, "on_predict_start") - - def _on_predict_epoch_start(self) -> None: - """Calls ``on_predict_epoch_start`` hooks.""" - trainer = self.trainer - call._call_callback_hooks(trainer, "on_predict_epoch_start") - call._call_lightning_module_hook(trainer, "on_predict_epoch_start") - - def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: - """Calls ``on_predict_epoch_end`` hook. - - Returns: - the results for all dataloaders - """ - trainer = self.trainer - call._call_callback_hooks(trainer, "on_predict_epoch_end") - call._call_lightning_module_hook(trainer, "on_predict_epoch_end") - - if self.return_predictions: - return self.predictions - - def _on_predict_end(self) -> None: - """Resets previous gradient status and calls ``on_predict_end`` hook.""" - if not self.return_predictions: - self._predictions = [] - self.epoch_batch_indices = [] - - trainer = self.trainer - call._call_callback_hooks(trainer, "on_predict_end") - call._call_lightning_module_hook(trainer, "on_predict_end") - call._call_strategy_hook(trainer, "on_predict_end") diff --git a/src/lightning/pytorch/loops/epoch/__init__.py b/src/lightning/pytorch/loops/epoch/__init__.py index 990f935ac212c..13a5f34b74a71 100644 --- a/src/lightning/pytorch/loops/epoch/__init__.py +++ b/src/lightning/pytorch/loops/epoch/__init__.py @@ -12,6 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lightning.pytorch.loops.epoch.evaluation_epoch_loop import _EvaluationEpochLoop # noqa: F401 -from lightning.pytorch.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop # noqa: F401 from lightning.pytorch.loops.epoch.training_epoch_loop import _TrainingEpochLoop # noqa: F401 diff --git a/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py b/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py deleted file mode 100644 index 007068febf2ff..0000000000000 --- a/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import OrderedDict -from typing import Any, Optional, Union - -import lightning.pytorch as pl -from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher -from lightning.pytorch.loops.loop import _Loop -from lightning.pytorch.loops.progress import BatchProgress -from lightning.pytorch.trainer import call -from lightning.pytorch.trainer.states import TrainerFn -from lightning.pytorch.utilities.exceptions import SIGTERMException -from lightning.pytorch.utilities.types import STEP_OUTPUT - - -class _EvaluationEpochLoop(_Loop): - """This is the loop performing the evaluation. - - It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current - state). - """ - - def __init__(self, trainer: "pl.Trainer") -> None: - super().__init__(trainer) - self.batch_progress = BatchProgress() - - self._dl_max_batches: Union[int, float] = 0 - self._data_fetcher: Optional[_DataFetcher] = None - self._dl_batch_idx = [0] - - @property - def done(self) -> bool: - """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" - return self.batch_progress.current.completed >= self._dl_max_batches - - def run(self, data_fetcher: _DataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None: - self.reset() - self.on_run_start(data_fetcher, dl_max_batches, kwargs) - while not self.done: - try: - self.advance(data_fetcher, kwargs) - self._restarting = False - except StopIteration: - break - self._restarting = False - self.on_run_end() - - def reset(self) -> None: - """Resets the loop's internal state.""" - self._dl_max_batches = 0 - self._data_fetcher = None - - if not self.restarting: - self.batch_progress.reset_on_run() - else: - self.batch_progress.reset_on_restart() - # when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we - # need to reset the current state when the loop has finished running - if self.done and self.trainer.state.fn != TrainerFn.FITTING: - self.batch_progress.reset_on_run() - - def on_run_start(self, data_fetcher: _DataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None: - """Adds the passed arguments to the loop's state if necessary. - - Args: - data_fetcher: the current data_fetcher wrapping the dataloader - dl_max_batches: maximum number of batches the dataloader can produce - kwargs: the kwargs passed down to the hooks. - """ - self._dl_max_batches = dl_max_batches - # creates the iterator inside the fetcher but returns `self` - self._data_fetcher = iter(data_fetcher) - # add the previous `fetched` value to properly track `is_last_batch` with no prefetching - data_fetcher.fetched += self.batch_progress.current.ready - - stage = self.trainer.state.stage - assert stage is not None - stage = stage.dataloader_prefix - self._profiler_fetch_action = ( - f"[{self.__class__.__name__}].{stage}_dataloader_idx_{kwargs.get('dataloader_idx', 0)}_next" - ) - data_fetcher._start_profiler = self._on_before_fetch - data_fetcher._stop_profiler = self._on_after_fetch - - def _on_before_fetch(self) -> None: - self.trainer.profiler.start(self._profiler_fetch_action) - - def _on_after_fetch(self) -> None: - self.trainer.profiler.stop(self._profiler_fetch_action) - - def advance( - self, - data_fetcher: _DataFetcher, - kwargs: OrderedDict, - ) -> None: - """Calls the evaluation step with the corresponding hooks and updates the logger connector. - - Args: - data_fetcher: iterator over the dataloader - kwargs: the kwargs passed down to the hooks. - - Raises: - StopIteration: If the current batch is None - """ - batch_idx = ( - data_fetcher.fetched - if isinstance(data_fetcher, _DataLoaderIterDataFetcher) - else self.batch_progress.current.ready - ) - batch = next(data_fetcher) - self.batch_progress.is_last_batch = data_fetcher.done - - dataloader_idx = kwargs.get("dataloader_idx", 0) - batch = self.trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=dataloader_idx) - batch = call._call_strategy_hook(self.trainer, "batch_to_device", batch, dataloader_idx=dataloader_idx) - - # configure step_kwargs - kwargs = self._build_kwargs(kwargs, batch, batch_idx) - - self.batch_progress.increment_ready() - - # hook - self._on_evaluation_batch_start(**kwargs) - - self.batch_progress.increment_started() - - # lightning module methods - output = self._evaluation_step(**kwargs) - output = self._evaluation_step_end(output) - - self.batch_progress.increment_processed() - - # track loss history - self._on_evaluation_batch_end(output, **kwargs) - - self.batch_progress.increment_completed() - - # log batch metrics - if not self.trainer.sanity_checking: - self.trainer._logger_connector.update_eval_step_metrics(self._dl_batch_idx[dataloader_idx]) - self._dl_batch_idx[dataloader_idx] += 1 - - if not self.batch_progress.is_last_batch and self.trainer.received_sigterm: - raise SIGTERMException - - def on_run_end(self) -> None: - self._data_fetcher = None - - def _evaluation_step(self, **kwargs: Any) -> Optional[STEP_OUTPUT]: - """The evaluation step (validation_step or test_step depending on the trainer's state). - - Args: - batch: The current batch to run through the step. - batch_idx: The index of the current batch - dataloader_idx: the index of the dataloader producing the current batch - - Returns: - the outputs of the step - """ - trainer = self.trainer - hook_name = "test_step" if trainer.testing else "validation_step" - return call._call_strategy_hook(trainer, hook_name, *kwargs.values()) - - def _evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - """Calls the `{validation/test}_step_end` hook.""" - trainer = self.trainer - hook_name = "test_step_end" if trainer.testing else "validation_step_end" - model_output = call._call_lightning_module_hook(trainer, hook_name, *args, **kwargs) - strategy_output = call._call_strategy_hook(trainer, hook_name, *args, **kwargs) - output = strategy_output if model_output is None else model_output - return output - - def _on_evaluation_batch_start(self, **kwargs: Any) -> None: - """Calls the ``on_{validation/test}_batch_start`` hook. - - Args: - batch: The current batch to run through the step - batch_idx: The index of the current batch - dataloader_idx: The index of the dataloader producing the current batch - - Raises: - AssertionError: If the number of dataloaders is None (has not yet been set). - """ - trainer = self.trainer - trainer._logger_connector.on_batch_start(**kwargs) - - hook_name = "on_test_batch_start" if trainer.testing else "on_validation_batch_start" - call._call_callback_hooks(trainer, hook_name, *kwargs.values()) - call._call_lightning_module_hook(trainer, hook_name, *kwargs.values()) - - def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any) -> None: - """The ``on_{validation/test}_batch_end`` hook. - - Args: - output: The output of the performed step - batch: The input batch for the step - batch_idx: The index of the current batch - dataloader_idx: Index of the dataloader producing the current batch - """ - trainer = self.trainer - - hook_name = "on_test_batch_end" if trainer.testing else "on_validation_batch_end" - call._call_callback_hooks(trainer, hook_name, output, *kwargs.values()) - call._call_lightning_module_hook(trainer, hook_name, output, *kwargs.values()) - - trainer._logger_connector.on_batch_end() - - def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict: - """Helper method to build the arguments for the current step. - - Args: - kwargs: The kwargs passed down to the hooks. - batch: The current batch to run through the step. - - Returns: - The kwargs passed down to the hooks. - """ - kwargs.update(batch=batch, batch_idx=batch_idx) - # `dataloader_idx` should be last so we need to push these to the front - kwargs.move_to_end("batch_idx", last=False) - kwargs.move_to_end("batch", last=False) - return kwargs - - def _reset_dl_batch_idx(self, num_dataloaders: int) -> None: - self._dl_batch_idx = [0] * num_dataloaders diff --git a/src/lightning/pytorch/loops/epoch/prediction_epoch_loop.py b/src/lightning/pytorch/loops/epoch/prediction_epoch_loop.py deleted file mode 100644 index 1e7d944b348d2..0000000000000 --- a/src/lightning/pytorch/loops/epoch/prediction_epoch_loop.py +++ /dev/null @@ -1,190 +0,0 @@ -from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Tuple, Union - -import torch - -import lightning.pytorch as pl -from lightning.fabric.utilities import move_data_to_device -from lightning.pytorch.callbacks import BasePredictionWriter -from lightning.pytorch.loops.loop import _Loop -from lightning.pytorch.loops.progress import Progress -from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper -from lightning.pytorch.trainer import call -from lightning.pytorch.utilities.rank_zero import WarningCache - -warning_cache = WarningCache() - - -class _PredictionEpochLoop(_Loop): - """Loop performing prediction on arbitrary sequentially used dataloaders.""" - - def __init__(self, trainer: "pl.Trainer") -> None: - super().__init__(trainer) - self.return_predictions = False - self.predictions: List[Any] = [] - self.current_batch_indices: List[int] = [] - self.batch_progress = Progress() - - self._dl_max_batches: Union[int, float] = 0 - self._num_dataloaders = 0 - self._warning_cache = WarningCache() - self._seen_batch_indices: List[List[int]] = [] - - @property - def done(self) -> bool: - """Ends prediction when the iteration count exceeds the total number of available batches.""" - return self.batch_progress.current.completed >= self._dl_max_batches - - @property - def should_store_predictions(self) -> bool: - """Whether the predictions should be stored for later usage (e.g. aggregation or returning)""" - prediction_writers = [cb for cb in self.trainer.callbacks if isinstance(cb, BasePredictionWriter)] - any_pred = any(cb.interval.on_epoch for cb in prediction_writers) - return self.return_predictions or any_pred - - def run( - self, - dataloader_iter: Iterator, - dataloader_idx: int, - dl_max_batches: Union[int, float], - num_dataloaders: int, - ) -> Tuple[List[Any], List[List[int]]]: - self.reset() - self.on_run_start(dataloader_idx, dl_max_batches, num_dataloaders) - while not self.done: - try: - self.advance(dataloader_iter, dataloader_idx) - self._restarting = False - except StopIteration: - break - self._restarting = False - return self.on_run_end() - - def reset(self) -> None: - """Resets the loops internal state.""" - self._seen_batch_indices = [] - self.predictions = [] - self.batch_progress.reset_on_run() - - def on_run_start( - self, - dataloader_idx: int, - dl_max_batches: Union[int, float], - num_dataloaders: int, - ) -> None: - """Prepares the loops internal state. - - Args: - dataloader_idx: the index of the current dataloader - dl_max_batches: the maximum number of batches the current loader can produce - num_dataloaders: the total number of dataloaders - """ - self._dl_max_batches = dl_max_batches - self._num_dataloaders = num_dataloaders - # this call requires that `self.return_predictions` is set - self._seen_batch_indices = self._get_batch_indices(dataloader_idx) if self.should_store_predictions else [] - - def advance( - self, - dataloader_iter: Iterator, - dataloader_idx: int, - ) -> None: - """Runs one prediction step. - - Args: - dataloader_iter: the iterator over the current dataloader - dataloader_idx: the index of the current dataloader - """ - trainer = self.trainer - - action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next" - with trainer.profiler.profile(action_name): - batch_idx, batch = next(dataloader_iter) - self._seen_batch_indices = self._get_batch_indices(dataloader_idx) if self.should_store_predictions else [] - # we need to truncate the list of batch indices due to prefetching in the dataloader and Lightning - self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)] - - if batch is None: - raise StopIteration - - batch = trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=dataloader_idx) - batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=dataloader_idx) - - self.batch_progress.increment_ready() - - self._predict_step(batch, batch_idx, dataloader_idx) - - def on_run_end(self) -> Tuple[List[Any], List[List[int]]]: - """Returns the predictions and the corresponding batch indices.""" - predictions, all_batch_indices = self.predictions, self._seen_batch_indices - self.predictions, self._seen_batch_indices = [], [] # free memory - return predictions, all_batch_indices - - def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the - predict step. - - Args: - batch: the current batch to run the prediction on - batch_idx: the index of the current batch - dataloader_idx: the index of the dataloader producing the current batch - """ - # configure step_kwargs - step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) - - # extract batch_indices and store them - batch_indices = self._get_batch_indices(dataloader_idx) - self.current_batch_indices = batch_indices[batch_idx] if batch_indices else [] - - trainer = self.trainer - call._call_callback_hooks(trainer, "on_predict_batch_start", *step_kwargs.values()) - call._call_lightning_module_hook(trainer, "on_predict_batch_start", *step_kwargs.values()) - - self.batch_progress.increment_started() - - predictions = call._call_strategy_hook(trainer, "predict_step", *step_kwargs.values()) - - self.batch_progress.increment_processed() - - if predictions is None: - self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") - - call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *step_kwargs.values()) - call._call_lightning_module_hook(trainer, "on_predict_batch_end", predictions, *step_kwargs.values()) - - self.batch_progress.increment_completed() - - if self.should_store_predictions: - self.predictions.append(move_data_to_device(predictions, torch.device("cpu"))) - - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]: - """Assembles the keyword arguments for the ``predict_step`` - - Args: - batch: the current batch to run the prediction on - batch_idx: the index of the current batch - dataloader_idx: the index of the dataloader producing the current batch - - Returns: - the dictionary containing all the keyboard arguments for the predict step - """ - step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) - if self._num_dataloaders > 1: - step_kwargs["dataloader_idx"] = dataloader_idx - return step_kwargs - - def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]: - """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our - :class:`~lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper`.""" - # the batch_sampler is not be defined in case of CombinedDataLoaders - assert self.trainer.predict_dataloaders - batch_sampler = getattr( - self.trainer.predict_dataloaders[dataloader_idx], - "batch_sampler", - None, - ) - if isinstance(batch_sampler, IndexBatchSamplerWrapper): - return batch_sampler.seen_batch_indices - - warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.") - return [] diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index 26d801bf21ad2..045c8322b9e52 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -104,7 +104,7 @@ def _is_training_done(self) -> bool: @property def _is_validation_done(self) -> bool: # when we are restarting we want to check whether the val loop has finished - return not self.restarting or self.val_loop.done + return not self.restarting or self.val_loop._has_run @property def done(self) -> bool: @@ -158,13 +158,12 @@ def reset(self) -> None: self.automatic_optimization.optim_progress.reset_on_run() # when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches # seen per epoch, this is useful for tracking when validation is run multiple times per epoch - self.val_loop.epoch_loop.batch_progress.total.reset() + self.val_loop.batch_progress.total.reset() def on_run_start(self, data_fetcher: _DataFetcher) -> None: - _ = iter(data_fetcher) # creates the iterator inside the fetcher + iter(data_fetcher) # creates the iterator inside the fetcher # add the previous `fetched` value to properly track `is_last_batch` with no prefetching data_fetcher.fetched += self.batch_progress.current.ready - data_fetcher._start_profiler = self._on_before_fetch data_fetcher._stop_profiler = self._on_after_fetch @@ -246,7 +245,7 @@ def on_advance_end(self) -> None: should_check_val = self._should_check_val_fx() if should_check_val: self.trainer.validating = True - self._run_validation() + self.val_loop.run() self.trainer.training = True # update plateau LR scheduler after metrics are logged @@ -275,12 +274,6 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0) - def _run_validation(self) -> None: - # reload dataloaders - self.val_loop._reload_evaluation_dataloaders() - - self.val_loop.run() - def _accumulated_batches_reached(self) -> bool: """Determine if accumulation will be finished by the end of the current batch.""" return self.batch_progress.current.ready % self.trainer.accumulate_grad_batches == 0 diff --git a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py similarity index 51% rename from src/lightning/pytorch/loops/dataloader/evaluation_loop.py rename to src/lightning/pytorch/loops/evaluation_loop.py index 3e30da9adaeb1..0775bd7a94aa1 100644 --- a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -14,151 +14,198 @@ import os import shutil import sys -from collections import ChainMap, OrderedDict -from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union +from collections import ChainMap, defaultdict, OrderedDict +from typing import Any, DefaultDict, Iterable, List, Optional, Tuple, Union from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor -from torch.utils.data.dataloader import DataLoader import lightning.pytorch as pl from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE -from lightning.pytorch.loops.dataloader import _DataLoaderLoop -from lightning.pytorch.loops.epoch import _EvaluationEpochLoop -from lightning.pytorch.loops.fetchers import _DataFetcher +from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher +from lightning.pytorch.loops.loop import _Loop +from lightning.pytorch.loops.progress import BatchProgress from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _set_sampler_epoch from lightning.pytorch.trainer import call +from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.trainer.supporters import _Sequential, CombinedLoader +from lightning.pytorch.utilities.exceptions import SIGTERMException +from lightning.pytorch.utilities.model_helpers import is_overridden if _RICH_AVAILABLE: from rich import get_console from rich.table import Column, Table -class _EvaluationLoop(_DataLoaderLoop): - """Top-level loop where validation/testing starts. - - It simply iterates over each evaluation dataloader from one to the next by calling ``EvaluationEpochLoop.run()`` in - its ``advance()`` method. - """ +class _EvaluationLoop(_Loop): + """Top-level loop where validation/testing starts.""" def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode: bool = True) -> None: super().__init__(trainer) - self.epoch_loop = _EvaluationEpochLoop(trainer) self.verbose = verbose self.inference_mode = inference_mode + self.batch_progress = BatchProgress() # across dataloaders self._results = _ResultCollection(training=False) self._logged_outputs: List[_OUT_DICT] = [] - self._max_batches: List[Union[int, float]] = [] self._has_run: bool = False + self._data_source = _DataLoaderSource(None, "") + self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None + self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int) @property def num_dataloaders(self) -> int: - """Returns the total number of dataloaders.""" - # case where user does: - # return dl1, dl2 - dataloaders = self.dataloaders - length = len(dataloaders) - if length > 0 and isinstance(dataloaders[0], (list, tuple)): - length = len(dataloaders[0]) - return length - - @property - def dataloaders(self) -> Sequence[DataLoader]: - """Returns the validation or test dataloaders.""" - dataloaders = self.trainer.test_dataloaders if self.trainer.testing else self.trainer.val_dataloaders - if dataloaders is None: - return [] - return dataloaders + """Returns the number of prediction dataloaders.""" + combined_loader = self._combined_loader + assert combined_loader is not None + return len(combined_loader._flattened) @property - def done(self) -> bool: - """Returns whether all dataloaders are processed or evaluation should be skipped altogether.""" - return super().done or self.skip + def max_batches(self) -> List[Union[int, float]]: + """The max number of batches this loop will run for each dataloader.""" + if self.trainer.testing: + return self.trainer.num_test_batches + elif self.trainer.sanity_checking: + return self.trainer.num_sanity_val_batches + elif self.trainer.validating: + return self.trainer.num_val_batches + raise RuntimeError(f"Unexpected stage: {self.trainer.state.stage}") @property def skip(self) -> bool: """Returns whether the evaluation should be skipped.""" - max_batches = self._get_max_batches() - return sum(max_batches) == 0 + return sum(self.max_batches) == 0 @_no_grad_context def run(self) -> List[_OUT_DICT]: + self.setup_data() if self.skip: return [] self.reset() self.on_run_start() - while not self.done: + data_fetcher = self._data_fetcher + assert data_fetcher is not None + previous_dataloader_idx = 0 + while True: try: - self.on_advance_start() - self.advance() - self.on_advance_end() - self._restarting = False + batch, batch_idx, dataloader_idx = next(data_fetcher) + self.batch_progress.is_last_batch = data_fetcher.done + if previous_dataloader_idx != dataloader_idx: + # the dataloader has changed, notify the logger connector + self._store_dataloader_outputs() + previous_dataloader_idx = dataloader_idx + # run step hooks + self._evaluation_step(batch, batch_idx, dataloader_idx) except StopIteration: + # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support break - self._restarting = False + finally: + self._restarting = False + self._store_dataloader_outputs() return self.on_run_end() + def setup_data(self) -> None: + trainer = self.trainer + + if ( + self._combined_loader is not None + and trainer.state.fn == "fit" + and not trainer._data_connector._should_reload_val_dl + ): + return + + source = self._data_source + pl_module = trainer.lightning_module + limit_batches = trainer.limit_test_batches if trainer.testing else trainer.limit_val_batches + hook_name = "test_step" if trainer.testing else "validation_step" + if not source.is_defined() or limit_batches == 0 or not is_overridden(hook_name, pl_module): + return + + # store epoch of dataloader reset for reload_dataloaders_every_n_epochs + # it should not reload again if it has already reloaded during sanity_check + if trainer.state.fn == "fit" and ( + (trainer.sanity_checking and trainer.fit_loop.epoch_loop._should_check_val_epoch()) + or not trainer.sanity_checking + ): + trainer._last_val_dl_reload_epoch = trainer.current_epoch + + stage = trainer.state.stage + assert stage is not None + num_batches, iterables = trainer._data_connector._reset_eval_dataloader(stage, model=pl_module) + if trainer.testing: + trainer.num_test_batches = num_batches + elif trainer.sanity_checking: + trainer.num_val_batches = num_batches + trainer.num_sanity_val_batches = [ + min(trainer.num_sanity_val_steps, val_batches) for val_batches in num_batches + ] + else: + trainer.num_val_batches = num_batches + + combined_loader = CombinedLoader(iterables, "sequential") + for i, dl in enumerate(combined_loader._flattened): + if trainer.state.fn != "fit": # if we are fitting, we need to do this in the loop + # some users want validation shuffling based on the training progress + _set_sampler_epoch(dl, trainer.fit_loop.epoch_progress.current.processed) + # allow the strategy to inject logic + dl = trainer.strategy.process_dataloader(dl) + combined_loader._update_index(dl, i) + self._combined_loader = combined_loader + + # this depends on the data used, so reset it too + self._seen_batches_per_dataloader = defaultdict(int) + def reset(self) -> None: """Resets the internal state of the loop.""" - self._max_batches = self._get_max_batches() - # bookkeeping - self._logged_outputs = [] + trainer = self.trainer - if isinstance(self._max_batches, int): - self._max_batches = [self._max_batches] * len(self.dataloaders) + self._has_run = False + self._logged_outputs = [] - super().reset() + if not self.restarting: + self.batch_progress.reset_on_run() + else: + self.batch_progress.reset_on_restart() # when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we # need to reset the current state when the loop has finished running - if self.done and self.trainer.state.fn != TrainerFn.FITTING: - self.dataloader_progress.reset_on_run() + if trainer.state.fn != TrainerFn.FITTING: + self.batch_progress.reset_on_run() + + data_fetcher = _select_data_fetcher(trainer) + if isinstance(data_fetcher, _DataLoaderIterDataFetcher) and self.num_dataloaders > 1: + raise NotImplementedError( + "Using `dataloader_iter` in your step method is not supported with multiple dataloaders" + ) + combined_loader = self._combined_loader + assert combined_loader is not None + + if trainer.state.fn == "fit": + for i, dl in enumerate(combined_loader._flattened): + # some users want validation shuffling based on the training progress + _set_sampler_epoch(dl, trainer.fit_loop.epoch_progress.current.processed) + + data_fetcher.setup(combined_loader) + iter(data_fetcher) # creates the iterator inside the fetcher + assert isinstance(combined_loader._iterator, _Sequential) + # set the per-dataloader limits + combined_loader._iterator.limits = self.max_batches + # add the previous `fetched` value to properly track `is_last_batch` with no prefetching + data_fetcher.fetched += self.batch_progress.current.ready + data_fetcher._start_profiler = self._on_before_fetch + data_fetcher._stop_profiler = self._on_after_fetch + self._data_fetcher = data_fetcher def on_run_start(self) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" - self._data_fetcher = _select_data_fetcher(self.trainer) - - # hook self._on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() self._on_evaluation_start() self._on_evaluation_epoch_start() - def advance(self) -> None: - """Performs evaluation on one single dataloader.""" - dataloader_idx = self.current_dataloader_idx - dataloader = self.current_dataloader - - assert self._data_fetcher is not None - self._data_fetcher.setup(dataloader) - dl_max_batches = self._max_batches[dataloader_idx] - - kwargs = OrderedDict() - if self.num_dataloaders > 1: - kwargs["dataloader_idx"] = dataloader_idx - self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) - - if not self.trainer.sanity_checking: - # indicate the loop has run - self._has_run = True - - def on_advance_start(self, *args: Any, **kwargs: Any) -> None: - if self.current_dataloader is not None: - _set_sampler_epoch(self.current_dataloader, self.trainer.fit_loop.epoch_progress.current.processed) - - super().on_advance_start(*args, **kwargs) - - def on_advance_end(self) -> None: - self.trainer._logger_connector.epoch_end_reached() - - self._logged_outputs.append(self.trainer._logger_connector.update_eval_epoch_metrics()) - - super().on_advance_end() - def on_run_end(self) -> List[_OUT_DICT]: """Runs the ``_on_evaluation_epoch_end`` hook.""" # if `done` returned True before any iterations were done, this won't have been called in `on_advance_end` @@ -197,29 +244,6 @@ def teardown(self) -> None: self._data_fetcher = None self._results.cpu() - def _get_max_batches(self) -> List[Union[int, float]]: - """Returns the max number of batches for each dataloader.""" - if self.trainer.testing: - max_batches = self.trainer.num_test_batches - else: - if self.trainer.sanity_checking: - max_batches = self.trainer.num_sanity_val_batches - else: - max_batches = self.trainer.num_val_batches - return max_batches - - def _reload_evaluation_dataloaders(self) -> None: - """Reloads dataloaders if necessary.""" - dataloaders = None - if self.trainer.testing: - self.trainer.reset_test_dataloader() - dataloaders = self.trainer.test_dataloaders - elif self.trainer.val_dataloaders is None or self.trainer._data_connector._should_reload_val_dl: - self.trainer.reset_val_dataloader() - dataloaders = self.trainer.val_dataloaders - if dataloaders is not None: - self.epoch_loop._reset_dl_batch_idx(len(dataloaders)) - def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_{validation/test}_start`` hooks.""" trainer = self.trainer @@ -275,6 +299,95 @@ def _on_evaluation_epoch_end(self) -> None: trainer._logger_connector.on_epoch_end() + def _store_dataloader_outputs(self) -> None: + trainer = self.trainer + trainer._logger_connector.epoch_end_reached() + self._logged_outputs.append(trainer._logger_connector.update_eval_epoch_metrics()) + + def _on_before_fetch(self) -> None: + stage = self.trainer.state.stage + assert stage is not None + stage = stage.dataloader_prefix + self.trainer.profiler.start(f"[{type(self).__name__}].{stage}_next") + + def _on_after_fetch(self) -> None: + stage = self.trainer.state.stage + assert stage is not None + stage = stage.dataloader_prefix + # the dataloader_idx cannot be easily included here because it might be different from the index used on + # profiler start, since the `__next__` call might use a different iterator + self.trainer.profiler.stop(f"[{type(self).__name__}].{stage}_next") + + def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """Runs the actual evaluation step together with all the necessary bookkeeping and the hooks tied to it. + + Args: + batch: The current batch to run through the step. + batch_idx: The index of the current batch + dataloader_idx: the index of the dataloader producing the current batch + """ + trainer = self.trainer + + batch = trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=dataloader_idx) + batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=dataloader_idx) + + step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None) + + self.batch_progress.increment_ready() + + trainer._logger_connector.on_batch_start(**step_kwargs) + + hook_name = "on_test_batch_start" if trainer.testing else "on_validation_batch_start" + call._call_callback_hooks(trainer, hook_name, *step_kwargs.values()) + call._call_lightning_module_hook(trainer, hook_name, *step_kwargs.values()) + + self.batch_progress.increment_started() + + hook_name = "test_step" if trainer.testing else "validation_step" + output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values()) + + hook_name = "test_step_end" if trainer.testing else "validation_step_end" + model_output = call._call_lightning_module_hook(trainer, hook_name, output) + strategy_output = call._call_strategy_hook(trainer, hook_name, output) + output = strategy_output if model_output is None else model_output + + self.batch_progress.increment_processed() + + hook_name = "on_test_batch_end" if trainer.testing else "on_validation_batch_end" + call._call_callback_hooks(trainer, hook_name, output, *step_kwargs.values()) + call._call_lightning_module_hook(trainer, hook_name, output, *step_kwargs.values()) + + trainer._logger_connector.on_batch_end() + + self.batch_progress.increment_completed() + + if not trainer.sanity_checking: + # indicate the loop has run + self._has_run = True + + # log batch metrics + trainer._logger_connector.update_eval_step_metrics(self._seen_batches_per_dataloader[dataloader_idx]) + self._seen_batches_per_dataloader[dataloader_idx] += 1 + + if not self.batch_progress.is_last_batch and trainer.received_sigterm: + raise SIGTERMException + + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> OrderedDict: + """Helper method to build the arguments for the current step. + + Args: + batch: the current batch to run through the step. + batch_idx: the index of the current batch + dataloader_idx: the index of the dataloader producing the current batch. None if not multiple dataloaders. + + Returns: + the dictionary containing all the keyboard arguments for the step + """ + step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) + if dataloader_idx is not None: + step_kwargs["dataloader_idx"] = dataloader_idx + return step_kwargs + @staticmethod def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]: for k, v in data.items(): diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index ef9f44c13f465..87854c27f5e86 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -12,19 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import Iterable, Optional + +from lightning_utilities.core.apply_func import apply_to_collection +from torch.utils.data import DataLoader import lightning.pytorch as pl +from lightning.fabric.utilities.data import _auto_add_worker_init_fn from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.epoch import _TrainingEpochLoop from lightning.pytorch.loops.fetchers import _DataFetcher from lightning.pytorch.loops.progress import Progress from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher, _set_sampler_epoch from lightning.pytorch.trainer import call +from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection +from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.trainer.supporters import CombinedLoader +from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException -from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info +from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn +from lightning.pytorch.utilities.warnings import PossibleUserWarning log = logging.getLogger(__name__) @@ -74,7 +83,8 @@ def __init__( self.epoch_loop = _TrainingEpochLoop(trainer) self.epoch_progress = Progress() - self._is_fresh_start_epoch: bool = True + self._data_source = _DataLoaderSource(None, "train_dataloader") + self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None @property @@ -182,6 +192,7 @@ def skip(self) -> bool: return self.done or self.trainer.limit_train_batches == 0 def run(self) -> None: + self.setup_data() if self.skip: return self.reset() @@ -197,6 +208,117 @@ def run(self) -> None: self._restarting = False self.on_run_end() + def setup_data(self) -> None: + trainer = self.trainer + + if self._combined_loader is not None and not trainer._data_connector._should_reload_train_dl: + return + + source = self._data_source + pl_module = trainer.lightning_module + if not source.is_defined() or trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module): + return + + log.debug(f"{self.__class__.__name__}: resetting train dataloader") + + train_dataloader = trainer._data_connector._request_dataloader() + + if trainer.overfit_batches > 0: + train_dataloader = trainer._data_connector._resolve_overfit_batches( + train_dataloader, mode=RunningStage.TRAINING + ) + + # automatically add samplers + train_dataloader = apply_to_collection( + train_dataloader, + (DataLoader, CombinedLoader), + trainer._data_connector._prepare_dataloader, + mode=RunningStage.TRAINING, + ) + train_dataloader = ( + train_dataloader.iterables if isinstance(train_dataloader, CombinedLoader) else train_dataloader + ) + + apply_to_collection(train_dataloader, Iterable, trainer.strategy.process_dataloader) + + # check the workers recursively + apply_to_collection(train_dataloader, DataLoader, trainer._data_connector._worker_check, "train_dataloader") + + # add worker_init_fn for correct seeding in worker processes + apply_to_collection(train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=trainer.global_rank) + + if not isinstance(train_dataloader, CombinedLoader): + self._combined_loader = CombinedLoader(train_dataloader, trainer._data_connector.multiple_trainloader_mode) + else: + self._combined_loader = train_dataloader + + module = pl_module or trainer.datamodule + orig_train_batches = trainer.num_training_batches = ( + len(self._combined_loader) + if has_len_all_ranks(self._combined_loader, trainer.strategy, module) + else float("inf") + ) + if orig_train_batches == 0: + return + + # store epoch of dataloader reset for reload_dataloaders_every_n_epochs + trainer._last_train_dl_reload_epoch = trainer.current_epoch + + if isinstance(trainer.limit_train_batches, int): + trainer.num_training_batches = min(orig_train_batches, trainer.limit_train_batches) + elif trainer.num_training_batches != float("inf"): + trainer.num_training_batches = int(orig_train_batches * trainer.limit_train_batches) + elif trainer.limit_train_batches != 1.0: + raise MisconfigurationException( + "When using an `IterableDataset`, `Trainer(limit_train_batches)` must be `1.0` or an int." + "An int specifies `num_training_batches` to use." + ) + + if isinstance(trainer.val_check_interval, int): + trainer.val_check_batch = trainer.val_check_interval + if trainer.val_check_batch > trainer.num_training_batches and trainer.check_val_every_n_epoch is not None: + raise ValueError( + f" `val_check_interval` ({trainer.val_check_interval}) must be less than or equal" + f" to the number of the training batches ({trainer.num_training_batches})." + " If you want to disable validation set `limit_val_batches` to 0.0 instead." + " If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`." + ) + else: + if not has_len_all_ranks(self._combined_loader, trainer.strategy, module): + if trainer.val_check_interval == 1.0: + trainer.val_check_batch = float("inf") + else: + raise MisconfigurationException( + "When using an IterableDataset for `train_dataloader`," + " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" + " checking validation every k training batches." + ) + else: + trainer.val_check_batch = int(trainer.num_training_batches * trainer.val_check_interval) + trainer.val_check_batch = max(1, trainer.val_check_batch) + + if trainer.loggers and trainer.num_training_batches < trainer.log_every_n_steps: + rank_zero_warn( + f"The number of training batches ({trainer.num_training_batches}) is smaller than the logging interval" + f" Trainer(log_every_n_steps={trainer.log_every_n_steps}). Set a lower value for log_every_n_steps if" + " you want to see logs for the training epoch.", + category=PossibleUserWarning, + ) + + if ( + trainer.num_training_batches == 0 + and trainer.limit_train_batches > 0.0 + and isinstance(trainer.limit_train_batches, float) + and orig_train_batches != float("inf") + ): + min_percentage = 1.0 / orig_train_batches + raise MisconfigurationException( + f"You requested to check {trainer.limit_train_batches} of the `train_dataloader` but" + f" {trainer.limit_train_batches} * {orig_train_batches} < 1. Please increase the" + f" `limit_train_batches` argument. Try at least" + f" `limit_train_batches={min_percentage}`" + ) + def reset(self) -> None: """Resets the internal state of this loop.""" if self.restarting: @@ -209,14 +331,16 @@ def on_run_start(self) -> None: self.epoch_progress.current.completed = self.epoch_progress.current.processed trainer = self.trainer - trainer.reset_train_dataloader(trainer.lightning_module) + # reload the evaluation dataloaders too for proper display in the progress bar - if self.epoch_loop._should_check_val_epoch(): - self.epoch_loop.val_loop._reload_evaluation_dataloaders() + if self.epoch_loop._should_check_val_epoch() and trainer.val_dataloaders is None: + # TODO(carmocca): avoid having to set validating + trainer.validating = True + self.epoch_loop.val_loop.setup_data() + trainer.training = True self._data_fetcher = _select_data_fetcher(trainer) - self._is_fresh_start_epoch = True self._results.to(device=trainer.lightning_module.device) call._call_callback_hooks(trainer, "on_train_start") @@ -226,17 +350,14 @@ def on_run_start(self) -> None: def on_advance_start(self) -> None: """Prepares the dataloader for training and calls the hook ``on_train_epoch_start``""" trainer = self.trainer - model = trainer.lightning_module - # reset train dataloader - if not self._is_fresh_start_epoch and trainer._data_connector._should_reload_train_dl: - log.debug(f"{self.__class__.__name__}: resetting train dataloader") - trainer.reset_train_dataloader(model) - self._is_fresh_start_epoch = False + # might need to setup data again depending on `trainer.reload_dataloaders_every_n_epochs` + self.setup_data() - if trainer.train_dataloader is not None: - assert isinstance(trainer.train_dataloader, CombinedLoader) - _set_sampler_epoch(trainer.train_dataloader, self.epoch_progress.current.processed) + # update the epoch value for all samplers + assert self._combined_loader is not None + for i, dl in enumerate(self._combined_loader._flattened): + _set_sampler_epoch(dl, self.epoch_progress.current.processed) self.epoch_progress.increment_ready() @@ -249,15 +370,17 @@ def on_advance_start(self) -> None: def advance(self) -> None: """Runs one whole epoch.""" - log.debug(f"{self.__class__.__name__}: advancing loop") - - trainer = self.trainer - assert trainer.train_dataloader is not None - dataloader = trainer.train_dataloader + log.debug(f"{type(self).__name__}: advancing loop") + combined_loader = self._combined_loader + assert combined_loader is not None + if combined_loader._mode not in ("max_size_cycle", "min_size"): + raise ValueError( + f'`{type(self).__name__}` only supports the `CombinedLoader(mode="max_size_cycle" | "min_size")` modes.' + ) assert self._data_fetcher is not None - self._data_fetcher.setup(dataloader) - with trainer.profiler.profile("run_training_epoch"): + self._data_fetcher.setup(combined_loader) + with self.trainer.profiler.profile("run_training_epoch"): self.epoch_loop.run(self._data_fetcher) def on_advance_end(self) -> None: diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py new file mode 100644 index 0000000000000..bae51864d9094 --- /dev/null +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -0,0 +1,303 @@ +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Union + +import torch +from lightning_utilities import WarningCache + +import lightning.pytorch as pl +from lightning.fabric.utilities import move_data_to_device +from lightning.pytorch.callbacks import BasePredictionWriter +from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher +from lightning.pytorch.loops.loop import _Loop +from lightning.pytorch.loops.progress import Progress +from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _set_sampler_epoch +from lightning.pytorch.overrides.distributed import IndexBatchSamplerWrapper +from lightning.pytorch.strategies import DDPSpawnStrategy +from lightning.pytorch.trainer import call +from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource +from lightning.pytorch.trainer.states import RunningStage +from lightning.pytorch.trainer.supporters import _Sequential, CombinedLoader +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.types import _PREDICT_OUTPUT + + +class _PredictionLoop(_Loop): + """Top-level loop where prediction starts.""" + + def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None: + super().__init__(trainer) + self.inference_mode = inference_mode + # dataloaders x batches x samples. used by PredictionWriter + self.epoch_batch_indices: List[List[List[int]]] = [] + self.current_batch_indices: List[int] = [] # used by PredictionWriter + self.batch_progress = Progress() # across dataloaders + + self._warning_cache = WarningCache() + self._data_source = _DataLoaderSource(None, "predict_dataloader") + self._combined_loader: Optional[CombinedLoader] = None + self._data_fetcher: Optional[_DataFetcher] = None + self._results = None # for `trainer._results` access + self._predictions: List[List[Any]] = [] # dataloaders x batches + self._return_predictions = False + + @property + def return_predictions(self) -> bool: + """Whether to return the predictions or not.""" + return self._return_predictions + + @return_predictions.setter + def return_predictions(self, return_predictions: Optional[bool] = None) -> None: + # `DDPSpawnStrategy` plugins and derivatives don't support return predictions. + is_ddp_spawn = isinstance(self.trainer.strategy, DDPSpawnStrategy) + if return_predictions and is_ddp_spawn: + raise MisconfigurationException( + "`return_predictions` should be set to `False` when using the `DDPSpawnStrategy` or children class. " + f"Found {return_predictions} with strategy {type(self.trainer.strategy)}." + ) + # For non `DDPSpawnStrategy` plugin, the `return_predictions` is True by default unless user decide otherwise. + self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions + + @property + def predictions(self) -> List[Any]: + """The cached predictions.""" + if self._predictions == []: + return self._predictions + return self._predictions[0] if self.num_dataloaders == 1 else self._predictions + + @property + def num_dataloaders(self) -> int: + """Returns the number of prediction dataloaders.""" + combined_loader = self._combined_loader + assert combined_loader is not None + return len(combined_loader._flattened) + + @property + def max_batches(self) -> List[Union[int, float]]: + """The max number of batches this loop will run for each dataloader.""" + return self.trainer.num_predict_batches + + @property + def skip(self) -> bool: + return sum(self.max_batches) == 0 + + @_no_grad_context + def run(self) -> Optional[_PREDICT_OUTPUT]: + self.setup_data() + if self.skip: + return None + self.reset() + self.on_run_start() + data_fetcher = self._data_fetcher + assert data_fetcher is not None + while True: + try: + batch, batch_idx, dataloader_idx = next(data_fetcher) + self.batch_progress.is_last_batch = data_fetcher.done + self._predict_step(batch, batch_idx, dataloader_idx) + except StopIteration: + # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support + break + finally: + self._restarting = False + return self.on_run_end() + + def setup_data(self) -> None: + trainer = self.trainer + source = self._data_source + pl_module = trainer.lightning_module + # a dfault `predict_step` exists in the LightningModule, so no need to check if it's overridden + if not source.is_defined() or trainer.limit_predict_batches == 0: + return + + trainer.num_predict_batches, iterables = trainer._data_connector._reset_eval_dataloader( + RunningStage.PREDICTING, model=pl_module + ) + combined_loader = CombinedLoader(iterables, "sequential") + for i, dl in enumerate(combined_loader._flattened): + # some users want prediction shuffling based on the training progress + _set_sampler_epoch(dl, trainer.fit_loop.epoch_progress.current.processed) + # allow the strategy to inject logic + dl = trainer.strategy.process_dataloader(dl) + combined_loader._update_index(dl, i) + self._combined_loader = combined_loader + + def reset(self) -> None: + """Resets the internal state of the loop for a new run.""" + self.batch_progress.reset_on_run() + + data_fetcher = _select_data_fetcher(self.trainer) + if isinstance(data_fetcher, _DataLoaderIterDataFetcher) and self.num_dataloaders > 1: + raise NotImplementedError( + "Using `dataloader_iter` in your step method is not supported with multiple dataloaders" + ) + combined_loader = self._combined_loader + assert combined_loader is not None + data_fetcher.setup(combined_loader) + iter(data_fetcher) # creates the iterator inside the fetcher + assert isinstance(combined_loader._iterator, _Sequential) + # set the per-dataloader limits + combined_loader._iterator.limits = self.max_batches + # add the previous `fetched` value to properly track `is_last_batch` with no prefetching + data_fetcher.fetched += self.batch_progress.current.ready + data_fetcher._start_profiler = self._on_before_fetch + data_fetcher._stop_profiler = self._on_after_fetch + self._data_fetcher = data_fetcher + + num_dataloaders = self.num_dataloaders + self.epoch_batch_indices = [[] for _ in range(num_dataloaders)] + self._predictions = [[] for _ in range(num_dataloaders)] + + def on_run_start(self) -> None: + """Calls ``_on_predict_model_eval``, ``_on_predict_start`` and ``_on_predict_epoch_start`` hooks.""" + trainer = self.trainer + call._call_lightning_module_hook(trainer, "on_predict_model_eval") + trainer.lightning_module.zero_grad() + self._on_predict_start() + self._on_predict_epoch_start() + + def on_run_end(self) -> Optional[_PREDICT_OUTPUT]: + """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders.""" + results = self._on_predict_epoch_end() + self._on_predict_end() + return results + + def teardown(self) -> None: + if self._data_fetcher is not None: + self._data_fetcher.teardown() + self._data_fetcher = None + + def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to it. + + Args: + batch: the current batch to run the prediction on + batch_idx: the index of the current batch + dataloader_idx: the index of the dataloader producing the current batch + """ + trainer = self.trainer + batch = trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=dataloader_idx) + batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=dataloader_idx) + + self.batch_progress.increment_ready() + + any_on_epoch = self._store_data_for_prediction_writer(batch_idx, dataloader_idx) + + step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None) + + call._call_callback_hooks(trainer, "on_predict_batch_start", *step_kwargs.values()) + call._call_lightning_module_hook(trainer, "on_predict_batch_start", *step_kwargs.values()) + + self.batch_progress.increment_started() + + # configure step_kwargs + predictions = call._call_strategy_hook(trainer, "predict_step", *step_kwargs.values()) + + self.batch_progress.increment_processed() + + if predictions is None: + self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") + + call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *step_kwargs.values()) + call._call_lightning_module_hook(trainer, "on_predict_batch_end", predictions, *step_kwargs.values()) + + self.batch_progress.increment_completed() + + if self._return_predictions or any_on_epoch: + self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu"))) + + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> Dict[str, Any]: + """Assembles the keyword arguments for the ``predict_step`` + + Args: + batch: the current batch to run the prediction on + batch_idx: the index of the current batch + dataloader_idx: the index of the dataloader producing the current batch. None if not multiple dataloaders. + + Returns: + the dictionary containing all the keyboard arguments for the predict step + """ + step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) + if dataloader_idx is not None: + step_kwargs["dataloader_idx"] = dataloader_idx + return step_kwargs + + def _get_batch_indices(self, dataloader: object) -> List[List[int]]: # batches x samples + """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our + :class:`~lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper`.""" + batch_sampler = getattr(dataloader, "batch_sampler", None) + if not isinstance(batch_sampler, IndexBatchSamplerWrapper): + self._warning_cache.warn( + f"Couldn't infer the batch indices fetched from your dataloader: `{type(dataloader).__name__}`" + ) + return [] + seen_batch_indices = batch_sampler.seen_batch_indices + # TODO(carmocca): this could be avoided + # we need to truncate the list because `IndexBatchSamplerWrapper` computes all indices on `__iter__` + seen_batch_indices = seen_batch_indices[: (self.batch_progress.current.completed + 1)] + return seen_batch_indices + + def _store_data_for_prediction_writer(self, batch_idx: int, dataloader_idx: int) -> bool: + prediction_writers = [cb for cb in self.trainer.callbacks if isinstance(cb, BasePredictionWriter)] + any_on_epoch = any(cb.interval.on_epoch for cb in prediction_writers) + any_on_batch = any(cb.interval.on_batch for cb in prediction_writers) + if any_on_batch or any_on_epoch: + combined_loader = self._combined_loader + assert combined_loader is not None + dataloader = combined_loader._flattened[dataloader_idx] + batch_indices = self._get_batch_indices(dataloader) + if not batch_indices: + # this is only available with `IndexBatchSamplerWrapper`, but it's only used on DataLoaders, if this is + # reached, it's likely because a non-DataLoader was passed + return any_on_epoch + batch_indices = batch_indices[batch_idx] + if any_on_epoch: + self.epoch_batch_indices[dataloader_idx].append(batch_indices) + if any_on_batch: + self.current_batch_indices = batch_indices + return any_on_epoch + + def _on_before_fetch(self) -> None: + self.trainer.profiler.start(f"[{type(self).__name__}].predict_next") + + def _on_after_fetch(self) -> None: + # the dataloader_idx cannot be easily included here because it might be different from the index used on + # profiler start, since the `__next__` call might use a different iterator + self.trainer.profiler.stop(f"[{type(self).__name__}].predict_next") + + def _on_predict_start(self) -> None: + """Calls ``on_predict_start`` hooks.""" + trainer = self.trainer + call._call_callback_hooks(trainer, "on_predict_start") + call._call_lightning_module_hook(trainer, "on_predict_start") + call._call_strategy_hook(trainer, "on_predict_start") + + def _on_predict_epoch_start(self) -> None: + """Calls ``on_predict_epoch_start`` hooks.""" + trainer = self.trainer + call._call_callback_hooks(trainer, "on_predict_epoch_start") + call._call_lightning_module_hook(trainer, "on_predict_epoch_start") + + def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: + """Calls ``on_predict_epoch_end`` hook. + + Returns: + the results for all dataloaders + """ + trainer = self.trainer + call._call_callback_hooks(trainer, "on_predict_epoch_end") + call._call_lightning_module_hook(trainer, "on_predict_epoch_end") + + if self.return_predictions: + return self.predictions + + def _on_predict_end(self) -> None: + """Resets previous gradient status and calls ``on_predict_end`` hook.""" + if not self.return_predictions: + self._predictions = [] + self.epoch_batch_indices = [] + + trainer = self.trainer + # hook + call._call_callback_hooks(trainer, "on_predict_end") + call._call_lightning_module_hook(trainer, "on_predict_end") + call._call_strategy_hook(trainer, "on_predict_end") diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index b6f61b75036b6..0d6aa9182307e 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -144,6 +144,8 @@ def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher: step_fx_name = "training_step" elif trainer.validating or trainer.sanity_checking: step_fx_name = "validation_step" + elif trainer.predicting: + step_fx_name = "predict_step" else: raise RuntimeError(f"DataFetcher is unsupported for {trainer.state.stage}") step_fx = getattr(lightning_module, step_fx_name) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 4b7a806cbf95e..d3d0545e05865 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -622,10 +622,10 @@ def _auto_select_batch_size(self) -> int: # by default we try to use the batch size of the loader assert self.lightning_module is not None batch_size = 1 - train_dl_source = self.lightning_module.trainer._data_connector._train_dataloader_source - if train_dl_source.is_defined(): + data_source = self.lightning_module.trainer.fit_loop._data_source + if data_source.is_defined(): try: - train_dataloader = train_dl_source.dataloader() + train_dataloader = data_source.dataloader() if hasattr(train_dataloader, "batch_sampler"): batch_size = train_dataloader.batch_sampler.batch_size # type: ignore[union-attr] # broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup` diff --git a/src/lightning/pytorch/trainer/configuration_validator.py b/src/lightning/pytorch/trainer/configuration_validator.py index c87b27b70966c..bb486f3d408ca 100644 --- a/src/lightning/pytorch/trainer/configuration_validator.py +++ b/src/lightning/pytorch/trainer/configuration_validator.py @@ -64,7 +64,7 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh ) # verify minimum validation requirements - has_val_loader = trainer._data_connector._val_dataloader_source.is_defined() + has_val_loader = trainer.fit_loop.epoch_loop.val_loop._data_source.is_defined() has_val_step = is_overridden("validation_step", model) if has_val_loader and not has_val_step: rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.") diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 8f16fb82a0f8c..5395adc7012f4 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -44,11 +44,6 @@ class DataConnector: def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: _LITERAL_SUPPORTED_MODES = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - self._train_dataloader_source = _DataLoaderSource(None, "") - self._val_dataloader_source = _DataLoaderSource(None, "") - self._test_dataloader_source = _DataLoaderSource(None, "") - self._predict_dataloader_source = _DataLoaderSource(None, "") - self._datahook_selector: Optional[_DataHookSelector] = None @property @@ -61,7 +56,7 @@ def _should_reload_train_dl(self) -> bool: def _should_reload_val_dl(self) -> bool: """Check if validation dataloader should be reloaded.""" n_epochs = self.trainer.reload_dataloaders_every_n_epochs - return n_epochs and self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs + return bool(n_epochs and self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs) def on_trainer_init( self, @@ -135,18 +130,21 @@ def attach_data( ) self.attach_datamodule(model, datamodule=datamodule) + trainer = self.trainer + fn = trainer.state.fn # Validate that the required data sources are available - if self.trainer.state.fn == TrainerFn.FITTING: - _check_dataloader_none(train_dataloaders, self._train_dataloader_source, self.trainer.state.fn) - elif self.trainer.state.fn == TrainerFn.VALIDATING: - _check_dataloader_none(val_dataloaders, self._val_dataloader_source, self.trainer.state.fn) - elif self.trainer.state.fn == TrainerFn.TESTING: - _check_dataloader_none(test_dataloaders, self._test_dataloader_source, self.trainer.state.fn) - elif self.trainer.state.fn == TrainerFn.PREDICTING: - _check_dataloader_none(predict_dataloaders, self._predict_dataloader_source, self.trainer.state.fn) + if fn == TrainerFn.FITTING: + _check_dataloader_none(train_dataloaders, trainer.fit_loop._data_source, fn) + # TODO(carmocca): fit's validation dataloaders should be checked too + elif fn == TrainerFn.VALIDATING: + _check_dataloader_none(val_dataloaders, trainer.validate_loop._data_source, fn) + elif fn == TrainerFn.TESTING: + _check_dataloader_none(test_dataloaders, trainer.test_loop._data_source, fn) + elif fn == TrainerFn.PREDICTING: + _check_dataloader_none(predict_dataloaders, trainer.predict_loop._data_source, fn) # Attach the trainer to the LightningModule - model.trainer = proxy(self.trainer) + model.trainer = proxy(trainer) def attach_dataloaders( self, @@ -156,23 +154,24 @@ def attach_dataloaders( test_dataloaders: Optional[EVAL_DATALOADERS] = None, predict_dataloaders: Optional[EVAL_DATALOADERS] = None, ) -> None: - self.trainer.train_dataloader = None - self.trainer.val_dataloaders = None - self.trainer.test_dataloaders = None - self.trainer.predict_dataloaders = None + trainer = self.trainer - self._train_dataloader_source = _DataLoaderSource( - train_dataloaders if train_dataloaders is not None else model, "train_dataloader" - ) - self._val_dataloader_source = _DataLoaderSource( - val_dataloaders if val_dataloaders is not None else model, "val_dataloader" - ) - self._test_dataloader_source = _DataLoaderSource( - test_dataloaders if test_dataloaders is not None else model, "test_dataloader" - ) - self._predict_dataloader_source = _DataLoaderSource( - predict_dataloaders if predict_dataloaders is not None else model, "predict_dataloader" + trainer.fit_loop._combined_loader = None + trainer.fit_loop.epoch_loop.val_loop._combined_loader = None + trainer.validate_loop._combined_loader = None + trainer.test_loop._combined_loader = None + trainer.predict_loop._combined_loader = None + + trainer.fit_loop._data_source.instance = train_dataloaders if train_dataloaders is not None else model + trainer.fit_loop.epoch_loop.val_loop._data_source.instance = ( + val_dataloaders if val_dataloaders is not None else model ) + trainer.fit_loop.epoch_loop.val_loop._data_source.name = "val_dataloader" + trainer.validate_loop._data_source.instance = val_dataloaders if val_dataloaders is not None else model + trainer.validate_loop._data_source.name = "val_dataloader" + trainer.test_loop._data_source.instance = test_dataloaders if test_dataloaders is not None else model + trainer.test_loop._data_source.name = "test_dataloader" + trainer.predict_loop._data_source.instance = predict_dataloaders if predict_dataloaders is not None else model def attach_datamodule( self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None @@ -183,13 +182,18 @@ def attach_datamodule( if datamodule is None: return - self._train_dataloader_source = _DataLoaderSource(datamodule, "train_dataloader") - self._val_dataloader_source = _DataLoaderSource(datamodule, "val_dataloader") - self._test_dataloader_source = _DataLoaderSource(datamodule, "test_dataloader") - self._predict_dataloader_source = _DataLoaderSource(datamodule, "predict_dataloader") - - self.trainer.datamodule = datamodule - datamodule.trainer = self.trainer + trainer = self.trainer + trainer.fit_loop._data_source.instance = datamodule + trainer.fit_loop.epoch_loop.val_loop._data_source.instance = datamodule + trainer.fit_loop.epoch_loop.val_loop._data_source.name = "val_dataloader" + trainer.validate_loop._data_source.instance = datamodule + trainer.validate_loop._data_source.name = "val_dataloader" + trainer.test_loop._data_source.instance = datamodule + trainer.test_loop._data_source.name = "test_dataloader" + trainer.predict_loop._data_source.instance = datamodule + + trainer.datamodule = datamodule + datamodule.trainer = trainer def _worker_check(self, dataloader: DataLoader, name: str) -> None: if not isinstance(dataloader, DataLoader): @@ -267,8 +271,6 @@ def _prepare_dataloader( sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) dataloader = _update_dataloader(dataloader, sampler, mode=mode) - dataloader = self.trainer.strategy.process_dataloader(dataloader) - return dataloader def _resolve_sampler( @@ -331,14 +333,13 @@ def _reset_eval_dataloader( Returns: Tuple (num_batches, dataloaders) """ - assert mode.evaluating or mode == RunningStage.PREDICTING - # always get the loaders first so we can count how many there are - dataloaders = self._request_dataloader(mode) + dataloaders = self._request_dataloader() if self.trainer.overfit_batches > 0: dataloaders = self._resolve_overfit_batches(dataloaders, mode) + # TODO(carmocca): list conversion shouldn't be forced if not isinstance(dataloaders, list): dataloaders = [dataloaders] # type: ignore[assignment] @@ -410,20 +411,22 @@ def _reset_eval_dataloader( return loader_num_batches, dataloaders - def _request_dataloader(self, stage: RunningStage) -> TRAIN_DATALOADERS: + def _request_dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: """Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage. Returns: The requested dataloader """ - source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source") + loop = self.trainer._active_loop + if loop is None: + raise RuntimeError("No active loop running") with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler): # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as # attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning. # Also, it records all attribute setting and deletion using patched `__setattr__` and `__delattr__` # methods so that the re-instantiated object is as close to the original as possible. - dataloader = source.dataloader() + dataloader = loop._data_source.dataloader() if isinstance(dataloader, tuple): dataloader = list(dataloader) self.trainer.strategy.barrier("get_dataloaders") @@ -477,8 +480,8 @@ class _DataLoaderSource: The source can be - 1. from a ``*_datalaoder()`` method on the :class:`~lightning.pytorch.core.module.LightningModule`, - 2. from a ``*_datalaoder()`` method on the :class:`~lightning.pytorch.core.datamodule.LightningDataModule`, + 1. from a ``*_dataloader()`` method on the :class:`~lightning.pytorch.core.module.LightningModule`, + 2. from a ``*_dataloader()`` method on the :class:`~lightning.pytorch.core.datamodule.LightningDataModule`, 3. a direct instance of a :class:`~torch.utils.data.DataLoader` or supported collections thereof. Arguments: diff --git a/src/lightning/pytorch/trainer/supporters.py b/src/lightning/pytorch/trainer/supporters.py index ffa56538adc60..ef062e3c7a745 100644 --- a/src/lightning/pytorch/trainer/supporters.py +++ b/src/lightning/pytorch/trainer/supporters.py @@ -281,6 +281,7 @@ def reset(self) -> None: def _update_index(self, dataloader: Iterable, index: int) -> None: # mutation needs to be done using this method to avoid stale references + # TODO(carmocca): avoid this, inefficient self._flattened[index] = dataloader self._iterables = tree_unflatten(self._flattened, self._spec) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index d94d660e4f1e8..85db7aaed8ea7 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -29,16 +29,12 @@ from weakref import proxy import torch -from lightning_utilities.core.apply_func import apply_to_collection from torch.optim import Optimizer -from torch.utils.data import DataLoader import lightning.pytorch as pl from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars from lightning.fabric.utilities.cloud_io import get_filesystem -from lightning.fabric.utilities.data import _auto_add_worker_init_fn from lightning.fabric.utilities.types import _PATH -from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.accelerators import Accelerator from lightning.pytorch.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase from lightning.pytorch.core.datamodule import LightningDataModule @@ -46,7 +42,7 @@ from lightning.pytorch.loggers.tensorboard import TensorBoardLogger from lightning.pytorch.loggers.utilities import _log_hyperparams from lightning.pytorch.loops import _PredictionLoop, _TrainingEpochLoop -from lightning.pytorch.loops.dataloader.evaluation_loop import _EvaluationLoop +from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.loops.utilities import _parse_loop_limits, _reset_progress from lightning.pytorch.plugins import PLUGIN_INPUT, PrecisionPlugin @@ -67,14 +63,13 @@ from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _PBAR_DICT, _ResultCollection from lightning.pytorch.trainer.connectors.signal_connector import SignalConnector from lightning.pytorch.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus -from lightning.pytorch.trainer.supporters import _LITERAL_SUPPORTED_MODES, CombinedLoader +from lightning.pytorch.trainer.supporters import _LITERAL_SUPPORTED_MODES from lightning.pytorch.utilities import GradClipAlgorithmType, parsing from lightning.pytorch.utilities.argparse import _defaults_from_env_vars from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile -from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden -from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn +from lightning.pytorch.utilities.rank_zero import rank_zero_info from lightning.pytorch.utilities.seed import isolate_rng from lightning.pytorch.utilities.types import ( _EVALUATE_OUTPUT, @@ -392,18 +387,14 @@ def _setup_on_init(self) -> None: self.should_stop = False self.state = TrainerState() - self.num_training_batches = float("inf") - - self.train_dataloader: Optional[Union[CombinedLoader, TRAIN_DATALOADERS]] = None + # TODO(carmocca): move these to the loops + self.num_training_batches = float("inf") self.num_sanity_val_batches: List[Union[int, float]] = [] self.num_test_batches: List[Union[int, float]] = [] self.num_val_batches: List[Union[int, float]] = [] self.num_predict_batches: List[Union[int, float]] = [] - self.test_dataloaders: Optional[List[DataLoader]] = None - self.val_dataloaders: Optional[List[DataLoader]] = None - self.predict_dataloaders: Optional[List[DataLoader]] = None self._last_train_dl_reload_epoch = float("-inf") self._last_val_dl_reload_epoch = float("-inf") @@ -813,13 +804,10 @@ def _run( | || {self._run_stage} || FLOW | || - {self._run_train} || DIRECTION - or {self._run_evaluate} || - or {self._run_predict} || + loops || DIRECTION | || results \/ This is used to guide readers to the core loops: train, test, predict. - {self._run_predict} is the simplest to understand, use `Go to Definition` to read it :) """ # ---------------------------- @@ -887,7 +875,7 @@ def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: if self.evaluating: return self._run_evaluate() if self.predicting: - return self._run_predict() + return self.predict_loop.run() self._run_train() def _pre_training_routine(self) -> None: @@ -914,19 +902,12 @@ def _run_train(self) -> None: def _run_evaluate(self) -> _EVALUATE_OUTPUT: assert self.evaluating - # reload dataloaders - self._evaluation_loop._reload_evaluation_dataloaders() - with self.profiler.profile(f"run_{self.state.stage}_evaluation"): eval_loop_results = self._evaluation_loop.run() # remove the tensors from the eval results return convert_tensors_to_scalars(eval_loop_results) - def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: - self.reset_predict_dataloader(self.lightning_module) - return self.predict_loop.run() - def _run_sanity_check(self) -> None: val_loop = self.fit_loop.epoch_loop.val_loop @@ -949,12 +930,6 @@ def _run_sanity_check(self) -> None: call._call_callback_hooks(self, "on_sanity_check_start") - # reload dataloaders - val_loop._reload_evaluation_dataloaders() - self.num_sanity_val_batches = [ - min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches - ] - # run eval step val_loop.run() @@ -977,173 +952,6 @@ def __setup_profiler(self) -> None: self.profiler._lightning_module = proxy(self.lightning_module) self.profiler.setup(stage=self.state.fn, local_rank=local_rank, log_dir=self.log_dir) - """ - Data loading methods - """ - - def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the train dataloader and initialises required variables (number of batches, when to validate, - etc.). - - Args: - model: The ``LightningModule`` if calling this outside of the trainer scope. - """ - source = self._data_connector._train_dataloader_source - pl_module = model or self.lightning_module - has_step = is_overridden("training_step", pl_module) - enable_training = self.limit_train_batches > 0 - if not (source.is_defined() and has_step and enable_training): - return - - self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING) - - if self.overfit_batches > 0: - self.train_dataloader = self._data_connector._resolve_overfit_batches( - self.train_dataloader, mode=RunningStage.TRAINING - ) - - # automatically add samplers - self.train_dataloader = apply_to_collection( - self.train_dataloader, - (DataLoader, CombinedLoader), - self._data_connector._prepare_dataloader, - mode=RunningStage.TRAINING, - ) - loaders = ( - self.train_dataloader.iterables - if isinstance(self.train_dataloader, CombinedLoader) - else self.train_dataloader - ) - - # check the workers recursively - apply_to_collection(loaders, DataLoader, self._data_connector._worker_check, "train_dataloader") - - # add worker_init_fn for correct seeding in worker processes - apply_to_collection(loaders, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank) - - # wrap the sequence of train iterables to a CombinedLoader object for computing the num_training_batches - if not isinstance(self.train_dataloader, CombinedLoader): - self.train_dataloader = CombinedLoader(loaders, self._data_connector.multiple_trainloader_mode) - - module = model or self.lightning_module or self.datamodule - orig_train_batches = self.num_training_batches = ( - len(self.train_dataloader) - if has_len_all_ranks(self.train_dataloader, self.strategy, module) - else float("inf") - ) - if orig_train_batches == 0: - return - - # store epoch of dataloader reset for reload_dataloaders_every_n_epochs - self._last_train_dl_reload_epoch = self.current_epoch - - if isinstance(self.limit_train_batches, int): - self.num_training_batches = min(orig_train_batches, self.limit_train_batches) - elif self.num_training_batches != float("inf"): - self.num_training_batches = int(orig_train_batches * self.limit_train_batches) - elif self.limit_train_batches != 1.0: - raise MisconfigurationException( - "When using an `IterableDataset`, `Trainer(limit_train_batches)` must be `1.0` or an int." - "An int specifies `num_training_batches` to use." - ) - - if isinstance(self.val_check_interval, int): - self.val_check_batch = self.val_check_interval - if self.val_check_batch > self.num_training_batches and self.check_val_every_n_epoch is not None: - raise ValueError( - f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " - f"to the number of the training batches ({self.num_training_batches}). " - "If you want to disable validation set `limit_val_batches` to 0.0 instead." - "If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`." - ) - else: - if not has_len_all_ranks(self.train_dataloader, self.strategy, module): - if self.val_check_interval == 1.0: - self.val_check_batch = float("inf") - else: - raise MisconfigurationException( - "When using an IterableDataset for `train_dataloader`," - " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" - " checking validation every k training batches." - ) - else: - self.val_check_batch = int(self.num_training_batches * self.val_check_interval) - self.val_check_batch = max(1, self.val_check_batch) - - if self.loggers and self.num_training_batches < self.log_every_n_steps: - rank_zero_warn( - f"The number of training batches ({self.num_training_batches}) is smaller than the logging interval" - f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" - " you want to see logs for the training epoch.", - category=PossibleUserWarning, - ) - - if ( - self.num_training_batches == 0 - and self.limit_train_batches > 0.0 - and isinstance(self.limit_train_batches, float) - and orig_train_batches != float("inf") - ): - min_percentage = 1.0 / orig_train_batches - raise MisconfigurationException( - f"You requested to check {self.limit_train_batches} of the `train_dataloader` but" - f" {self.limit_train_batches} * {orig_train_batches} < 1. Please increase the" - f" `limit_train_batches` argument. Try at least" - f" `limit_train_batches={min_percentage}`" - ) - - def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the validation dataloader and determines the number of batches. - - Args: - model: The ``LightningModule`` if called outside of the trainer scope. - """ - source = self._data_connector._val_dataloader_source - pl_module = self.lightning_module or model - has_step = is_overridden("validation_step", pl_module) - enable_validation = self.limit_val_batches > 0 - if source.is_defined() and has_step and enable_validation: - # store epoch of dataloader reset for reload_dataloaders_every_n_epochs - # it should not reload again if it has already reloaded during sanity_check - if self.state.fn == TrainerFn.FITTING and ( - (self.sanity_checking and self.fit_loop.epoch_loop._should_check_val_epoch()) - or not self.sanity_checking - ): - self._last_val_dl_reload_epoch = self.current_epoch - - self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader( - RunningStage.VALIDATING, model=pl_module - ) - - def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the test dataloader and determines the number of batches. - - Args: - model: The ``LightningModule`` if called outside of the trainer scope. - """ - source = self._data_connector._test_dataloader_source - pl_module = self.lightning_module or model - has_step = is_overridden("test_step", pl_module) - enable_testing = self.limit_test_batches > 0 - if source.is_defined() and has_step and enable_testing: - self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader( - RunningStage.TESTING, model=pl_module - ) - - def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the predict dataloader and determines the number of batches. - - Args: - model: The ``LightningModule`` if called outside of the trainer scope. - """ - source = self._data_connector._predict_dataloader_source - pl_module = self.lightning_module or model - enable_prediction = self.limit_predict_batches > 0 - if source.is_defined() and enable_prediction: - self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader( - RunningStage.PREDICTING, model=pl_module - ) - """ Accelerator properties """ @@ -1271,7 +1079,7 @@ def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: def enable_validation(self) -> bool: """Check if we should run validation during training.""" return ( - self._data_connector._val_dataloader_source.is_defined() + self.fit_loop.epoch_loop.val_loop._data_source.is_defined() and is_overridden("validation_step", self.lightning_module) and self.limit_val_batches > 0 ) @@ -1479,6 +1287,28 @@ def is_last_batch(self) -> bool: """Whether trainer is executing the last batch.""" return self.fit_loop.epoch_loop.batch_progress.is_last_batch + @property + def train_dataloader(self) -> TRAIN_DATALOADERS: + if (combined_loader := self.fit_loop._combined_loader) is not None: + return combined_loader.iterables + + @property + def val_dataloaders(self) -> EVAL_DATALOADERS: + if (combined_loader := self.fit_loop.epoch_loop.val_loop._combined_loader) is not None: + return combined_loader.iterables + elif (combined_loader := self.validate_loop._combined_loader) is not None: + return combined_loader.iterables + + @property + def test_dataloaders(self) -> EVAL_DATALOADERS: + if (combined_loader := self.test_loop._combined_loader) is not None: + return combined_loader.iterables + + @property + def predict_dataloaders(self) -> EVAL_DATALOADERS: + if (combined_loader := self.predict_loop._combined_loader) is not None: + return combined_loader.iterables + @property def _evaluation_loop(self) -> _EvaluationLoop: if self.state.fn == TrainerFn.FITTING: @@ -1565,7 +1395,10 @@ def configure_optimizers(self): if self.train_dataloader is None: rank_zero_info("Loading `train_dataloader` to estimate number of stepping batches.") - self.reset_train_dataloader() + stage = self.state.stage + self.training = True + self.fit_loop.setup_data() + self.state.stage = stage total_batches = self.num_training_batches diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 00413ec8242be..dd590cc34701c 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -27,7 +27,6 @@ def _scale_batch_size( trainer: "pl.Trainer", - model: "pl.LightningModule", mode: str = "power", steps_per_trial: int = 3, init_val: int = 2, @@ -39,7 +38,6 @@ def _scale_batch_size( Args: trainer: A Trainer instance. - model: Model to tune. mode: Search strategy to update the batch size: - ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error. @@ -80,9 +78,9 @@ def _scale_batch_size( new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) if mode == "power": - new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, params) + new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params) elif mode == "binsearch": - new_size = _run_binary_scaling(trainer, model, new_size, batch_arg_name, max_trials, params) + new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params) garbage_collection_cuda() @@ -104,19 +102,17 @@ def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: "loggers": trainer.loggers, "callbacks": trainer.callbacks, } - if trainer.state.fn == "fit": - loop = trainer.fit_loop + loop = trainer._active_loop + assert loop is not None + if isinstance(loop, pl.loops._FitLoop): dumped_params["max_steps"] = trainer.max_steps dumped_params["limit_train_batches"] = trainer.limit_train_batches dumped_params["limit_val_batches"] = trainer.limit_val_batches - else: + elif isinstance(loop, pl.loops._EvaluationLoop): stage = trainer.state.stage - loop = getattr(trainer, f"{stage}_loop") assert stage is not None dumped_params["limit_eval_batches"] = getattr(trainer, f"limit_{stage.dataloader_prefix}_batches") - - if hasattr(loop, "verbose"): - dumped_params["loop_verbose"] = loop.verbose + dumped_params["loop_verbose"] = loop.verbose dumped_params["loop_state_dict"] = deepcopy(loop.state_dict()) return dumped_params @@ -128,18 +124,17 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N trainer.logger = DummyLogger() if trainer.logger is not None else None trainer.callbacks = [] - if trainer.state.fn == "fit": + loop = trainer._active_loop + assert loop is not None + if isinstance(loop, pl.loops._FitLoop): trainer.limit_train_batches = 1.0 trainer.limit_val_batches = steps_per_trial trainer.fit_loop.max_steps = steps_per_trial - else: + elif isinstance(loop, pl.loops._EvaluationLoop): stage = trainer.state.stage - loop = getattr(trainer, f"{stage}_loop") assert stage is not None setattr(trainer, f"limit_{stage.dataloader_prefix}_batches", steps_per_trial) - - if hasattr(loop, "verbose"): - loop.verbose = False + loop.verbose = False def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: @@ -147,26 +142,29 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) trainer.loggers = params["loggers"] trainer.callbacks = params["callbacks"] - if trainer.state.fn == "fit": - loop = trainer.fit_loop + loop = trainer._active_loop + assert loop is not None + if isinstance(loop, pl.loops._FitLoop): loop.max_steps = params["max_steps"] trainer.limit_train_batches = params["limit_train_batches"] trainer.limit_val_batches = params["limit_val_batches"] - else: + elif isinstance(loop, pl.loops._EvaluationLoop): stage = trainer.state.stage - loop = getattr(trainer, f"{stage}_loop") assert stage is not None setattr(trainer, f"limit_{stage.dataloader_prefix}_batches", params["limit_eval_batches"]) loop.load_state_dict(deepcopy(params["loop_state_dict"])) loop.restarting = False - if "loop_verbose" in params: + if isinstance(loop, pl.loops._EvaluationLoop) and "loop_verbose" in params: loop.verbose = params["loop_verbose"] + # make sure the loop's state is reset + _reset_dataloaders(trainer) + loop.reset() + def _run_power_scaling( trainer: "pl.Trainer", - pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, @@ -190,7 +188,7 @@ def _run_power_scaling( break # Force the train dataloader to reset as the batch size has changed - _reset_dataloaders(trainer, pl_module) + _reset_dataloaders(trainer) any_success = True except RuntimeError as exception: if is_oom_error(exception): @@ -198,7 +196,7 @@ def _run_power_scaling( garbage_collection_cuda() new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc="failed") # Force the train dataloader to reset as the batch size has changed - _reset_dataloaders(trainer, pl_module) + _reset_dataloaders(trainer) if any_success: break else: @@ -209,7 +207,6 @@ def _run_power_scaling( def _run_binary_scaling( trainer: "pl.Trainer", - pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, @@ -249,7 +246,7 @@ def _run_binary_scaling( break # Force the train dataloader to reset as the batch size has changed - _reset_dataloaders(trainer, pl_module) + _reset_dataloaders(trainer) except RuntimeError as exception: # Only these errors should trigger an adjustment @@ -262,7 +259,7 @@ def _run_binary_scaling( new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed") # Force the train dataloader to reset as the batch size has changed - _reset_dataloaders(trainer, pl_module) + _reset_dataloaders(trainer) if high - low <= 1: break @@ -300,28 +297,13 @@ def _adjust_batch_size( if desc: rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") - if trainer.state.fn == "fit": - from lightning.pytorch.trainer.supporters import CombinedLoader - - if trainer.train_dataloader is None: - trainer.reset_train_dataloader() - - assert isinstance(trainer.train_dataloader, CombinedLoader) - if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer): - # at this moment, `train_dataloader` is already a CombinedLoader. len can return a size or infinity - new_size = min(new_size, len(trainer.train_dataloader.dataset)) # type: ignore[arg-type] - else: - stage = trainer.state.stage - assert stage is not None - dataloaders = getattr(trainer, f"{stage.dataloader_prefix}_dataloaders") - if dataloaders is None: - _reset_dataloaders(trainer, model) - - dataloaders = getattr(trainer, f"{stage.dataloader_prefix}_dataloaders") - assert dataloaders is not None - # TODO: should we consider all the eval dataloaders here? - if not _is_valid_batch_size(new_size, dataloaders[0], trainer): - new_size = min(new_size, len(dataloaders[0].dataset)) + loop = trainer._active_loop + assert loop is not None + loop.setup_data() + combined_loader = loop._combined_loader + assert combined_loader is not None + if not _is_valid_batch_size(new_size, combined_loader, trainer): + new_size = min(new_size, len(combined_loader.dataset)) # type: ignore[arg-type] changed = new_size != batch_size lightning_setattr(model, batch_arg_name, new_size) @@ -336,22 +318,16 @@ def _is_valid_batch_size(batch_size: int, dataloader: Iterable, trainer: "pl.Tra return not has_len or batch_size <= len(dataloader) # type: ignore[arg-type] -def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if trainer.state.fn == "fit": - trainer.reset_train_dataloader(pl_module) - else: - stage = trainer.state.stage - assert stage is not None - reset_fn = getattr(trainer, f"reset_{stage.dataloader_prefix}_dataloader") - reset_fn(pl_module) +def _reset_dataloaders(trainer: "pl.Trainer") -> None: + loop = trainer._active_loop + assert loop is not None + loop._combined_loader = None # force a reload + loop.setup_data() def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: - if trainer.state.fn == "fit": - loop = trainer.fit_loop - else: - loop = getattr(trainer, f"{trainer.state.stage}_loop") - + loop = trainer._active_loop + assert loop is not None loop.load_state_dict(deepcopy(params["loop_state_dict"])) loop.restarting = False loop.run() diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index bccba5ad9f2c1..c8f6e6d93df4c 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -18,11 +18,10 @@ """ from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Dict, Generator, List, Optional, Protocol, runtime_checkable, Sequence, Type, Union +from typing import Any, Dict, Generator, List, Optional, Protocol, runtime_checkable, Type, Union import torch from torch import Tensor -from torch.utils.data import DataLoader from torchmetrics import Metric from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER, LRScheduler, ProcessGroup, ReduceLROnPlateau @@ -32,16 +31,8 @@ STEP_OUTPUT = Union[Tensor, Dict[str, Any]] _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] -TRAIN_DATALOADERS = Union[ - DataLoader, - Sequence[DataLoader], - Sequence[Sequence[DataLoader]], - Sequence[Dict[str, DataLoader]], - Dict[str, DataLoader], - Dict[str, Dict[str, DataLoader]], - Dict[str, Sequence[DataLoader]], -] -EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] +TRAIN_DATALOADERS = Any # any iterable or collection of iterables +EVAL_DATALOADERS = Any # any iterable or collection of iterables @runtime_checkable diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index 01978eaeb9c8e..dd45734d0d818 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -27,7 +27,6 @@ from lightning.pytorch.plugins import IPUPrecisionPlugin from lightning.pytorch.strategies.ipu import IPUStrategy from lightning.pytorch.trainer.states import RunningStage, TrainerFn -from lightning.pytorch.trainer.supporters import CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -365,7 +364,7 @@ def train_dataloader(self): assert isinstance(trainer.strategy, IPUStrategy) assert trainer.strategy.training_opts is other_options - dataloader = trainer.train_dataloader.iterables + dataloader = trainer.train_dataloader assert dataloader is model.poptorch_dataloader # exact object, was not recreated # dataloader uses the options in the model, not the strategy assert dataloader.options is model_options @@ -393,7 +392,7 @@ def test_manual_poptorch_opts(tmpdir): assert trainer.strategy.training_opts == training_opts assert trainer.strategy.inference_opts == inference_opts - dataloader = trainer.train_dataloader.iterables + dataloader = trainer.train_dataloader assert isinstance(dataloader, poptorch.DataLoader) assert dataloader.options == training_opts assert trainer.num_devices > 1 # testing this only makes sense in a distributed setting @@ -426,8 +425,6 @@ def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: val_dataloader = trainer.val_dataloaders[0] train_dataloader = trainer.train_dataloader - assert isinstance(train_dataloader, CombinedLoader) - train_dataloader = train_dataloader.iterables assert isinstance(val_dataloader, poptorch.DataLoader) assert isinstance(train_dataloader, poptorch.DataLoader) assert train_dataloader.options.replication_factor == 2 diff --git a/tests/tests_pytorch/callbacks/test_prediction_writer.py b/tests/tests_pytorch/callbacks/test_prediction_writer.py index 0ab1d279861bc..70312ff24b710 100644 --- a/tests/tests_pytorch/callbacks/test_prediction_writer.py +++ b/tests/tests_pytorch/callbacks/test_prediction_writer.py @@ -16,7 +16,6 @@ import pytest from torch.utils.data import DataLoader -import lightning.pytorch as pl from lightning.pytorch import Trainer from lightning.pytorch.callbacks import BasePredictionWriter from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -111,10 +110,10 @@ def test_prediction_writer_batch_indices(num_workers): def test_prediction_writer_partial_support_for_combined_loader(): """Test partial support for CombinedLoader: prediction works but sample indices don't get tracked.""" - pl.loops.epoch.prediction_epoch_loop.warning_cache.clear() class PredictionModel(BoringModel): def predict_dataloader(self): + # TODO(carmocca): this should work return CombinedLoader( { "a": DataLoader(RandomDataset(32, 8), batch_size=2), @@ -131,7 +130,7 @@ def predict_step(self, batch, *args, **kwargs): model = PredictionModel() writer = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(callbacks=writer) - with pytest.warns(UserWarning, match="Lightning couldn't infer the indices fetched for your dataloader."): + with pytest.warns(UserWarning, match="infer the batch indices fetched from your dataloader: `CombinedLoader"): trainer.predict(model) writer.write_on_batch_end.assert_has_calls( diff --git a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py index ba305a461fd0f..8de5e45350546 100644 --- a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py @@ -39,10 +39,12 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir): trainer = Trainer(**trainer_kwargs) with patch.object( - trainer.fit_loop.epoch_loop.val_loop, "advance", wraps=trainer.fit_loop.epoch_loop.val_loop.advance - ) as advance_mocked: + trainer.fit_loop.epoch_loop.val_loop, + "_evaluation_step", + wraps=trainer.fit_loop.epoch_loop.val_loop._evaluation_step, + ) as step_mock: trainer.fit(model, ckpt_path=ckpt_path) - assert advance_mocked.call_count == 1 + assert step_mock.call_count == 1 @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index 293bde5c08d92..5d189da6d17f5 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -14,6 +14,7 @@ from unittest import mock from unittest.mock import call, Mock +import pytest import torch from torch.utils.data.dataloader import DataLoader from torch.utils.data.sampler import BatchSampler, RandomSampler @@ -23,7 +24,7 @@ from tests_pytorch.helpers.runif import RunIf -@mock.patch("lightning.pytorch.loops.dataloader.evaluation_loop._EvaluationLoop._on_evaluation_epoch_end") +@mock.patch("lightning.pytorch.loops.evaluation_loop._EvaluationLoop._on_evaluation_epoch_end") def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): """Tests that `on_evaluation_epoch_end` is called for `on_validation_epoch_end` and `on_test_epoch_end` hooks.""" @@ -66,13 +67,13 @@ def _get_dataloader(): val_dataloader = _get_dataloader() trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) # One for each epoch - assert train_dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)] + assert train_dataloader.sampler.set_epoch.mock_calls == [call(0), call(1)] # One for each epoch + sanity check - assert val_dataloader.sampler.set_epoch.call_args_list == [call(0), call(0), call(1)] + assert val_dataloader.sampler.set_epoch.mock_calls == [call(0), call(0), call(1)] val_dataloader = _get_dataloader() trainer.validate(model, val_dataloader) - assert val_dataloader.sampler.set_epoch.call_args_list == [call(2)] + assert val_dataloader.sampler.set_epoch.mock_calls == [call(2)] def test_evaluation_loop_batch_sampler_set_epoch_called(tmpdir): @@ -178,3 +179,21 @@ def validation_step(self, batch, batch_idx): enable_model_summary=False, ) trainer.fit(BoringLargeBatchModel()) + + +def test_evaluation_loop_dataloader_iter_multiple_dataloaders(tmp_path): + trainer = Trainer( + default_root_dir=tmp_path, + limit_val_batches=1, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, + ) + + class MyModel(BoringModel): + def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + ... + + model = MyModel() + with pytest.raises(NotImplementedError, match="dataloader_iter.*is not supported with multiple dataloaders"): + trainer.validate(model, {"a": [0, 1], "b": [2, 3]}) diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index 07c6c1507c8d3..c08659ad03652 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -243,8 +243,8 @@ def on_train_epoch_end(self): trainer.fit(model) -@pytest.mark.parametrize("fn", ("validate", "test")) -def test_fetching_dataloader_iter_running_stages(fn, tmpdir): +@pytest.mark.parametrize("fn", ("validate", "test", "predict")) +def test_fetching_dataloader_iter_running_stages(fn, tmp_path): class TestModel(BoringModel): def fetch(self, data_fetcher, dataloader_iter, batch_idx): assert isinstance(data_fetcher, _DataLoaderIterDataFetcher) @@ -254,19 +254,46 @@ def fetch(self, data_fetcher, dataloader_iter, batch_idx): return batch def validation_step(self, dataloader_iter, batch_idx): - batch = self.fetch(self.trainer.validate_loop._data_fetcher, dataloader_iter, batch_idx) + data_fetcher = self.trainer.validate_loop._data_fetcher + batch = self.fetch(data_fetcher, dataloader_iter, batch_idx) return super().validation_step(batch, batch_idx) def test_step(self, dataloader_iter, batch_idx): - batch = self.fetch(self.trainer.test_loop._data_fetcher, dataloader_iter, batch_idx) + data_fetcher = self.trainer.test_loop._data_fetcher + batch = self.fetch(data_fetcher, dataloader_iter, batch_idx) + return super().test_step(batch, batch_idx) + + def predict_step(self, dataloader_iter, batch_idx): + data_fetcher = self.trainer.predict_loop._data_fetcher + batch = self.fetch(data_fetcher, dataloader_iter, batch_idx) return super().test_step(batch, batch_idx) model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - if fn == "validate": - trainer.validate(model) - elif fn == "test": - trainer.test(model) + trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=1) + trainer_fn = getattr(trainer, fn) + trainer_fn(model) + + +@pytest.mark.parametrize("fn", ("validate", "test", "predict")) +def test_fetching_dataloader_iter_running_stages_multiple_dataloaders(fn, tmp_path): + class MyModel(BoringModel): + def validation_step(self, dataloader_iter, batch_idx, dataloader_idx): + ... + + def test_step(self, dataloader_iter, batch_idx, dataloader_idx): + ... + + def predict_step(self, dataloader_iter, batch_idx, dataloader_idx): + ... + + def dataloaders(): + return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] + + model = MyModel() + trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=1) + trainer_fn = getattr(trainer, fn) + with pytest.raises(NotImplementedError, match="dataloader_iter.*is not supported with multiple dataloaders"): + trainer_fn(model, dataloaders()) class DummyWaitable: @@ -463,12 +490,12 @@ def val_dataloader(self): assert isinstance(profiler, SimpleProfiler) # validation - for i in range(2): - key = f"[_EvaluationEpochLoop].val_dataloader_idx_{i}_next" - assert key in profiler.recorded_durations - durations = profiler.recorded_durations[key] - assert len(durations) == fast_dev_run - assert all(d > 0 for d in durations) + key = "[_EvaluationLoop].val_next" + assert key in profiler.recorded_durations + durations = profiler.recorded_durations[key] + # +1 because we fetch one extra batch before breaking the loop when the fast_dev_run condition allows + assert len(durations) == 2 * fast_dev_run + 1 + assert all(d > 0 for d in durations) # training key = "[_TrainingEpochLoop].train_dataloader_next" assert key in profiler.recorded_durations @@ -476,16 +503,16 @@ def val_dataloader(self): assert len(durations) == fast_dev_run assert all(d > 0 for d in durations) # test - key = "[_EvaluationEpochLoop].val_dataloader_idx_0_next" + key = "[_EvaluationLoop].test_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] - assert len(durations) == fast_dev_run + assert len(durations) == fast_dev_run + 1 assert all(d > 0 for d in durations) # predict - key = "[_PredictionEpochLoop].predict_dataloader_idx_0_next" + key = "[_PredictionLoop].predict_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] - assert len(durations) == fast_dev_run + assert len(durations) == fast_dev_run + 1 assert all(d > 0 for d in durations) # now test profiling when the dataloader_iter is polled manually diff --git a/tests/tests_pytorch/loops/test_loop_state_dict.py b/tests/tests_pytorch/loops/test_loop_state_dict.py index 8aec34f3327f0..8d1cd37e6ed0e 100644 --- a/tests/tests_pytorch/loops/test_loop_state_dict.py +++ b/tests/tests_pytorch/loops/test_loop_state_dict.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock from lightning.pytorch.loops import _FitLoop from lightning.pytorch.trainer.trainer import Trainer @@ -31,7 +30,6 @@ def test_loops_state_dict(): def test_loops_state_dict_structure(): trainer = Trainer() - trainer.train_dataloader = Mock() state_dict = trainer._checkpoint_connector._get_loops_state_dict() expected = { "fit_loop": { @@ -62,15 +60,10 @@ def test_loops_state_dict_structure(): }, }, "epoch_loop.val_loop.state_dict": {}, - "epoch_loop.val_loop.dataloader_progress": { - "total": {"ready": 0, "completed": 0}, - "current": {"ready": 0, "completed": 0}, - }, - "epoch_loop.val_loop.epoch_loop.state_dict": {}, - "epoch_loop.val_loop.epoch_loop.batch_progress": { - # number of batches across validation runs per epoch + "epoch_loop.val_loop.batch_progress": { + # number of batches across validation runs per epoch across dataloaders "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - # number of batches for this validation run + # number of batches for this validation run across dataloaders "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "is_last_batch": False, }, @@ -81,21 +74,17 @@ def test_loops_state_dict_structure(): }, "validate_loop": { "state_dict": {}, - "dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, - "epoch_loop.state_dict": {}, - "epoch_loop.batch_progress": { - # total batches run by `validate` + "batch_progress": { + # total batches run by `validate` across dataloaders "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - # number of batches run by this `validate` call + # number of batches run by this `validate` call across dataloaders "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "is_last_batch": False, }, }, "test_loop": { "state_dict": {}, - "dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, - "epoch_loop.state_dict": {}, - "epoch_loop.batch_progress": { + "batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "is_last_batch": False, @@ -103,9 +92,7 @@ def test_loops_state_dict_structure(): }, "predict_loop": { "state_dict": {}, - "dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, - "epoch_loop.state_dict": {}, - "epoch_loop.batch_progress": { + "batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 322c3619a9cc1..ab32150d368cc 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -226,13 +226,6 @@ def val_dataloader(self): ckpt_path = str(tmpdir / "on_exception.ckpt") checkpoint = torch.load(ckpt_path)["loops"]["fit_loop"] - total_dataloader = stop_epoch * n_dataloaders + stop_dataloader - expected = { - "total": {"ready": total_dataloader + 1, "completed": total_dataloader}, - "current": {"ready": stop_dataloader + 1, "completed": stop_dataloader}, - } - assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected - trainer.fit_loop.load_state_dict(checkpoint) # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch @@ -248,14 +241,14 @@ def val_dataloader(self): "completed": total_val_batch, }, "current": { - "ready": stop_batch + 1, - "started": stop_batch + 1, - "processed": stop_batch, - "completed": stop_batch, + "ready": total_val_batch + 1, + "started": total_val_batch + 1, + "processed": total_val_batch, + "completed": total_val_batch, }, "is_last_batch": False, } - assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected + assert trainer.fit_loop.epoch_loop.val_loop.batch_progress.state_dict() == expected @pytest.mark.parametrize("accumulate_grad_batches", (1, 2, 3)) @@ -387,9 +380,7 @@ def training_step(self, batch, batch_idx): }, }, "epoch_loop.val_loop.state_dict": ANY, - "epoch_loop.val_loop.dataloader_progress": ANY, - "epoch_loop.val_loop.epoch_loop.state_dict": ANY, - "epoch_loop.val_loop.epoch_loop.batch_progress": ANY, + "epoch_loop.val_loop.batch_progress": ANY, } assert checkpoint["loops"]["fit_loop"] == expected @@ -527,9 +518,7 @@ def train_dataloader(self): }, }, "epoch_loop.val_loop.state_dict": ANY, - "epoch_loop.val_loop.dataloader_progress": ANY, - "epoch_loop.val_loop.epoch_loop.state_dict": ANY, - "epoch_loop.val_loop.epoch_loop.batch_progress": ANY, + "epoch_loop.val_loop.batch_progress": ANY, } assert checkpoint["loops"]["fit_loop"] == expected @@ -675,19 +664,19 @@ def val_dataloader(self): } val_per_epoch = int(1 // val_check_interval) - assert state_dict["epoch_loop.val_loop.dataloader_progress"] == { - "total": {"ready": n_val_dataloaders * val_per_epoch, "completed": n_val_dataloaders * val_per_epoch}, - "current": {"ready": n_val_dataloaders, "completed": n_val_dataloaders}, - } - - assert state_dict["epoch_loop.val_loop.epoch_loop.batch_progress"] == { + assert state_dict["epoch_loop.val_loop.batch_progress"] == { "total": { "ready": n_val_dataloaders * val_per_epoch * n_batches, "started": n_val_dataloaders * val_per_epoch * n_batches, "processed": n_val_dataloaders * val_per_epoch * n_batches, "completed": n_val_dataloaders * val_per_epoch * n_batches, }, - "current": {"ready": n_batches, "completed": n_batches, "started": n_batches, "processed": n_batches}, + "current": { + "ready": n_val_dataloaders * n_batches, + "started": n_val_dataloaders * n_batches, + "processed": n_val_dataloaders * n_batches, + "completed": n_val_dataloaders * n_batches, + }, "is_last_batch": True, } @@ -724,7 +713,7 @@ def val_dataloader(self): "is_last_batch": val_check_interval == 1, } - val_batch_progress = "epoch_loop.val_loop.epoch_loop.batch_progress" + val_batch_progress = "epoch_loop.val_loop.batch_progress" # "nb_": non-breaking nb_total_val_batch = stop_dataloader * n_batches assert checkpoint[val_batch_progress] == { @@ -735,10 +724,10 @@ def val_dataloader(self): "completed": nb_total_val_batch + stop_batch, }, "current": { - "ready": stop_batch + 1, - "started": stop_batch + 1, - "processed": stop_batch, - "completed": stop_batch, + "ready": nb_total_val_batch + stop_batch + 1, + "started": nb_total_val_batch + stop_batch + 1, + "processed": nb_total_val_batch + stop_batch, + "completed": nb_total_val_batch + stop_batch, }, "is_last_batch": False, } @@ -762,10 +751,6 @@ def val_dataloader(self): assert state_dict_after_restart["epoch_loop.batch_progress"] == expected["epoch_loop.batch_progress"] - val_dl_progress = "epoch_loop.val_loop.dataloader_progress" - expected[val_dl_progress]["total"]["ready"] += 1 - assert state_dict_after_restart[val_dl_progress] == expected[val_dl_progress] - expected[val_batch_progress]["total"]["ready"] += 1 expected[val_batch_progress]["total"]["started"] += 1 assert state_dict_after_restart[val_batch_progress] == expected[val_batch_progress] @@ -825,6 +810,5 @@ def on_train_epoch_end(self, trainer, *_): assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs) # on sanity checking end, the workers are being deleted too. - assert val_dataloader.count_shutdown_workers == 2 if persistent_workers else (3 if should_fail else max_epochs + 1) - assert train_dataloader._iterator is None - assert val_dataloader._iterator is None + expected = 2 if persistent_workers else (3 if should_fail else max_epochs + 1) + assert val_dataloader.count_shutdown_workers == expected diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py index 1b5e05c502e5a..6bf3bdd9f9456 100644 --- a/tests/tests_pytorch/loops/test_prediction_loop.py +++ b/tests/tests_pytorch/loops/test_prediction_loop.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools from unittest import mock from unittest.mock import call +import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel @@ -63,3 +66,37 @@ def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path): with mock.patch("lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper.set_epoch") as set_epoch_mock: trainer.predict(model) assert set_epoch_mock.mock_calls == [call(2)] + + +def test_prediction_loop_with_iterable_dataset(tmp_path): + class MyModel(BoringModel): + def predict_step(self, batch, batch_idx, dataloader_idx=0): + return (batch, batch_idx, dataloader_idx) + + model = MyModel() + trainer = Trainer( + default_root_dir=tmp_path, + limit_predict_batches=3, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, + ) + preds = trainer.predict(model, itertools.count()) + assert preds == [(0, 0, 0), (1, 1, 0), (2, 2, 0)] + + preds = trainer.predict(model, [itertools.count(), itertools.count()]) + assert preds == [[(0, 0, 0), (1, 1, 0), (2, 2, 0)], [(0, 0, 1), (1, 1, 1), (2, 2, 1)]] + + # TODO(carmocca): this shouldn't raise + with pytest.raises(ValueError, match="Mismatch in number of limits"): + trainer.predict(model, {"a": [0, 1], "b": [2, 3]}) + with pytest.raises(ValueError, match="Mismatch in number of limits"): + trainer.predict(model, [0, 1, 2]) + + class MyModel(BoringModel): + def predict_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + ... + + model = MyModel() + with pytest.raises(NotImplementedError, match="dataloader_iter.*is not supported with multiple dataloaders"): + trainer.predict(model, {"a": [0, 1], "b": [2, 3]}) diff --git a/tests/tests_pytorch/strategies/test_single_device_strategy.py b/tests/tests_pytorch/strategies/test_single_device_strategy.py index 5afc2cb5c2783..c27d9322ed67e 100644 --- a/tests/tests_pytorch/strategies/test_single_device_strategy.py +++ b/tests/tests_pytorch/strategies/test_single_device_strategy.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import pickle +from unittest.mock import Mock +import pytest import torch +from torch.utils.data import DataLoader from lightning.pytorch import Trainer from lightning.pytorch.core.optimizer import LightningOptimizer -from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.strategies import SingleDeviceStrategy +from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf @@ -75,3 +79,63 @@ def test_strategy_pickle(): strategy_reloaded = pickle.loads(state) # loading restores the lightning optimizers assert isinstance(strategy_reloaded._lightning_optimizers[0], LightningOptimizer) + + +class BoringModelNoDataloaders(BoringModel): + def train_dataloader(self): + raise NotImplementedError + + def val_dataloader(self): + raise NotImplementedError + + def test_dataloader(self): + raise NotImplementedError + + def predict_dataloader(self): + raise NotImplementedError + + +_loader = DataLoader(RandomDataset(32, 64)) +_loader_no_len = CustomNotImplementedErrorDataloader(_loader) + + +@pytest.mark.parametrize( + ("keyword", "value"), + ( + ("train_dataloaders", _loader_no_len), + ("val_dataloaders", _loader_no_len), + ("test_dataloaders", _loader_no_len), + ("predict_dataloaders", _loader_no_len), + ("val_dataloaders", [_loader, _loader_no_len]), + ), +) +def test_process_dataloader_gets_called_as_expected(keyword, value, monkeypatch): + trainer = Trainer() + model = BoringModelNoDataloaders() + strategy = SingleDeviceStrategy(accelerator=Mock()) + strategy.connect(model) + trainer._accelerator_connector.strategy = strategy + process_dataloader_mock = Mock() + monkeypatch.setattr(strategy, "process_dataloader", process_dataloader_mock) + + if "train" in keyword: + trainer.state.fn = "fit" + trainer.training = True + fn = trainer.fit_loop.setup_data + elif "val" in keyword: + trainer.state.fn = "validate" + trainer.validating = True + fn = trainer.validate_loop.setup_data + elif "test" in keyword: + trainer.state.fn = "test" + trainer.testing = True + fn = trainer.test_loop.setup_data + else: + trainer.predicting = True + fn = trainer.predict_loop.setup_data + + trainer._data_connector.attach_dataloaders(model, **{keyword: value}) + fn() + + expected = len(value) if isinstance(value, list) else 1 + assert process_dataloader_mock.call_count == expected diff --git a/tests/tests_pytorch/strategies/test_xla.py b/tests/tests_pytorch/strategies/test_xla.py index d7724464a5515..83add3849c45d 100644 --- a/tests/tests_pytorch/strategies/test_xla.py +++ b/tests/tests_pytorch/strategies/test_xla.py @@ -26,28 +26,11 @@ from tests_pytorch.helpers.runif import RunIf -class BoringModelNoDataloaders(BoringModel): - def train_dataloader(self): - raise NotImplementedError - - def val_dataloader(self): - raise NotImplementedError - - def test_dataloader(self): - raise NotImplementedError - - def predict_dataloader(self): - raise NotImplementedError - - -_loader = DataLoader(RandomDataset(32, 64)) -_loader_no_len = CustomNotImplementedErrorDataloader(_loader) - - def test_error_process_iterable_dataloader(xla_available): strategy = XLAStrategy(MagicMock()) + loader_no_len = CustomNotImplementedErrorDataloader(DataLoader(RandomDataset(32, 64))) with pytest.raises(TypeError, match="TPUs do not currently support"): - strategy.process_dataloader(_loader_no_len) + strategy.process_dataloader(loader_no_len) class BoringModelTPU(BoringModel): diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index d2651720857b3..7691b23a0bb24 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -116,7 +116,7 @@ def on_fit_start(self): def on_train_end(self): def _get_warning_msg(): - dl = self.trainer.train_dataloader.iterables + dl = self.trainer.train_dataloader if hasattr(dl, "persistent_workers"): if self.num_workers == 0: warn_str = "Consider setting num_workers>0 and persistent_workers=True" @@ -295,7 +295,7 @@ def __iter__(self): class LoaderTestModel(BoringModel): def training_step(self, batch, batch_idx): - assert len(self.trainer.train_dataloader.iterables) == 10 + assert len(self.trainer.train_dataloader) == 10 return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): @@ -368,6 +368,8 @@ def test_error_raised_with_float_limited_eval_batches(): limit_val_batches = 1 / (dl_size + 2) trainer = Trainer(limit_val_batches=limit_val_batches) trainer._data_connector.attach_data(model) + trainer.state.fn = TrainerFn.VALIDATING + trainer.state.stage = RunningStage.VALIDATING with pytest.raises( MisconfigurationException, match=rf"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", @@ -403,6 +405,8 @@ def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl, wa model = BoringModel() trainer._data_connector.attach_data(model, val_dataloaders=val_dl) context = pytest.warns if warns else no_warning_call + trainer.state.fn = TrainerFn.VALIDATING + trainer.state.stage = RunningStage.VALIDATING with context(PossibleUserWarning, match="recommended .* turn shuffling off for val/test/predict"): trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model) @@ -531,15 +535,18 @@ def test_eval_distributed_sampler_warning(devices, warn_context): model = BoringModel() trainer = Trainer(strategy="ddp", devices=devices, accelerator="cpu") + trainer.strategy.connect(model) trainer._data_connector.attach_data(model) trainer.state.fn = TrainerFn.VALIDATING + trainer.state.stage = RunningStage.VALIDATING with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): - trainer.reset_val_dataloader(model) + trainer.validate_loop.setup_data() trainer.state.fn = TrainerFn.TESTING + trainer.state.stage = RunningStage.TESTING with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): - trainer.reset_test_dataloader(model) + trainer.test_loop.setup_data() @pytest.mark.parametrize("shuffle", [True, False]) @@ -552,8 +559,11 @@ def val_dataloader(self): trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp") model = CustomModel() + trainer.strategy.connect(model) trainer._data_connector.attach_data(model) - trainer.reset_val_dataloader(model) + trainer.state.fn = TrainerFn.FITTING + trainer.state.stage = RunningStage.VALIDATING + trainer.fit_loop.epoch_loop.val_loop.setup_data() assert trainer.val_dataloaders[0].sampler.shuffle == shuffle @@ -562,13 +572,15 @@ def test_error_raised_with_insufficient_float_limit_train_dataloader(): dl = DataLoader(RandomDataset(32, batch_size * 9), batch_size=batch_size) trainer = Trainer(limit_train_batches=0.1) model = BoringModel() - + trainer.strategy.connect(model) trainer._data_connector.attach_data(model=model, train_dataloaders=dl) + trainer.state.fn = TrainerFn.FITTING + trainer.state.stage = RunningStage.TRAINING with pytest.raises( MisconfigurationException, match="Please increase the `limit_train_batches` argument. Try at least", ): - trainer.reset_train_dataloader(model) + trainer.fit_loop.setup_data() @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/trainer/flags/test_limit_batches.py b/tests/tests_pytorch/trainer/flags/test_limit_batches.py index 1e38896d03383..1268dd8efd029 100644 --- a/tests/tests_pytorch/trainer/flags/test_limit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_limit_batches.py @@ -62,8 +62,22 @@ def test_eval_limit_batches(stage, mode, limit_batches): trainer = Trainer(**{limit_eval_batches: limit_batches}) model.trainer = trainer + trainer.strategy.connect(model) trainer._data_connector.attach_dataloaders(model) - loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model) + + trainer.state.stage = stage + trainer.state.fn = stage.value + trainer._active_loop.setup_data() + if stage == RunningStage.VALIDATING: + loader_num_batches = trainer.num_val_batches + dataloaders = trainer.val_dataloaders + elif stage == RunningStage.TESTING: + loader_num_batches = trainer.num_test_batches + dataloaders = trainer.test_dataloaders + elif stage == RunningStage.PREDICTING: + loader_num_batches = trainer.num_predict_batches + dataloaders = trainer.predict_dataloaders + expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches assert loader_num_batches[0] == expected_batches assert len(dataloaders[0]) == len(eval_loader) diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index 36a75717c6856..b218539fd7d7f 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -74,7 +74,7 @@ def val_dataloader(self): with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"): trainer.fit(model) - assert isinstance(trainer.train_dataloader.iterables.sampler, SequentialSampler) + assert isinstance(trainer.train_dataloader.sampler, SequentialSampler) assert isinstance(trainer.val_dataloaders[0].sampler, SequentialSampler) @@ -91,18 +91,25 @@ def test_overfit_batch_limits_eval(stage, mode, overfit_batches): eval_loader = getattr(dm, f"{mode}_dataloader")() trainer = Trainer(overfit_batches=overfit_batches) model.trainer = trainer + trainer.strategy.connect(model) trainer._data_connector.attach_datamodule(model, datamodule=dm) - loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model) + trainer.state.stage = stage + trainer.state.fn = stage.value + trainer._active_loop.setup_data() + if stage == RunningStage.VALIDATING: assert ( - loader_num_batches[0] == overfit_batches + trainer.num_val_batches[0] == overfit_batches if isinstance(overfit_batches, int) else len(dm.val_dataloader()) * overfit_batches ) - else: - assert loader_num_batches[0] == len(eval_loader) - assert isinstance(dataloaders[0].sampler, SequentialSampler) + elif stage == RunningStage.TESTING: + assert trainer.num_test_batches[0] == len(eval_loader) + assert isinstance(trainer.test_dataloaders[0].sampler, SequentialSampler) + elif stage == RunningStage.PREDICTING: + assert trainer.num_predict_batches[0] == len(eval_loader) + assert isinstance(trainer.predict_dataloaders[0].sampler, SequentialSampler) @pytest.mark.parametrize("overfit_batches", [0.11, 4]) @@ -135,8 +142,10 @@ def train_dataloader(self): # test train loader applies correct limits trainer = Trainer(overfit_batches=overfit_batches) model.trainer = trainer + trainer.strategy.connect(model) trainer._data_connector.attach_dataloaders(model=model) - trainer.reset_train_dataloader(model) + trainer.training = True + trainer.fit_loop.setup_data() expected_batches = ( int(overfit_batches * full_train_samples) if isinstance(overfit_batches, float) else overfit_batches ) @@ -158,9 +167,10 @@ def test_distributed_sampler_with_overfit_batches(): strategy="ddp_spawn", ) model.trainer = trainer - trainer.strategy._lightning_module = model + trainer.strategy.connect(model) trainer._data_connector.attach_dataloaders(model) - trainer.reset_train_dataloader() - train_sampler = trainer.train_dataloader.iterables.sampler + trainer.training = True + trainer.fit_loop.setup_data() + train_sampler = trainer.train_dataloader.sampler assert isinstance(train_sampler, DistributedSampler) assert train_sampler.shuffle is False diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index d0aadbf756f33..4e31b877bcef7 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -29,7 +29,7 @@ from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.loggers import TensorBoardLogger -from lightning.pytorch.loops.dataloader import _EvaluationLoop +from lightning.pytorch.loops import _EvaluationLoop from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 @@ -843,7 +843,7 @@ def test_dataloader(self): ], ) def test_native_print_results(monkeypatch, inputs, expected): - import lightning.pytorch.loops.dataloader.evaluation_loop as imports + import lightning.pytorch.loops.evaluation_loop as imports monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) @@ -855,7 +855,7 @@ def test_native_print_results(monkeypatch, inputs, expected): @pytest.mark.parametrize("encoding", ["latin-1", "utf-8"]) def test_native_print_results_encodings(monkeypatch, encoding): - import lightning.pytorch.loops.dataloader.evaluation_loop as imports + import lightning.pytorch.loops.evaluation_loop as imports monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) @@ -975,28 +975,33 @@ def test_dataloader(self): ) model = CustomBoringModel() - trainer.fit(model) - trainer.validate(model) - trainer.test(model) - def get_suffix(dl_idx): return f"/dataloader_idx_{dl_idx}" if num_dataloaders == 2 else "" eval_steps = range(limit_batches) + trainer.fit(model) fit_calls = [ call(metrics={f"val_log_fit{get_suffix(dl_idx)}": float(step)}, step=step + (limit_batches * epoch)) for epoch in range(max_epochs) for dl_idx in range(num_dataloaders) for step in eval_steps ] + assert mock_log_metrics.mock_calls == fit_calls + + mock_log_metrics.reset_mock() + trainer.validate(model) validate_calls = [ call(metrics={f"val_log_validate{get_suffix(dl_idx)}": float(val)}, step=val) for dl_idx in range(num_dataloaders) for val in eval_steps ] + assert mock_log_metrics.mock_calls == validate_calls + + mock_log_metrics.reset_mock() + trainer.test(model) test_calls = [ call(metrics={f"test_log{get_suffix(dl_idx)}": float(val)}, step=val) for dl_idx in range(num_dataloaders) for val in eval_steps ] - assert mock_log_metrics.mock_calls == fit_calls + validate_calls + test_calls + assert mock_log_metrics.mock_calls == test_calls diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 740af109dad4a..8109520bd7610 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -46,11 +46,14 @@ def test_num_stepping_batches_raises_info_with_no_dataloaders_loaded(caplog): trainer._data_connector.attach_data(model) trainer.strategy.connect(model) - message = "to estimate number of stepping batches" - trainer.reset_train_dataloader() + # artificially setup the data + trainer.training = True + trainer.fit_loop.setup_data() + with caplog.at_level(logging.INFO): assert trainer.estimated_stepping_batches == 64 + message = "to estimate number of stepping batches" assert message not in caplog.text trainer = Trainer(max_epochs=1) diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 9e6489afc815e..3cc3d78b51e64 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -34,7 +34,6 @@ ) from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.trainer.states import RunningStage -from lightning.pytorch.trainer.supporters import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader @@ -70,7 +69,6 @@ def test_fit_train_loader_only(tmpdir): model.test_dataloader = None model.validation_step = None - model.test_step = None trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) @@ -122,14 +120,10 @@ def test_dataloader_config_errors_init(tmpdir, dataloader_options): def test_multiple_val_dataloader(tmpdir): """Verify multiple val_dataloader.""" - model = MultiValDataLoaderBoringModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.3, limit_train_batches=1.0) trainer.fit(model) - # verify training completed - assert trainer.state.finished, f"Training failed with {trainer.state}" - # verify there are 2 val loaders assert len(trainer.val_dataloaders) == 2, "Multiple val_dataloaders not initiated properly" @@ -160,7 +154,7 @@ def test_train_dataloader_passed_to_fit(tmpdir): fit_options = dict(train_dataloaders=train_loader) trainer.fit(model, **fit_options) assert trainer.num_training_batches == 2 - assert trainer.train_dataloader.iterables == train_loader + assert trainer.train_dataloader == train_loader assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -433,9 +427,9 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v ) with patch.object( - trainer.fit_loop.epoch_loop.val_loop.epoch_loop, + trainer.fit_loop.epoch_loop.val_loop, "_evaluation_step", - wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop._evaluation_step, + wraps=trainer.fit_loop.epoch_loop.val_loop._evaluation_step, ) as mocked: trainer.fit(model) assert trainer.num_training_batches == limit_train_batches @@ -443,9 +437,9 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v assert mocked.call_count == limit_val_batches * len(trainer.val_dataloaders) with patch.object( - trainer.test_loop.epoch_loop, + trainer.test_loop, "_evaluation_step", - wraps=trainer.test_loop.epoch_loop._evaluation_step, + wraps=trainer.test_loop._evaluation_step, ) as mocked: trainer.test(model) test_dataloader_lengths = [len(x) for x in model.test_dataloader()] @@ -515,20 +509,24 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path): assert len(trainer.test_dataloaders) == 1 -def test_warning_on_zero_len_dataloader(tmpdir): +def test_warning_on_zero_len_dataloader(): """Test that a warning is raised if a zero-length dataloader is defined.""" model = BoringModel() trainer = Trainer() + trainer.strategy.connect(model) train_dataloader = DataLoader(RandomDataset(32, 0)) val_dataloader = DataLoader(RandomDataset(32, 0)) trainer._data_connector.attach_data(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) + trainer.training = True with pytest.warns(UserWarning, match="Total length of `CombinedLoader` across ranks is zero"): - trainer.reset_train_dataloader(model) + trainer.fit_loop.setup_data() assert trainer.num_training_batches == 0 + trainer.state.fn = "validate" + trainer.validating = True with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero"): - trainer.reset_val_dataloader(model) + trainer.validate_loop.setup_data() assert trainer.num_val_batches == [0] @@ -1123,7 +1121,8 @@ def validation_step(self, batch, batch_idx): assert tracker.mock_calls == expected_calls -def test_dataloaders_load_only_once_passed_loaders(tmpdir): +@pytest.mark.parametrize("sanity_check", (False, True)) +def test_dataloaders_load_only_once_passed_loaders(tmpdir, sanity_check): model = BoringModel() train_dataloader = model.train_dataloader() val_dataloader = model.val_dataloader() @@ -1134,29 +1133,33 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): model.val_dataloader = None model.test_dataloader = None - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, max_epochs=3) + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0.3, + limit_val_batches=0.3, + max_epochs=3, + num_sanity_val_steps=1 if sanity_check else 0, + ) - trainer.reset_train_dataloader = Mock(wraps=trainer.reset_train_dataloader) - trainer.reset_val_dataloader = Mock(wraps=trainer.reset_val_dataloader) - trainer.reset_test_dataloader = Mock(wraps=trainer.reset_test_dataloader) + stages = [] + original_request_dataloader = trainer._data_connector._request_dataloader - tracker = Mock() - tracker.attach_mock(trainer.reset_train_dataloader, "reset_train_dataloader") - tracker.attach_mock(trainer.reset_val_dataloader, "reset_val_dataloader") - tracker.attach_mock(trainer.reset_test_dataloader, "reset_test_dataloader") + def mock_request_dataloader(): + stages.append(trainer.state.stage) + return original_request_dataloader() + + request_dataloader_mock = Mock(wraps=mock_request_dataloader) + trainer._data_connector._request_dataloader = request_dataloader_mock trainer.fit(model, train_dataloader, val_dataloader) - trainer.test(model, dataloaders=test_dataloader) + assert request_dataloader_mock.call_count == 2 - trainer.reset_train_dataloader.assert_called_once() - trainer.reset_val_dataloader.assert_called_once() - trainer.reset_test_dataloader.assert_called_once() + request_dataloader_mock.reset_mock() + trainer.test(model, dataloaders=test_dataloader) + assert request_dataloader_mock.call_count == 1 - assert tracker.mock_calls == [ - call.reset_val_dataloader(), - call.reset_train_dataloader(model), - call.reset_test_dataloader(), - ] + expected = ["sanity_check", "train", "test"] if sanity_check else ["train", "validate", "test"] + assert stages == expected def test_dataloaders_reset_and_attach(tmpdir): @@ -1178,12 +1181,12 @@ def test_dataloaders_reset_and_attach(tmpdir): # 1st fit trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1) - assert trainer.train_dataloader.iterables.dataset is dataloader_0.dataset + assert trainer.train_dataloader.dataset is dataloader_0.dataset assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset # 2nd fit trainer.fit_loop.max_steps += 1 trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3) - assert trainer.train_dataloader.iterables.dataset is dataloader_2.dataset + assert trainer.train_dataloader.dataset is dataloader_2.dataset assert trainer.val_dataloaders[0].dataset is dataloader_3.dataset # 1st validate @@ -1282,7 +1285,6 @@ def predict_dataloader(self): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=5, multiple_trainloader_mode=multiple_trainloader_mode) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" trainer.test(model) preds = trainer.predict(model) @@ -1316,7 +1318,7 @@ def train_dataloader(self): return DataLoaderWrapper(loader) def on_train_batch_start(self, batch, batch_idx: int) -> None: - assert isinstance(self.trainer.train_dataloader.iterables, DataLoaderWrapper) + assert isinstance(self.trainer.train_dataloader, DataLoaderWrapper) self.on_train_batch_start_called = True def val_dataloader(self): @@ -1341,8 +1343,7 @@ def on_validation_batch_start(self, *_): def test_multiple_dataloaders_with_random_sampler_overfit_batches(num_loaders, tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx): - assert isinstance(self.trainer.train_dataloader, CombinedLoader) - assert all(isinstance(s, SequentialSampler) for s in self.trainer.train_dataloader.sampler) + assert all(isinstance(dl.sampler, SequentialSampler) for dl in self.trainer.train_dataloader) return super().training_step(batch[0], batch_idx) def _create_dataloader(self): diff --git a/tests/tests_pytorch/trainer/test_supporters.py b/tests/tests_pytorch/trainer/test_supporters.py index 08af8ca7148e8..38c6dd1876f55 100644 --- a/tests/tests_pytorch/trainer/test_supporters.py +++ b/tests/tests_pytorch/trainer/test_supporters.py @@ -171,10 +171,6 @@ def test_combined_loader_raises(): with pytest.raises(ValueError, match="Unsupported mode 'testtt'"): CombinedLoader([range(10)], "testtt") - combined_loader = CombinedLoader(None, "max_size_cycle") - with pytest.raises(NotImplementedError, match="NoneType` does not define `__len__"): - len(combined_loader) - class TestIterableDataset(IterableDataset): def __init__(self, size: int = 10): @@ -399,6 +395,7 @@ def test_combined_dataloader_for_training_with_ddp(replace_sampler_ddp, mode, us replace_sampler_ddp=replace_sampler_ddp, multiple_trainloader_mode=mode, ) + trainer.strategy.connect(model) trainer._data_connector.attach_data( model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None ) @@ -409,8 +406,9 @@ def test_combined_dataloader_for_training_with_ddp(replace_sampler_ddp, mode, us if replace_sampler_ddp else expected_length_before_ddp ) - trainer.reset_train_dataloader(model=model) + trainer.state.stage = "train" + trainer.fit_loop.setup_data() assert trainer.train_dataloader is not None - assert isinstance(trainer.train_dataloader, CombinedLoader) - assert trainer.train_dataloader._mode == mode + assert isinstance(trainer.fit_loop._combined_loader, CombinedLoader) + assert trainer.fit_loop._combined_loader._mode == mode assert trainer.num_training_batches == expected_length_after_ddp diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 8d7b044d24a12..1a77eed0c96eb 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1090,17 +1090,7 @@ def test_invalid_gradient_clip_algo(tmpdir): @pytest.mark.parametrize("limit_val_batches", [0.0, 1, 1.0, 0.5, 5]) def test_num_sanity_val_steps(tmpdir, limit_val_batches): """Test that the number of sanity check batches is clipped to `limit_val_batches`.""" - - class CustomModel(BoringModel): - def validation_step(self, batch, batch_idx, dataloader_idx): - return super().validation_step(batch, batch_idx) - - def val_dataloader(self): - return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] - - model = CustomModel() num_sanity_val_steps = 4 - trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=num_sanity_val_steps, @@ -1109,16 +1099,19 @@ def val_dataloader(self): ) assert trainer.num_sanity_val_steps == num_sanity_val_steps - class CustomModelMixedVal(CustomModel): + class CustomModelMixedVal(BoringModel): + def validation_step(self, batch, batch_idx, dataloader_idx): + return super().validation_step(batch, batch_idx) + def val_dataloader(self): return [DataLoader(RandomDataset(32, 64), batch_size=8), DataLoader(RandomDataset(32, 64))] model = CustomModelMixedVal() with patch.object( - trainer.fit_loop.epoch_loop.val_loop.epoch_loop, + trainer.fit_loop.epoch_loop.val_loop, "_evaluation_step", - wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop._evaluation_step, + wraps=trainer.fit_loop.epoch_loop.val_loop._evaluation_step, ) as mocked: trainer.fit(model) assert mocked.call_count == sum( @@ -1145,9 +1138,9 @@ def val_dataloader(self): assert trainer.num_sanity_val_steps == float("inf") with patch.object( - trainer.fit_loop.epoch_loop.val_loop.epoch_loop, + trainer.fit_loop.epoch_loop.val_loop, "_evaluation_step", - wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop._evaluation_step, + wraps=trainer.fit_loop.epoch_loop.val_loop._evaluation_step, ) as mocked: val_dataloaders = model.val_dataloader() trainer.fit(model, val_dataloaders=val_dataloaders) @@ -1969,18 +1962,36 @@ def test_trainer_config_strategy(monkeypatch, trainer_kwargs, strategy_cls, stra ) def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_stage): dl_prefix = running_stage.dataloader_prefix - trainer_kwargs = {f"limit_{dl_prefix}_batches": 0} + argument = f"limit_{dl_prefix}_batches" + trainer_kwargs = {argument: 0} trainer = Trainer(**trainer_kwargs) model = BoringModel() + trainer.strategy.connect(model) trainer._data_connector.attach_data(model) - reset_dataloader = getattr(trainer, f"reset_{dl_prefix}_dataloader") - reset_dataloader(model) - dl = ( - trainer.train_dataloader - if running_stage == RunningStage.TRAINING - else getattr(trainer, f"{dl_prefix}_dataloaders") - ) - assert dl is None + + trainer.state.stage = running_stage + if running_stage == "train": + trainer.state.fn = "fit" + fn = trainer.fit_loop.setup_data + elif running_stage == "validate": + trainer.state.fn = "validate" + fn = trainer.validate_loop.setup_data + elif running_stage == "test": + trainer.state.fn = "test" + fn = trainer.test_loop.setup_data + else: + fn = trainer.predict_loop.setup_data + + # with no limit, the attribute is None + fn() + dataloader_attribute = f"{dl_prefix}_dataloader{'' if running_stage == 'train' else 's'}" + assert getattr(trainer, dataloader_attribute) is None + + # validate it would've worked if a limit was set + setattr(trainer, argument, 1) + fn() + expected = DataLoader if running_stage == "train" else list + assert isinstance(getattr(trainer, dataloader_attribute), expected) @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index 08b94a4763a8f..a5fec99febf71 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -73,10 +73,10 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_b assert model.batch_size == new_batch_size if dm_bs == -1: # datamodule batch size takes precedence - assert trainer.train_dataloader.iterables.batch_size == new_batch_size + assert trainer.train_dataloader.batch_size == new_batch_size if dm_bs not in (-1, None): assert datamodule.batch_size == new_batch_size - assert trainer.train_dataloader.iterables.batch_size == new_batch_size + assert trainer.train_dataloader.batch_size == new_batch_size @pytest.mark.parametrize("trainer_fn", ["fit", "validate", "test", "predict"]) @@ -312,7 +312,7 @@ def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method): new_batch_size = tuner.scale_batch_size(model, **scale_batch_size_kwargs) assert advance_mocked.call_count == max_trials - assert trainer.train_dataloader.iterables.batch_size == new_batch_size + assert trainer.train_dataloader.batch_size == new_batch_size assert trainer.val_dataloaders[0].batch_size == init_batch_size @@ -333,8 +333,7 @@ def test_tuner_with_evaluation_methods(tmpdir, trainer_fn): assert trainer.global_step == 0 assert trainer.current_epoch == 0 - assert loop.dataloader_progress.current.completed == 0 - assert loop.epoch_loop.batch_progress.current.completed == 0 + assert loop.batch_progress.current.completed == 0 assert expected_scaled_batch_size == after_batch_size assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size_temp_model")) @@ -357,7 +356,7 @@ def test_batch_size_finder_callback(tmpdir, trainer_fn): loop = getattr(trainer, f"{trainer_fn}_loop") if trainer_fn == "fit": - expected_steps = trainer.train_dataloader.iterables.dataset.len // after_batch_size + expected_steps = trainer.train_dataloader.dataset.len // after_batch_size assert trainer.global_step == expected_steps * max_epochs assert trainer.current_epoch == max_epochs assert loop.epoch_loop.batch_progress.total.completed == expected_steps * max_epochs @@ -372,8 +371,7 @@ def test_batch_size_finder_callback(tmpdir, trainer_fn): expected_steps = dl.dataset.len // after_batch_size assert trainer.global_step == 0 assert trainer.current_epoch == 0 - assert loop.dataloader_progress.current.completed == 1 - assert loop.epoch_loop.batch_progress.current.completed == expected_steps + assert loop.batch_progress.current.completed == expected_steps assert expected_scaled_batch_size == after_batch_size assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size_temp_model")) @@ -466,4 +464,4 @@ def train_dataloader(self): new_batch_size = tuner.scale_batch_size(model, **scale_batch_size_kwargs) assert new_batch_size == model.batch_size assert new_batch_size == expected_batch_size - assert trainer.train_dataloader.iterables.batch_size == expected_batch_size + assert trainer.train_dataloader.batch_size == expected_batch_size diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index d6a4c96e5f530..afdb27c50d9a5 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -47,7 +47,7 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): should_signal = ( - self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.is_last_batch + self.trainer.fit_loop.epoch_loop.val_loop.batch_progress.is_last_batch if self.on_last_batch else batch_idx == 2 ) @@ -117,12 +117,12 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on if failure_on_training: # Breaking on first validation batch. # This is done to capture the random state of the validation dataloader. - status = "_EvaluationEpochLoop:advance" + status = "_EvaluationLoop:_evaluation_step" else: # when breaking on last batch of validation, we should exist on `run_end` val_check_interval == 1.0 status = "_FitLoop:on_advance_end" if val_check_interval == 1.0 else "_TrainingEpochLoop:on_advance_end" else: - status = "_TrainingEpochLoop:on_advance_end" if failure_on_training else "_EvaluationEpochLoop:advance" + status = "_TrainingEpochLoop:on_advance_end" if failure_on_training else "_EvaluationLoop:_evaluation_step" else: if val_check_interval == 1.0: status = "_FitLoop:on_advance_end"