diff --git a/.github/workflows/ci_dockers.yml b/.github/workflows/ci_dockers.yml index 2481eddd7f633..1e8412319df44 100644 --- a/.github/workflows/ci_dockers.yml +++ b/.github/workflows/ci_dockers.yml @@ -79,6 +79,7 @@ jobs: - {python_version: "3.7", pytorch_version: "1.11", cuda_version: "11.3.1"} # latest (used in Tutorials) - {python_version: "3.8", pytorch_version: "1.8", cuda_version: "11.1"} + - {python_version: "3.8", pytorch_version: "1.9", cuda_version: "11.1"} - {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.1"} - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"} steps: diff --git a/.github/workflows/events-nightly.yml b/.github/workflows/events-nightly.yml index 361b318629b28..3506a099fb6e4 100644 --- a/.github/workflows/events-nightly.yml +++ b/.github/workflows/events-nightly.yml @@ -119,6 +119,7 @@ jobs: - {python_version: "3.7", pytorch_version: "1.11", cuda_version: "11.3.1"} # latest (used in Tutorials) - {python_version: "3.8", pytorch_version: "1.8", cuda_version: "11.1"} + - {python_version: "3.8", pytorch_version: "1.9", cuda_version: "11.1"} - {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.1"} - {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f88bee224734..800dd6859f608 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,6 +36,7 @@ repos: args: ['--maxkb=350', '--enforce-all'] exclude: | (?x)^( + CHANGELOG.md| docs/source/_static/images/general/fast_2.gif| docs/source/_static/images/mnist_imgs/pt_to_pl.jpg| docs/source/_static/images/lightning_module/pt_to_pl.png| diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f662344f58a0..dee9e6fc3daa5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.6.3] - 2022-05-03 + +### Fixed + +- Use only a single instance of `rich.console.Console` throughout codebase ([#12886](https://github.com/PyTorchLightning/pytorch-lightning/pull/12886)) +- Fixed an issue to ensure all the checkpoint states are saved in a common filepath with `DeepspeedStrategy` ([#12887](https://github.com/PyTorchLightning/pytorch-lightning/pull/12887)) +- Fixed `trainer.logger` deprecation message ([#12671](https://github.com/PyTorchLightning/pytorch-lightning/pull/12671)) +- Fixed an issue where sharded grad scaler is passed in when using BF16 with the `ShardedStrategy` ([#12915](https://github.com/PyTorchLightning/pytorch-lightning/pull/12915)) +- Fixed an issue wrt recursive invocation of DDP configuration in hpu parallel plugin ([#12912](https://github.com/PyTorchLightning/pytorch-lightning/pull/12912)) +- Fixed printing of ragged dictionaries in `Trainer.validate` and `Trainer.test` ([#12857](https://github.com/PyTorchLightning/pytorch-lightning/pull/12857)) +- Fixed threading support for legacy loading of checkpoints ([#12814](https://github.com/PyTorchLightning/pytorch-lightning/pull/12814)) +- Fixed pickling of `KFoldLoop` ([#12441](https://github.com/PyTorchLightning/pytorch-lightning/pull/12441)) +- Stopped `optimizer_zero_grad` from being called after IPU execution ([#12913](https://github.com/PyTorchLightning/pytorch-lightning/pull/12913)) +- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891)) +- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653)) +- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965)) +- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889)) +- Fixed fit loop restart logic to enable resume using the checkpoint ([#12821](https://github.com/PyTorchLightning/pytorch-lightning/pull/12821) + + ## [1.6.2] - 2022-04-27 ### Fixed diff --git a/dockers/base-conda/Dockerfile b/dockers/base-conda/Dockerfile index 790c997f7b007..5600f5f3cb909 100644 --- a/dockers/base-conda/Dockerfile +++ b/dockers/base-conda/Dockerfile @@ -29,7 +29,11 @@ ENV \ # CUDA_TOOLKIT_ROOT_DIR="/usr/local/cuda" \ MKL_THREADING_LAYER=GNU -RUN apt-get update -qq --fix-missing && \ +RUN \ + # TODO: Remove the manual key installation once the base image is updated. + # https://github.com/NVIDIA/nvidia-docker/issues/1631 + apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \ + apt-get update -qq --fix-missing && \ apt-get install -y --no-install-recommends \ build-essential \ cmake \ @@ -104,16 +108,6 @@ RUN \ pip install -r requirements-examples.txt --no-cache-dir --find-links https://download.pytorch.org/whl/test/torch_test.html && \ rm assistant.py -RUN \ - apt-get purge -y cmake && \ - wget -q https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz && \ - tar -zxvf cmake-3.20.2.tar.gz && \ - cd cmake-3.20.2 && \ - ./bootstrap -- -DCMAKE_USE_OPENSSL=OFF && \ - make && \ - make install && \ - cmake --version - ENV \ # if you want this environment to be the default o \ne, uncomment the following line: CONDA_DEFAULT_ENV=${CONDA_ENV} \ diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 2e70a72d16ee3..15f490f2da01b 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -31,7 +31,11 @@ ENV \ # MAKEFLAGS="-j$(nproc)" MAKEFLAGS="-j2" -RUN apt-get update -qq --fix-missing && \ +RUN \ + # TODO: Remove the manual key installation once the base image is updated. + # https://github.com/NVIDIA/nvidia-docker/issues/1631 + apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \ + apt-get update -qq --fix-missing && \ apt-get install -y --no-install-recommends \ build-essential \ pkg-config \ diff --git a/docs/source/conf.py b/docs/source/conf.py index 502354c9f95f8..2cfa0175a9a16 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -121,7 +121,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: "sphinx_copybutton", "sphinx_paramlinks", "sphinx_togglebutton", - "pt_lightning_sphinx_theme.extensions.lightning_tutorials", + "pt_lightning_sphinx_theme.extensions.lightning", ] # Suppress warnings about duplicate labels (needed for PL tutorials) diff --git a/pl_examples/loop_examples/kfold.py b/pl_examples/loop_examples/kfold.py index 811ad409c2e91..229673ec78df4 100644 --- a/pl_examples/loop_examples/kfold.py +++ b/pl_examples/loop_examples/kfold.py @@ -239,6 +239,9 @@ def __getattr__(self, key) -> Any: return getattr(self.fit_loop, key) return self.__dict__[key] + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + class LitImageClassifier(ImageClassifier): def __init__(self) -> None: diff --git a/pytorch_lightning/__about__.py b/pytorch_lightning/__about__.py index 8b7e408363bcd..bf1938828cd42 100644 --- a/pytorch_lightning/__about__.py +++ b/pytorch_lightning/__about__.py @@ -1,7 +1,7 @@ import time _this_year = time.strftime("%Y") -__version__ = "1.6.2" +__version__ = "1.6.3" __author__ = "William Falcon et al." __author_email__ = "waf2107@columbia.edu" __license__ = "Apache-2.0" diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index cc475634ff6ea..fb5914a7a5d41 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -22,7 +22,8 @@ Task, Style = None, None if _RICH_AVAILABLE: - from rich.console import Console, RenderableType + from rich import get_console, reconfigure + from rich.console import RenderableType from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn from rich.progress_bar import ProgressBar from rich.style import Style @@ -278,7 +279,8 @@ def enable(self) -> None: def _init_progress(self, trainer): if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() - self._console = Console(**self._console_kwargs) + reconfigure(**self._console_kwargs) + self._console = get_console() self._console.clear_live() self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) self.progress = CustomProgress( diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 2d4da1c15eea8..f6467e4606e6a 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -262,13 +262,13 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches - self.main_progress_bar.total = convert_inf(total_batches) + self.main_progress_bar.reset(convert_inf(total_batches)) self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: current = self.train_batch_idx + self._val_processed if self._should_update(current, self.main_progress_bar.total): - _update_n(self.main_progress_bar, current) + _update_n(self.main_progress_bar, current, self.refresh_rate) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -288,17 +288,17 @@ def on_validation_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.val_progress_bar.total = convert_inf(self.total_val_batches_current_dataloader) + self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader)) desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: if self._should_update(self.val_batch_idx, self.val_progress_bar.total): - _update_n(self.val_progress_bar, self.val_batch_idx) + _update_n(self.val_progress_bar, self.val_batch_idx, self.refresh_rate) current = self.train_batch_idx + self._val_processed if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total): - _update_n(self.main_progress_bar, current) + _update_n(self.main_progress_bar, current, self.refresh_rate) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._main_progress_bar is not None and trainer.state.fn == "fit": @@ -315,12 +315,12 @@ def on_test_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.test_progress_bar.total = convert_inf(self.total_test_batches_current_dataloader) + self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader)) self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") def on_test_batch_end(self, *_: Any) -> None: if self._should_update(self.test_batch_idx, self.test_progress_bar.total): - _update_n(self.test_progress_bar, self.test_batch_idx) + _update_n(self.test_progress_bar, self.test_batch_idx, self.refresh_rate) def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar.close() @@ -335,12 +335,12 @@ def on_predict_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.predict_progress_bar.total = convert_inf(self.total_predict_batches_current_dataloader) + self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader)) self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") def on_predict_batch_end(self, *_: Any) -> None: if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total): - _update_n(self.predict_progress_bar, self.predict_batch_idx) + _update_n(self.predict_progress_bar, self.predict_batch_idx, self.refresh_rate) def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar.close() @@ -384,7 +384,10 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: return x -def _update_n(bar: _tqdm, value: int) -> None: +def _update_n(bar: _tqdm, current: int, refresh_rate: int) -> None: if not bar.disable: - bar.n = value + total = bar.total + leftover = current % refresh_rate + advance = leftover if (current == total and leftover != 0) else refresh_rate + bar.update(advance) bar.refresh() diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index 4b0b3f702cf6e..2ae1262eb25d9 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -26,7 +26,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10 +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11 from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TORCH_GREATER_EQUAL_1_10: @@ -34,6 +34,11 @@ else: from torch.quantization import QConfig +if _TORCH_GREATER_EQUAL_1_11: + from torch.ao.quantization import fuse_modules_qat as fuse_modules +else: + from torch.quantization import fuse_modules + def wrap_qat_forward_context( quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None @@ -252,7 +257,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None: model.qconfig = self._qconfig if self._check_feasible_fuse(model): - torch.quantization.fuse_modules(model, self._modules_to_fuse, inplace=True) + fuse_modules(model, self._modules_to_fuse, inplace=True) # Prepare the model for QAT. This inserts observers and fake_quants in # the model that will observe weight and activation tensors during calibration. diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index 14c078a273ece..148de6275950e 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -18,7 +18,7 @@ from pytorch_lightning.utilities.model_summary import get_human_readable_count if _RICH_AVAILABLE: - from rich.console import Console + from rich import get_console from rich.table import Table @@ -73,7 +73,7 @@ def summarize( model_size: float, ) -> None: - console = Console() + console = get_console() table = Table(header_style="bold magenta") table.add_column(" ", style="dim") diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b5a748295a7ea..b3d2adec571e3 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -37,10 +37,10 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO -from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator -from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType +from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, warnings from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp @@ -249,7 +249,26 @@ def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None: @property def logger(self) -> Optional[LightningLoggerBase]: """Reference to the logger object in the Trainer.""" - return self.trainer.logger if self.trainer else None + # this should match the implementation of `trainer.logger` + # we don't reuse it so we can properly set the deprecation stacklevel + if self.trainer is None: + return + loggers = self.trainer.loggers + if len(loggers) == 0: + return None + if len(loggers) == 1: + return loggers[0] + else: + if not self._running_torchscript: + rank_zero_deprecation( + "Using `lightning_module.logger` when multiple loggers are configured." + " This behavior will change in v1.8 when `LoggerCollection` is removed, and" + " `lightning_module.logger` will return the first logger available.", + stacklevel=5, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return LoggerCollection(loggers) @property def loggers(self) -> List[LightningLoggerBase]: diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index b8c1cc9550475..e0db99c3f5d41 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -16,7 +16,7 @@ import sys from collections import ChainMap, OrderedDict from functools import partial -from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Type, Union import torch from deprecate.utils import void @@ -42,7 +42,7 @@ from pytorch_lightning.utilities.types import EPOCH_OUTPUT if _RICH_AVAILABLE: - from rich.console import Console + from rich import get_console from rich.table import Column, Table @@ -320,45 +320,48 @@ def _on_evaluation_epoch_end(self) -> None: self.trainer._logger_connector.on_epoch_end() @staticmethod - def _get_keys(data: dict) -> Iterable[str]: - if any(isinstance(v, dict) for v in data.values()): - for v in data.values(): - yield from apply_to_collection(v, dict, dict.keys) - else: - yield from data.keys() - - @staticmethod - def _find_value(data: dict, target: str) -> Iterable[Any]: + def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]: for k, v in data.items(): - if k == target: - yield v - elif isinstance(v, dict): - yield from EvaluationLoop._find_value(v, target) + if isinstance(v, dict): + for new_key in apply_to_collection(v, dict, EvaluationLoop._get_keys): + yield (k, *new_key) # this need to be in parenthesis for older python versions + else: + yield k, @staticmethod - def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] = None) -> None: - # print to stdout by default - if file is None: - file = sys.stdout + def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]: + target_start, *rest = target + if target_start not in data: + return None + result = data[target_start] + if not rest: + return result + return EvaluationLoop._find_value(result, rest) + @staticmethod + def _print_results(results: List[_OUT_DICT], stage: str) -> None: # remove the dl idx suffix results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results] - metrics = sorted({k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys}) - if not metrics: + metrics_paths = {k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys} + if not metrics_paths: return + + metrics_strs = [":".join(metric) for metric in metrics_paths] + # sort both lists based on metrics_strs + metrics_strs, metrics_paths = zip(*sorted(zip(metrics_strs, metrics_paths))) + headers = [f"DataLoader {i}" for i in range(len(results))] # fallback is useful for testing of printed output term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120 - max_length = int(min(max(len(max(metrics + headers, key=len)), 25), term_size / 2)) + max_length = int(min(max(len(max(metrics_strs, key=len)), len(max(headers, key=len)), 25), term_size / 2)) - rows: List[List[Any]] = [[] for _ in metrics] + rows: List[List[Any]] = [[] for _ in metrics_paths] for result in results: - for metric, row in zip(metrics, rows): - v = list(EvaluationLoop._find_value(result, metric)) - if v: - val = v[0] + for metric, row in zip(metrics_paths, rows): + val = EvaluationLoop._find_value(result, metric) + if val is not None: if isinstance(val, torch.Tensor): val = val.item() if val.numel() == 1 else val.tolist() row.append(f"{val}") @@ -375,15 +378,15 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] table_headers.insert(0, f"{stage} Metric".capitalize()) if _RICH_AVAILABLE: - console = Console(file=file) - columns = [Column(h, justify="center", style="magenta", width=max_length) for h in table_headers] columns[0].style = "cyan" table = Table(*columns) - for metric, row in zip(metrics, table_rows): + for metric, row in zip(metrics_strs, table_rows): row.insert(0, metric) table.add_row(*row) + + console = get_console() console.print(table) else: row_format = f"{{:^{max_length}}}" * len(table_headers) @@ -391,8 +394,8 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] try: # some terminals do not support this character - if hasattr(file, "encoding") and file.encoding is not None: - "─".encode(file.encoding) + if sys.stdout.encoding is not None: + "─".encode(sys.stdout.encoding) except UnicodeEncodeError: bar_character = "-" else: @@ -400,7 +403,7 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] bar = bar_character * term_size lines = [bar, row_format.format(*table_headers).rstrip(), bar] - for metric, row in zip(metrics, table_rows): + for metric, row in zip(metrics_strs, table_rows): # deal with column overflow if len(metric) > half_term_size: while len(metric) > half_term_size: @@ -411,7 +414,7 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] else: lines.append(row_format.format(metric, *row).rstrip()) lines.append(bar) - print(os.linesep.join(lines), file=file) + print(os.linesep.join(lines)) def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index db3f60fb28ede..40334387c0688 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -123,15 +123,10 @@ def running_loss(self) -> TensorRunningAccum: @Loop.restarting.setter def restarting(self, restarting: bool) -> None: - # if the last epoch completely finished, we are not actually restarting, we can check this to see if all - # current values are equal - values = ( - self.epoch_progress.current.ready, - self.epoch_progress.current.started, - self.epoch_progress.current.processed, - ) - finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values) - restarting &= finished_before_on_train_end + # if the last epoch completely finished, we are not actually restarting + values = self.epoch_progress.current.ready, self.epoch_progress.current.started + epoch_unfinished = any(v != self.epoch_progress.current.processed for v in values) + restarting = restarting and epoch_unfinished or self._iteration_based_training() Loop.restarting.fset(self, restarting) # call the parent setter @property @@ -205,6 +200,10 @@ def reset(self) -> None: def on_run_start(self) -> None: # type: ignore[override] """Calls the ``on_train_start`` hook.""" + # update the current_epoch in-case of checkpoint reload + if not self._iteration_based_training(): + self.epoch_progress.current.completed = self.epoch_progress.current.processed + # reset train dataloader and val dataloader self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) @@ -336,6 +335,9 @@ def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate() + def _iteration_based_training(self) -> bool: + return self.trainer.max_steps != -1 + def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: training_step_fx = getattr(trainer.lightning_module, "training_step") diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index 57840a918a2e1..e40aea8ecf4eb 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -35,7 +35,7 @@ def __init__( "You have asked for sharded AMP but you have not installed it." " Install `fairscale` using this guide: https://https://github.com/facebookresearch/fairscale" ) - super().__init__(precision, device, scaler=scaler or ShardedGradScaler()) + super().__init__(precision, device, scaler=ShardedGradScaler() if scaler is None and precision == 16 else None) def clip_grad_by_norm(self, optimizer: "OSS", clip_val: Union[int, float]) -> None: optimizer.clip_grad_norm(clip_val) diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index f3e3951e4ffd7..8eac4a18b2bad 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -757,6 +757,9 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Op TypeError: If ``storage_options`` arg is passed in """ + # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath + filepath = self.broadcast(filepath) + if storage_options is not None: raise TypeError( "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" diff --git a/pytorch_lightning/strategies/fully_sharded.py b/pytorch_lightning/strategies/fully_sharded.py index b61429264d80a..6a902d3e09a3a 100644 --- a/pytorch_lightning/strategies/fully_sharded.py +++ b/pytorch_lightning/strategies/fully_sharded.py @@ -163,7 +163,7 @@ def wrap_policy(*args, **kwargs): cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, - mixed_precision=(precision == PrecisionType.MIXED), + mixed_precision=(precision in (PrecisionType.MIXED, PrecisionType.HALF)), reshard_after_forward=self.reshard_after_forward, fp32_reduce_scatter=self.fp32_reduce_scatter, compute_dtype=self.compute_dtype, diff --git a/pytorch_lightning/strategies/hpu_parallel.py b/pytorch_lightning/strategies/hpu_parallel.py index 562a841b89510..4996ddabcf960 100644 --- a/pytorch_lightning/strategies/hpu_parallel.py +++ b/pytorch_lightning/strategies/hpu_parallel.py @@ -103,7 +103,7 @@ def configure_ddp(self) -> None: self._model._set_static_graph() # type: ignore self._register_ddp_hooks() else: - self.configure_ddp() + super().configure_ddp() def broadcast(self, obj: object, src: int = 0) -> object: # type: ignore obj = [obj] diff --git a/pytorch_lightning/strategies/ipu.py b/pytorch_lightning/strategies/ipu.py index 4603110c01536..29d6a3f068e29 100644 --- a/pytorch_lightning/strategies/ipu.py +++ b/pytorch_lightning/strategies/ipu.py @@ -25,12 +25,13 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE +from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import STEP_OUTPUT if _POPTORCH_AVAILABLE: @@ -121,6 +122,7 @@ def __init__( os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(options) self._update_dataloader_original: Optional[Callable] = None + self._optimizer_zero_grad_original: Optional[Callable] = None def setup(self, trainer: "pl.Trainer") -> None: # set the `accumulate_grad_batches` property as early as possible @@ -134,6 +136,11 @@ def setup(self, trainer: "pl.Trainer") -> None: super().setup(trainer) + # disable the `optimizer_zero_grad` function by setting it to `None`. + # this is because the IPU zeros the gradients internally + self._optimizer_zero_grad_original = self.lightning_module.optimizer_zero_grad + self._disable_zero_grad() + model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision) self.model = model @@ -260,6 +267,16 @@ def to_tensor(x): args = apply_to_collection(args, dtype=(int, float), function=to_tensor) return args + def _disable_zero_grad(self) -> None: + lightning_module = self.lightning_module + if is_overridden("optimizer_zero_grad", lightning_module): + assert lightning_module is not None # `is_overridden` returns False otherwise + rank_zero_warn( + "You have overridden the `LightningModule.optimizer_zero_grad` hook but it will be ignored since" + " IPUs handle the zeroing of gradients internally." + ) + lightning_module.optimizer_zero_grad = None # type: ignore[assignment] + def _step(self, stage: RunningStage, *args: Any, **kwargs: Any): args = self._prepare_input(args) poptorch_model = self.poptorch_models[stage] @@ -290,6 +307,10 @@ def teardown(self) -> None: # undo dataloader patching pl.trainer.connectors.data_connector._update_dataloader = self._update_dataloader_original + if self._optimizer_zero_grad_original is not None: + # re-enable `optimizer_zero_grad` + self.lightning_module.optimizer_zero_grad = self._optimizer_zero_grad_original + for model in self.poptorch_models.values(): model.destroy() diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 58f730f43eef1..7d9e6e48fc43d 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -447,7 +447,7 @@ def _request_dataloader( self.trainer._call_lightning_module_hook("on_" + hook, pl_module=model) with _replace_dataloader_init_method(): # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as - # attributes on the instance in case the dataloader needs to be re-instantiated later by Ligtning + # attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning dataloader = source.dataloader() if isinstance(dataloader, tuple): dataloader = list(dataloader) @@ -482,6 +482,7 @@ def replace_sampler(dataloader: DataLoader) -> DataLoader: @staticmethod def _check_eval_shuffling(dataloader, mode): + # limit this warning only for samplers assigned automatically when shuffle is set if _is_dataloader_shuffled(dataloader): rank_zero_warn( f"Your `{mode.dataloader_prefix}_dataloader`'s sampler has shuffling enabled," diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2565b6ca51338..af9515006fb74 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2700,19 +2700,21 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop @property def logger(self) -> Optional[LightningLoggerBase]: - if len(self.loggers) == 0: + loggers = self.loggers + if len(loggers) == 0: return None - if len(self.loggers) == 1: - return self.loggers[0] + if len(loggers) == 1: + return loggers[0] else: - rank_zero_warn( - "Using trainer.logger when Trainer is configured to use multiple loggers." - " This behavior will change in v1.8 when LoggerCollection is removed, and" - " trainer.logger will return the first logger in trainer.loggers" + rank_zero_deprecation( + "Using `trainer.logger` when multiple loggers are configured." + " This behavior will change in v1.8 when `LoggerCollection` is removed, and" + " `trainer.logger` will return the first logger available.", + stacklevel=5, ) with warnings.catch_warnings(): warnings.simplefilter("ignore") - return LoggerCollection(self.loggers) + return LoggerCollection(loggers) @logger.setter def logger(self, logger: Optional[LightningLoggerBase]) -> None: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 289b7faa431e2..87947ac9a10f3 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -49,6 +49,7 @@ _RICH_AVAILABLE, _TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_10, + _TORCH_GREATER_EQUAL_1_11, _TORCH_QUANTIZE_AVAILABLE, _TORCHTEXT_AVAILABLE, _TORCHVISION_AVAILABLE, diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 700ba843bb516..01c6b91a5108b 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -29,7 +29,7 @@ from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE +from pytorch_lightning.utilities.imports import _DOCSTRING_PARSER_AVAILABLE, _JSONARGPARSE_AVAILABLE from pytorch_lightning.utilities.meta import get_all_subclasses from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_warn @@ -37,13 +37,15 @@ if _JSONARGPARSE_AVAILABLE: from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, Namespace, set_config_read_mode - from jsonargparse.optionals import import_docstring_parse set_config_read_mode(fsspec_enabled=True) else: locals()["ArgumentParser"] = object locals()["Namespace"] = object +if _DOCSTRING_PARSER_AVAILABLE: + import docstring_parser + class _Registry(dict): def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> Type: @@ -888,9 +890,13 @@ def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) - def _get_short_description(component: object) -> Optional[str]: - parse, _ = import_docstring_parse("LightningCLI(run=True)") - try: - docstring = parse(component.__doc__) - return docstring.short_description - except ValueError: - rank_zero_warn(f"Failed parsing docstring for {component}") + if component.__doc__ is None: + return None + if not _DOCSTRING_PARSER_AVAILABLE: + rank_zero_warn(f"Failed parsing docstring for {component}: docstring-parser package is required") + else: + try: + docstring = docstring_parser.parse(component.__doc__) + return docstring.short_description + except (ValueError, docstring_parser.ParseError) as ex: + rank_zero_warn(f"Failed parsing docstring for {component}: {ex}") diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 5d54c8e53f091..872b07476e750 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union import torch -from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler, SequentialSampler +from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler import pytorch_lightning as pl from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper @@ -384,9 +384,18 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> return dl_kwargs -def _is_dataloader_shuffled(dataloader: DataLoader): - return ( - hasattr(dataloader, "sampler") - and not isinstance(dataloader.sampler, SequentialSampler) - and not isinstance(dataloader.dataset, IterableDataset) - ) +def _is_dataloader_shuffled(dataloader: object) -> bool: + if hasattr(dataloader, "shuffle"): + # this attribute is not part of PyTorch's DataLoader, but could have been set by + # our `_replace_dataloader_init_method` context manager + return dataloader.shuffle + if isinstance(dataloader.dataset, IterableDataset): + # shuffling is useless with iterable datasets + return False + if not hasattr(dataloader, "sampler"): + # shuffling is enabled via a sampler. No sampler, no shuffling + return False + sampler = dataloader.sampler + if isinstance(sampler, SequentialSampler): + return False + return isinstance(sampler, RandomSampler) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 9bf0fdd046134..835e56f1816da 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -100,6 +100,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: _DEEPSPEED_AVAILABLE = _package_available("deepspeed") _DEEPSPEED_GREATER_EQUAL_0_5_9 = _DEEPSPEED_AVAILABLE and _compare_version("deepspeed", operator.ge, "0.5.9") _DEEPSPEED_GREATER_EQUAL_0_6 = _DEEPSPEED_AVAILABLE and _compare_version("deepspeed", operator.ge, "0.6.0") +_DOCSTRING_PARSER_AVAILABLE = _package_available("docstring_parser") _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3") _FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4") diff --git a/pytorch_lightning/utilities/migration.py b/pytorch_lightning/utilities/migration.py index bc3761e47b835..30cc823210423 100644 --- a/pytorch_lightning/utilities/migration.py +++ b/pytorch_lightning/utilities/migration.py @@ -14,10 +14,14 @@ from __future__ import annotations import sys +import threading from types import ModuleType, TracebackType import pytorch_lightning.utilities.argparse +# Create a global lock to ensure no race condition with deleting sys modules +_lock = threading.Lock() + class pl_legacy_patch: """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for @@ -35,6 +39,7 @@ class pl_legacy_patch: """ def __enter__(self) -> None: + _lock.acquire() # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module @@ -49,3 +54,4 @@ def __exit__( if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") del sys.modules["pytorch_lightning.utilities.argparse_utils"] + _lock.release() diff --git a/requirements.txt b/requirements.txt index 6aa080fc7e8fb..39f0d586ba18f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ numpy>=1.17.2 torch>=1.8.* -tqdm>=4.41.0 +tqdm>=4.57.0 PyYAML>=5.4 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0 diff --git a/tests/callbacks/test_rich_model_summary.py b/tests/callbacks/test_rich_model_summary.py index c596557eed0dc..d9e4ec55902ca 100644 --- a/tests/callbacks/test_rich_model_summary.py +++ b/tests/callbacks/test_rich_model_summary.py @@ -41,8 +41,8 @@ def test_rich_progress_bar_import_error(monkeypatch): @RunIf(rich=True) -@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Console.print", autospec=True) -@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Table.add_row", autospec=True) +@mock.patch("rich.console.Console.print", autospec=True) +@mock.patch("rich.table.Table.add_row", autospec=True) def test_rich_summary_tuples(mock_table_add_row, mock_console): """Ensure that tuples are converted into string, and print is called correctly.""" model_summary = RichModelSummary() diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index f36f9d3353093..f46ea267f9d64 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -53,6 +53,7 @@ def n(self): @n.setter def n(self, value): self.__n = value + # track the changes in the `n` value if not len(self.n_values) or value != self.n_values[-1]: self.n_values.append(value) @@ -158,7 +159,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): assert not pbar.val_progress_bar.leave assert trainer.num_sanity_val_batches == expected_sanity_steps assert pbar.val_progress_bar.total_values == expected_sanity_steps - assert pbar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl + assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] # fit @@ -177,7 +178,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): # check val progress bar total assert pbar.val_progress_bar.total_values == m - assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl + assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] assert not pbar.val_progress_bar.leave @@ -186,7 +187,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): trainer.validate(model) assert trainer.num_val_batches == m assert pbar.val_progress_bar.total_values == m - assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl + assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] # test @@ -195,7 +196,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): assert pbar.test_progress_bar.leave k = trainer.num_test_batches assert pbar.test_progress_bar.total_values == k - assert pbar.test_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl + assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)] assert pbar.test_progress_bar.leave @@ -205,7 +206,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): assert pbar.predict_progress_bar.leave k = trainer.num_predict_batches assert pbar.predict_progress_bar.total_values == k - assert pbar.predict_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl + assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)] assert pbar.predict_progress_bar.leave @@ -359,13 +360,13 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): @pytest.mark.parametrize( "train_batches,val_batches,refresh_rate,train_updates,val_updates", [ - [2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]], + [2, 3, 1, [0, 1, 2, 3, 4, 5], [0, 1, 2, 3]], [0, 0, 3, None, None], - [1, 0, 3, [1], None], - [1, 1, 3, [2], [1]], - [5, 0, 3, [3, 5], None], - [5, 2, 3, [3, 6, 7], [2]], - [5, 2, 6, [6, 7], [2]], + [1, 0, 3, [0, 1], None], + [1, 1, 3, [0, 2], [0, 1]], + [5, 0, 3, [0, 3, 5], None], + [5, 2, 3, [0, 3, 6, 7], [0, 2]], + [5, 2, 6, [0, 6, 7], [0, 2]], ], ) def test_main_progress_bar_update_amount( @@ -395,7 +396,7 @@ def test_main_progress_bar_update_amount( assert progress_bar.val_progress_bar.n_values == val_updates -@pytest.mark.parametrize("test_batches,refresh_rate,updates", [[1, 3, [1]], [3, 1, [1, 2, 3]], [5, 3, [3, 5]]]) +@pytest.mark.parametrize("test_batches,refresh_rate,updates", [(1, 3, [0, 1]), (3, 1, [0, 1, 2, 3]), (5, 3, [0, 3, 5])]) def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, updates: list): """Test that test progress updates with the correct amount.""" model = BoringModel() @@ -566,7 +567,7 @@ def test_tqdm_progress_bar_can_be_pickled(): @pytest.mark.parametrize( ["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"], - [(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])], + [(4, [0, 3, 6, 9, 12, 14], [0, 3, 6, 7]), (0.5, [0, 3, 6, 9, 12, 15, 18, 21], [0, 3, 6, 7])], ) def test_progress_bar_max_val_check_interval( tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 7e753617a6331..ac2806cbf3811 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,6 +14,7 @@ import glob import os import sys +import threading from unittest.mock import patch import pytest @@ -60,6 +61,28 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo trainer.should_stop = True +@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) +def test_legacy_ckpt_threading(tmpdir, pl_version: str): + def load_model(): + import torch + + from pytorch_lightning.utilities.migration import pl_legacy_patch + + with pl_legacy_patch(): + _ = torch.load(PATH_LEGACY) + + PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + with patch("sys.path", [PATH_LEGACY] + sys.path): + t1 = threading.Thread(target=load_model) + t2 = threading.Thread(target=load_model) + + t1.start() + t2.start() + + t1.join() + t2.join() + + @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) def test_resume_legacy_checkpoints(tmpdir, pl_version: str): PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index c7fee3b0d5fd0..07fcf8dadccc3 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -77,7 +77,7 @@ def test_property_logger(tmpdir): assert model.logger is None logger = TensorBoardLogger(tmpdir) - trainer = Mock(logger=logger) + trainer = Mock(loggers=[logger]) model.trainer = trainer assert model.logger == logger diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 8975a8a3d47c1..39ad68a21e039 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -760,10 +760,11 @@ def test_v1_8_0_logger_collection(tmpdir): trainer1.logger trainer1.loggers trainer2.loggers - trainer2.logger + with pytest.deprecated_call(match="logger` will return the first logger"): + _ = trainer2.logger with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): - LoggerCollection([logger1, logger2]) + _ = LoggerCollection([logger1, logger2]) def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 05dc566949a08..0e62c4b109ef5 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -13,6 +13,7 @@ # limitations under the License. import functools import os +import re import traceback from contextlib import contextmanager from typing import Optional, Type @@ -126,7 +127,7 @@ def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Option return else: for w in record.list: - if w.category is expected_warning and match in w.message.args[0]: + if w.category is expected_warning and re.compile(match).search(w.message.args[0]): break else: return diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 6c24512d782c6..fa31b50a4d807 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -570,7 +570,102 @@ def training_step(self, batch, batch_idx): assert called == expected -def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): +def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): + # initial training to get a checkpoint + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=0, + enable_progress_bar=False, + enable_model_summary=False, + callbacks=[HookedCallback([])], + ) + trainer.fit(model) + best_model_path = trainer.checkpoint_callback.best_model_path + + called = [] + callback = HookedCallback(called) + # already performed 1 step, resume and do 2 more + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=2, + limit_val_batches=0, + enable_progress_bar=False, + enable_model_summary=False, + callbacks=[callback], + track_grad_norm=1, + ) + assert called == [ + dict(name="Callback.on_init_start", args=(trainer,)), + dict(name="Callback.on_init_end", args=(trainer,)), + ] + + # resume from checkpoint with HookedModel + model = HookedModel(called) + trainer.fit(model, ckpt_path=best_model_path) + loaded_ckpt = { + "callbacks": ANY, + "epoch": 0, + "global_step": 2, + "lr_schedulers": ANY, + "optimizer_states": ANY, + "pytorch-lightning_version": __version__, + "state_dict": ANY, + "loops": ANY, + } + saved_ckpt = {**loaded_ckpt, "global_step": 4, "epoch": 1} + expected = [ + dict(name="Callback.on_init_start", args=(trainer,)), + dict(name="Callback.on_init_end", args=(trainer,)), + dict(name="configure_callbacks"), + dict(name="prepare_data"), + dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)), + dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")), + dict(name="setup", kwargs=dict(stage="fit")), + dict(name="on_load_checkpoint", args=(loaded_ckpt,)), + dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})), + dict(name="Callback.load_state_dict", args=({"foo": True},)), + dict(name="configure_sharded_model"), + dict(name="Callback.on_configure_sharded_model", args=(trainer, model)), + dict(name="configure_optimizers"), + dict(name="Callback.on_fit_start", args=(trainer, model)), + dict(name="on_fit_start"), + dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)), + dict(name="on_pretrain_routine_start"), + dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)), + dict(name="on_pretrain_routine_end"), + dict(name="train", args=(True,)), + dict(name="on_train_dataloader"), + dict(name="train_dataloader"), + dict(name="Callback.on_train_start", args=(trainer, model)), + dict(name="on_train_start"), + dict(name="Callback.on_epoch_start", args=(trainer, model)), + dict(name="on_epoch_start"), + dict(name="Callback.on_train_epoch_start", args=(trainer, model)), + dict(name="on_train_epoch_start"), + *model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0), + dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)), + dict(name="Callback.on_train_epoch_end", args=(trainer, model)), + dict(name="Callback.state_dict"), + dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), + dict(name="on_save_checkpoint", args=(saved_ckpt,)), + dict(name="on_train_epoch_end"), + dict(name="Callback.on_epoch_end", args=(trainer, model)), + dict(name="on_epoch_end"), + dict(name="Callback.on_train_end", args=(trainer, model)), + dict(name="on_train_end"), + dict(name="Callback.on_fit_end", args=(trainer, model)), + dict(name="on_fit_end"), + dict(name="Callback.teardown", args=(trainer, model), kwargs=dict(stage="fit")), + dict(name="teardown", kwargs=dict(stage="fit")), + ] + assert called == expected + + +def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir): # initial training to get a checkpoint model = BoringModel() trainer = Trainer( diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 0d6c9772b9c45..136e8ee516bbb 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -199,7 +199,7 @@ def on_train_start(self): if self.trainer.state.fn == TrainerFn.TUNING: self._test_on_val_test_predict_tune_start() else: - assert self.trainer.current_epoch == state_dict["epoch"] + assert self.trainer.current_epoch == state_dict["epoch"] + 1 assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() assert self._check_optimizers() diff --git a/tests/plugins/precision/test_sharded_precision.py b/tests/plugins/precision/test_sharded_precision.py new file mode 100644 index 0000000000000..754095912fb53 --- /dev/null +++ b/tests/plugins/precision/test_sharded_precision.py @@ -0,0 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from pytorch_lightning.plugins import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE +from tests.helpers.runif import RunIf + +ShardedGradScaler = None +if _FAIRSCALE_AVAILABLE: + from fairscale.optim.grad_scaler import ShardedGradScaler + + +@RunIf(fairscale=True) +@pytest.mark.parametrize( + "precision,scaler,expected", + [ + (16, torch.cuda.amp.GradScaler(), torch.cuda.amp.GradScaler), + (16, None, ShardedGradScaler), + pytest.param("bf16", None, None, marks=RunIf(min_torch="1.10")), + (32, None, None), + ], +) +def test_sharded_precision_scaler(precision, scaler, expected): + plugin = ShardedNativeMixedPrecisionPlugin(precision=precision, scaler=scaler, device="cuda") + if expected: + assert isinstance(plugin.scaler, expected) + else: + assert not plugin.scaler diff --git a/tests/strategies/test_ddp_fully_sharded_with_full_state_dict.py b/tests/strategies/test_ddp_fully_sharded_with_full_state_dict.py index 4b237c8704ddc..2912d59598220 100644 --- a/tests/strategies/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/strategies/test_ddp_fully_sharded_with_full_state_dict.py @@ -90,6 +90,11 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer.module[0].reshard_after_forward is True assert self.layer.module[2].reshard_after_forward is True + if isinstance(self.trainer.precision_plugin, FullyShardedNativeMixedPrecisionPlugin): + assert self.layer.mixed_precision + assert self.layer.module[0].mixed_precision + assert self.layer.module[2].mixed_precision + @RunIf(min_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True) def test_fully_sharded_strategy_checkpoint(tmpdir): diff --git a/tests/strategies/test_ddp_strategy_with_comm_hook.py b/tests/strategies/test_ddp_strategy_with_comm_hook.py index 34ff4d412828c..11082849d684d 100644 --- a/tests/strategies/test_ddp_strategy_with_comm_hook.py +++ b/tests/strategies/test_ddp_strategy_with_comm_hook.py @@ -30,11 +30,26 @@ import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD +class TestDDPStrategy(DDPStrategy): + def __init__(self, expected_ddp_comm_hook_name, *args, **kwargs): + self.expected_ddp_comm_hook_name = expected_ddp_comm_hook_name + super().__init__(*args, **kwargs) + + def teardown(self): + # check here before unwrapping DistributedDataParallel in self.teardown + attached_ddp_comm_hook_name = self.model._get_ddp_logging_data()["comm_hook"] + assert attached_ddp_comm_hook_name == self.expected_ddp_comm_hook_name + return super().teardown() + + @RunIf(min_gpus=2, min_torch="1.9.0", skip_windows=True, standalone=True) def test_ddp_fp16_compress_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() - strategy = DDPStrategy(ddp_comm_hook=default.fp16_compress_hook) + strategy = TestDDPStrategy( + expected_ddp_comm_hook_name=default.fp16_compress_hook.__qualname__, + ddp_comm_hook=default.fp16_compress_hook, + ) trainer = Trainer( max_epochs=1, accelerator="gpu", @@ -45,9 +60,6 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook - expected_comm_hook = default.fp16_compress_hook.__qualname__ - assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -55,7 +67,8 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): def test_ddp_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() - strategy = DDPStrategy( + strategy = TestDDPStrategy( + expected_ddp_comm_hook_name=powerSGD.powerSGD_hook.__qualname__, ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, ) @@ -69,9 +82,6 @@ def test_ddp_sgd_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook - expected_comm_hook = powerSGD.powerSGD_hook.__qualname__ - assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -79,7 +89,8 @@ def test_ddp_sgd_comm_hook(tmpdir): def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress wrapper for SGD hook.""" model = BoringModel() - strategy = DDPStrategy( + strategy = TestDDPStrategy( + expected_ddp_comm_hook_name=default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__, ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, ddp_comm_wrapper=default.fp16_compress_wrapper, @@ -94,9 +105,6 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook - expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__ - assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -122,8 +130,8 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): def test_ddp_post_local_sgd_comm_hook(tmpdir): """Test for DDP post-localSGD hook.""" model = BoringModel() - - strategy = DDPStrategy( + strategy = TestDDPStrategy( + expected_ddp_comm_hook_name=post_localSGD.post_localSGD_hook.__qualname__, ddp_comm_state=post_localSGD.PostLocalSGDState( process_group=None, subgroup=None, @@ -141,9 +149,6 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir): sync_batchnorm=True, ) trainer.fit(model) - trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook - expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__ - assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" diff --git a/tests/strategies/test_deepspeed_strategy.py b/tests/strategies/test_deepspeed_strategy.py index 9ca285434e8a2..de009e55fe875 100644 --- a/tests/strategies/test_deepspeed_strategy.py +++ b/tests/strategies/test_deepspeed_strategy.py @@ -1172,19 +1172,30 @@ def test_deepspeed_with_meta_device(tmpdir): def test_deepspeed_multi_save_same_filepath(tmpdir): """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old sharded checkpoints.""" - model = BoringModel() + + class CustomModel(BoringModel): + def training_step(self, *args, **kwargs): + self.log("grank", self.global_rank) + return super().training_step(*args, **kwargs) + + model = CustomModel() trainer = Trainer( default_root_dir=tmpdir, strategy="deepspeed", accelerator="gpu", devices=2, - callbacks=[ModelCheckpoint(save_top_k=1, save_last=True)], + callbacks=[ModelCheckpoint(filename="{epoch}_{step}_{grank}", save_top_k=1)], limit_train_batches=1, limit_val_batches=0, num_sanity_val_steps=0, max_epochs=2, ) trainer.fit(model) - ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, "last.ckpt") - expected = ["latest", "zero_to_fp32.py", "checkpoint"] - assert set(expected) == set(os.listdir(ckpt_path)) + + filepath = "epoch=1_step=2_grank=0.0.ckpt" + expected = {filepath} + assert expected == set(os.listdir(trainer.checkpoint_callback.dirpath)) + + ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filepath) + expected = {"latest", "zero_to_fp32.py", "checkpoint"} + assert expected == set(os.listdir(ckpt_path)) diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index e22e846600122..c335644048db2 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -11,20 +11,388 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import redirect_stderr +from io import StringIO +from re import escape from unittest.mock import Mock import pytest -from torch.utils.data import DataLoader +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler from pytorch_lightning import Trainer +from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache -from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.trainer.states import RunningStage, TrainerFn +from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.utilities.data import _update_dataloader +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import PossibleUserWarning -from tests.helpers import BoringDataModule, BoringModel -from tests.helpers.boring_model import RandomDataset +from tests.helpers.boring_model import BoringDataModule, BoringModel, RandomDataset +from tests.helpers.runif import RunIf from tests.helpers.utils import no_warning_call +@RunIf(skip_windows=True) +@pytest.mark.parametrize("mode", (1, 2)) +def test_replace_distributed_sampler(tmpdir, mode): + class IndexedRandomDataset(RandomDataset): + def __getitem__(self, index): + return self.data[index] + + class CustomDataLoader(DataLoader): + def __init__(self, num_features, dataset, *args, **kwargs): + # argument `num_features` unused on purpose + # it gets automatically captured by _replace_dataloader_init_method() + super().__init__(dataset, *args, **kwargs) + + class CustomBatchSampler(BatchSampler): + pass + + class TestModel(BoringModel): + def __init__(self, numbers_test_dataloaders, mode): + super().__init__() + self._numbers_test_dataloaders = numbers_test_dataloaders + self._mode = mode + + def test_step(self, batch, batch_idx, dataloader_idx=0): + return super().test_step(batch, batch_idx) + + def on_test_start(self) -> None: + dataloader = self.trainer.test_dataloaders[0] + assert isinstance(dataloader, CustomDataLoader) + batch_sampler = dataloader.batch_sampler + if self._mode == 1: + assert isinstance(batch_sampler, CustomBatchSampler) + # the batch_size is set on the batch sampler + assert dataloader.batch_size is None + elif self._mode == 2: + assert type(batch_sampler) is BatchSampler + assert dataloader.batch_size == self._mode + assert batch_sampler.batch_size == self._mode + assert batch_sampler.drop_last + # the sampler has been replaced + assert isinstance(batch_sampler.sampler, DistributedSampler) + + def create_dataset(self): + dataset = IndexedRandomDataset(32, 64) + if self._mode == 1: + # with a custom batch sampler + batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=1, drop_last=True) + return CustomDataLoader(32, dataset, batch_sampler=batch_sampler) + elif self._mode == 2: + # with no batch sampler provided + return CustomDataLoader(32, dataset, batch_size=2, drop_last=True) + + def test_dataloader(self): + return [self.create_dataset()] * self._numbers_test_dataloaders + + model = TestModel(2, mode) + model.test_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_test_batches=2, + accelerator="cpu", + devices=1, + strategy="ddp_find_unused_parameters_false", + ) + trainer.test(model) + + +class TestSpawnBoringModel(BoringModel): + def __init__(self, num_workers): + super().__init__() + self.num_workers = num_workers + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) + + def on_fit_start(self): + self._resout = StringIO() + self.ctx = redirect_stderr(self._resout) + self.ctx.__enter__() + + def on_train_end(self): + def _get_warning_msg(): + dl = self.trainer.train_dataloader.loaders + if hasattr(dl, "persistent_workers"): + if self.num_workers == 0: + warn_str = "Consider setting num_workers>0 and persistent_workers=True" + else: + warn_str = "Consider setting persistent_workers=True" + else: + warn_str = "Consider setting strategy=ddp" + + return warn_str + + if self.trainer.is_global_zero: + self.ctx.__exit__(None, None, None) + msg = self._resout.getvalue() + warn_str = _get_warning_msg() + assert warn_str in msg + + +@RunIf(skip_windows=True) +@pytest.mark.parametrize("num_workers", [0, 1]) +def test_dataloader_warnings(tmpdir, num_workers): + trainer = Trainer(default_root_dir=tmpdir, accelerator="cpu", devices=2, strategy="ddp_spawn", fast_dev_run=4) + assert isinstance(trainer.strategy, DDPSpawnStrategy) + trainer.fit(TestSpawnBoringModel(num_workers)) + + +def test_update_dataloader_raises(): + with pytest.raises(ValueError, match="needs to subclass `torch.utils.data.DataLoader"): + _update_dataloader(object(), object(), mode="fit") + + +def test_dataloaders_with_missing_keyword_arguments(): + ds = RandomDataset(10, 20) + + class TestDataLoader(DataLoader): + def __init__(self, dataset): + super().__init__(dataset) + + loader = TestDataLoader(ds) + sampler = SequentialSampler(ds) + match = escape("missing arguments are ['batch_sampler', 'sampler', 'shuffle']") + with pytest.raises(MisconfigurationException, match=match): + _update_dataloader(loader, sampler, mode="fit") + match = escape("missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler', 'shuffle']") + with pytest.raises(MisconfigurationException, match=match): + _update_dataloader(loader, sampler, mode="predict") + + class TestDataLoader(DataLoader): + def __init__(self, dataset, *args, **kwargs): + super().__init__(dataset) + + loader = TestDataLoader(ds) + sampler = SequentialSampler(ds) + _update_dataloader(loader, sampler, mode="fit") + _update_dataloader(loader, sampler, mode="predict") + + class TestDataLoader(DataLoader): + def __init__(self, *foo, **bar): + super().__init__(*foo, **bar) + + loader = TestDataLoader(ds) + sampler = SequentialSampler(ds) + _update_dataloader(loader, sampler, mode="fit") + _update_dataloader(loader, sampler, mode="predict") + + class TestDataLoader(DataLoader): + def __init__(self, num_feat, dataset, *args, shuffle=False): + self.num_feat = num_feat + super().__init__(dataset) + + loader = TestDataLoader(1, ds) + sampler = SequentialSampler(ds) + match = escape("missing arguments are ['batch_sampler', 'sampler']") + with pytest.raises(MisconfigurationException, match=match): + _update_dataloader(loader, sampler, mode="fit") + match = escape("missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler']") + with pytest.raises(MisconfigurationException, match=match): + _update_dataloader(loader, sampler, mode="predict") + + class TestDataLoader(DataLoader): + def __init__(self, num_feat, dataset, **kwargs): + self.feat_num = num_feat + super().__init__(dataset) + + loader = TestDataLoader(1, ds) + sampler = SequentialSampler(ds) + match = escape("missing attributes are ['num_feat']") + with pytest.raises(MisconfigurationException, match=match): + _update_dataloader(loader, sampler, mode="fit") + match = escape("missing attributes are ['num_feat']") + with pytest.raises(MisconfigurationException, match=match): + _update_dataloader(loader, sampler, mode="predict") + + +def test_update_dataloader_with_multiprocessing_context(): + """This test verifies that replace_sampler conserves multiprocessing context.""" + train = RandomDataset(32, 64) + context = "spawn" + train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) + new_data_loader = _update_dataloader(train, SequentialSampler(train.dataset)) + assert new_data_loader.multiprocessing_context == train.multiprocessing_context + + +def test_dataloader_reinit_for_subclass(): + class CustomDataLoader(DataLoader): + def __init__( + self, + dataset, + batch_size=1, + shuffle=False, + sampler=None, + batch_sampler=None, + num_workers=0, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + dummy_kwarg=None, + ): + super().__init__( + dataset, + batch_size, + shuffle, + sampler, + batch_sampler, + num_workers, + collate_fn, + pin_memory, + drop_last, + timeout, + worker_init_fn, + ) + self.dummy_kwarg = dummy_kwarg + self.something_unrelated = 1 + + trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp_spawn") + + class CustomDummyObj: + sampler = None + + result = trainer._data_connector._prepare_dataloader(CustomDummyObj(), shuffle=True) + assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" + + dataset = list(range(10)) + result = trainer._data_connector._prepare_dataloader(CustomDataLoader(dataset), shuffle=True) + assert isinstance(result, DataLoader) + assert isinstance(result, CustomDataLoader) + assert result.dummy_kwarg is None + + # Shuffled DataLoader should also work + result = trainer._data_connector._prepare_dataloader(CustomDataLoader(dataset, shuffle=True), shuffle=True) + assert isinstance(result, DataLoader) + assert isinstance(result, CustomDataLoader) + assert result.dummy_kwarg is None + + class CustomSampler(Sampler): + pass + + # Should raise an error if existing sampler is being replaced + dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset)) + with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): + trainer._data_connector._prepare_dataloader(dataloader, shuffle=True) + + +class LoaderTestModel(BoringModel): + def training_step(self, batch, batch_idx): + assert len(self.trainer.train_dataloader.loaders) == 10 + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + assert len(self.trainer.val_dataloaders[0]) == 10 + return super().validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + assert len(self.trainer.test_dataloaders[0]) == 10 + return super().test_step(batch, batch_idx) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + assert len(self.trainer.predict_dataloaders[0]) == 10 + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + +def test_loader_detaching(): + """Checks that the loader has been reset after the entrypoint.""" + + loader = DataLoader(RandomDataset(32, 10), batch_size=1) + + model = LoaderTestModel() + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer = Trainer(fast_dev_run=1) + trainer.fit(model, loader, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer.validate(model, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer.predict(model, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + trainer.test(model, loader) + + assert len(model.train_dataloader()) == 64 + assert len(model.val_dataloader()) == 64 + assert len(model.predict_dataloader()) == 64 + assert len(model.test_dataloader()) == 64 + + +def test_pre_made_batches(): + """Check that loader works with pre-made batches.""" + loader = DataLoader(RandomDataset(32, 10), batch_size=None) + trainer = Trainer(fast_dev_run=1) + trainer.predict(LoaderTestModel(), loader) + + +def test_error_raised_with_float_limited_eval_batches(): + """Test that an error is raised if there are not enough batches when passed with float value of + limit_eval_batches.""" + model = BoringModel() + dl_size = len(model.val_dataloader()) + limit_val_batches = 1 / (dl_size + 2) + trainer = Trainer(limit_val_batches=limit_val_batches) + trainer._data_connector.attach_data(model) + with pytest.raises( + MisconfigurationException, + match=rf"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", + ): + trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model) + + +@pytest.mark.parametrize( + "val_dl,warns", + [ + (DataLoader(dataset=RandomDataset(32, 64), shuffle=True), True), + (DataLoader(dataset=RandomDataset(32, 64), sampler=list(range(64))), False), + (CombinedLoader(DataLoader(dataset=RandomDataset(32, 64), shuffle=True)), True), + ( + CombinedLoader( + [DataLoader(dataset=RandomDataset(32, 64)), DataLoader(dataset=RandomDataset(32, 64), shuffle=True)] + ), + True, + ), + ( + CombinedLoader( + { + "dl1": DataLoader(dataset=RandomDataset(32, 64)), + "dl2": DataLoader(dataset=RandomDataset(32, 64), shuffle=True), + } + ), + True, + ), + ], +) +def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl, warns): + trainer = Trainer() + model = BoringModel() + trainer._data_connector.attach_data(model, val_dataloaders=val_dl) + context = pytest.warns if warns else no_warning_call + with context(PossibleUserWarning, match="recommended .* turn shuffling off for val/test/predict"): + trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model) + + class NoDataLoaderModel(BoringModel): def __init__(self): super().__init__() diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 0ca9bf3107b9c..aed065596474c 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -15,6 +15,7 @@ import collections import itertools import os +from contextlib import redirect_stdout from io import StringIO from unittest import mock from unittest.mock import call @@ -28,10 +29,13 @@ from pytorch_lightning.loops.dataloader import EvaluationLoop from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 +from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _RICH_AVAILABLE from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf +if _RICH_AVAILABLE: + from rich import get_console + def test__validation_step__log(tmpdir): """Tests that validation_step can log.""" @@ -794,8 +798,12 @@ def test_dataloader(self): inputs1 = ( [ - {"performance": {"log1": torch.tensor(5), "log2": torch.tensor(3)}}, - {"test": {"no_log1": torch.tensor(6), "no_log2": torch.tensor(1)}}, + { + "value": torch.tensor(2), + "performance": {"log:1": torch.tensor(0), "log2": torch.tensor(3), "log3": torch.tensor(7)}, + "extra": {"log3": torch.tensor(7)}, + }, + {"different value": torch.tensor(1.5), "tes:t": {"no_log1": torch.tensor(6), "no_log2": torch.tensor(1)}}, ], RunningStage.TESTING, ) @@ -803,10 +811,14 @@ def test_dataloader(self): ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 DataLoader 1 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── - log1 5 - log2 3 - no_log1 6 - no_log2 1 + different value 1.5 + extra:log3 7 + performance:log2 3 + performance:log3 7 + performance:log:1 0 + tes:t:no_log1 6 + tes:t:no_log2 1 + value 2 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── """ @@ -864,8 +876,9 @@ def test_native_print_results(monkeypatch, inputs, expected): import pytorch_lightning.loops.dataloader.evaluation_loop as imports monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) - out = StringIO() - EvaluationLoop._print_results(*inputs, file=out) + + with redirect_stdout(StringIO()) as out: + EvaluationLoop._print_results(*inputs) expected = expected[1:] # remove the initial line break from the """ string assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip() @@ -878,7 +891,8 @@ def test_native_print_results_encodings(monkeypatch, encoding): out = mock.Mock() out.encoding = encoding - EvaluationLoop._print_results(*inputs0, file=out) + with redirect_stdout(out) as out: + EvaluationLoop._print_results(*inputs0) # Attempt to encode everything the file is told to write with the given encoding for call_ in out.method_calls: @@ -900,10 +914,14 @@ def test_native_print_results_encodings(monkeypatch, encoding): ┏━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ DataLoader 1 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ log1 │ 5 │ │ -│ log2 │ 3 │ │ -│ no_log1 │ │ 6 │ -│ no_log2 │ │ 1 │ +│ different value │ │ 1.5 │ +│ extra:log3 │ 7 │ │ +│ performance:log2 │ 3 │ │ +│ performance:log3 │ 7 │ │ +│ performance:log:1 │ 0 │ │ +│ tes:t:no_log1 │ │ 6 │ +│ tes:t:no_log2 │ │ 1 │ +│ value │ 2 │ │ └─────────────────────────┴─────────────────────────┴──────────────────────────┘ """ @@ -950,7 +968,8 @@ def test_native_print_results_encodings(monkeypatch, encoding): ) @RunIf(skip_windows=True, rich=True) def test_rich_print_results(inputs, expected): - out = StringIO() - EvaluationLoop._print_results(*inputs, file=out) + console = get_console() + with console.capture() as capture: + EvaluationLoop._print_results(*inputs) expected = expected[1:] # remove the initial line break from the """ string - assert out.getvalue() == expected.lstrip() + assert capture.get() == expected.lstrip() diff --git a/tests/trainer/properties/test_loggers.py b/tests/trainer/properties/test_loggers.py index 7598bc153243e..1b81d4a942f02 100644 --- a/tests/trainer/properties/test_loggers.py +++ b/tests/trainer/properties/test_loggers.py @@ -61,7 +61,8 @@ def test_trainer_loggers_setters(): assert trainer.loggers == [logger1] trainer.logger = logger_collection - assert trainer.logger._logger_iterable == logger_collection._logger_iterable + with pytest.deprecated_call(match="logger` when multiple loggers are configured"): + assert trainer.logger._logger_iterable == logger_collection._logger_iterable assert trainer.loggers == [logger1, logger2] # LoggerCollection of size 1 should result in trainer.logger becoming the contained logger. @@ -76,7 +77,8 @@ def test_trainer_loggers_setters(): # Test setters for trainer.loggers trainer.loggers = [logger1, logger2] assert trainer.loggers == [logger1, logger2] - assert trainer.logger._logger_iterable == logger_collection._logger_iterable + with pytest.deprecated_call(match="logger` when multiple loggers are configured"): + assert trainer.logger._logger_iterable == logger_collection._logger_iterable trainer.loggers = [logger1] assert trainer.loggers == [logger1] diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py deleted file mode 100644 index 4455cd89e2104..0000000000000 --- a/tests/trainer/test_data_loading.py +++ /dev/null @@ -1,382 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from contextlib import redirect_stderr -from io import StringIO -from re import escape - -import pytest -from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler - -from pytorch_lightning import Trainer -from pytorch_lightning.strategies import DDPSpawnStrategy -from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.utilities.data import _update_dataloader -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.warnings import PossibleUserWarning -from tests.helpers import BoringModel, RandomDataset -from tests.helpers.runif import RunIf - - -@RunIf(skip_windows=True) -@pytest.mark.parametrize("mode", (1, 2)) -def test_replace_distributed_sampler(tmpdir, mode): - class IndexedRandomDataset(RandomDataset): - def __getitem__(self, index): - return self.data[index] - - class CustomDataLoader(DataLoader): - def __init__(self, num_features, dataset, *args, **kwargs): - # argument `num_features` unused on purpose - # it gets automatically captured by _replace_dataloader_init_method() - super().__init__(dataset, *args, **kwargs) - - class CustomBatchSampler(BatchSampler): - pass - - class TestModel(BoringModel): - def __init__(self, numbers_test_dataloaders, mode): - super().__init__() - self._numbers_test_dataloaders = numbers_test_dataloaders - self._mode = mode - - def test_step(self, batch, batch_idx, dataloader_idx=0): - return super().test_step(batch, batch_idx) - - def on_test_start(self) -> None: - dataloader = self.trainer.test_dataloaders[0] - assert isinstance(dataloader, CustomDataLoader) - batch_sampler = dataloader.batch_sampler - if self._mode == 1: - assert isinstance(batch_sampler, CustomBatchSampler) - # the batch_size is set on the batch sampler - assert dataloader.batch_size is None - elif self._mode == 2: - assert type(batch_sampler) is BatchSampler - assert dataloader.batch_size == self._mode - assert batch_sampler.batch_size == self._mode - assert batch_sampler.drop_last - # the sampler has been replaced - assert isinstance(batch_sampler.sampler, DistributedSampler) - - def create_dataset(self): - dataset = IndexedRandomDataset(32, 64) - if self._mode == 1: - # with a custom batch sampler - batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=1, drop_last=True) - return CustomDataLoader(32, dataset, batch_sampler=batch_sampler) - elif self._mode == 2: - # with no batch sampler provided - return CustomDataLoader(32, dataset, batch_size=2, drop_last=True) - - def test_dataloader(self): - return [self.create_dataset()] * self._numbers_test_dataloaders - - model = TestModel(2, mode) - model.test_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - limit_test_batches=2, - accelerator="cpu", - devices=1, - strategy="ddp_find_unused_parameters_false", - ) - trainer.test(model) - - -class TestSpawnBoringModel(BoringModel): - def __init__(self, num_workers): - super().__init__() - self.num_workers = num_workers - - def train_dataloader(self): - return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers) - - def on_fit_start(self): - self._resout = StringIO() - self.ctx = redirect_stderr(self._resout) - self.ctx.__enter__() - - def on_train_end(self): - def _get_warning_msg(): - dl = self.trainer.train_dataloader.loaders - if hasattr(dl, "persistent_workers"): - if self.num_workers == 0: - warn_str = "Consider setting num_workers>0 and persistent_workers=True" - else: - warn_str = "Consider setting persistent_workers=True" - else: - warn_str = "Consider setting strategy=ddp" - - return warn_str - - if self.trainer.is_global_zero: - self.ctx.__exit__(None, None, None) - msg = self._resout.getvalue() - warn_str = _get_warning_msg() - assert warn_str in msg - - -@RunIf(skip_windows=True) -@pytest.mark.parametrize("num_workers", [0, 1]) -def test_dataloader_warnings(tmpdir, num_workers): - trainer = Trainer(default_root_dir=tmpdir, accelerator="cpu", devices=2, strategy="ddp_spawn", fast_dev_run=4) - assert isinstance(trainer.strategy, DDPSpawnStrategy) - trainer.fit(TestSpawnBoringModel(num_workers)) - - -def test_update_dataloader_raises(): - with pytest.raises(ValueError, match="needs to subclass `torch.utils.data.DataLoader"): - _update_dataloader(object(), object(), mode="fit") - - -def test_dataloaders_with_missing_keyword_arguments(): - ds = RandomDataset(10, 20) - - class TestDataLoader(DataLoader): - def __init__(self, dataset): - super().__init__(dataset) - - loader = TestDataLoader(ds) - sampler = SequentialSampler(ds) - match = escape("missing arguments are ['batch_sampler', 'sampler', 'shuffle']") - with pytest.raises(MisconfigurationException, match=match): - _update_dataloader(loader, sampler, mode="fit") - match = escape("missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler', 'shuffle']") - with pytest.raises(MisconfigurationException, match=match): - _update_dataloader(loader, sampler, mode="predict") - - class TestDataLoader(DataLoader): - def __init__(self, dataset, *args, **kwargs): - super().__init__(dataset) - - loader = TestDataLoader(ds) - sampler = SequentialSampler(ds) - _update_dataloader(loader, sampler, mode="fit") - _update_dataloader(loader, sampler, mode="predict") - - class TestDataLoader(DataLoader): - def __init__(self, *foo, **bar): - super().__init__(*foo, **bar) - - loader = TestDataLoader(ds) - sampler = SequentialSampler(ds) - _update_dataloader(loader, sampler, mode="fit") - _update_dataloader(loader, sampler, mode="predict") - - class TestDataLoader(DataLoader): - def __init__(self, num_feat, dataset, *args, shuffle=False): - self.num_feat = num_feat - super().__init__(dataset) - - loader = TestDataLoader(1, ds) - sampler = SequentialSampler(ds) - match = escape("missing arguments are ['batch_sampler', 'sampler']") - with pytest.raises(MisconfigurationException, match=match): - _update_dataloader(loader, sampler, mode="fit") - match = escape("missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler']") - with pytest.raises(MisconfigurationException, match=match): - _update_dataloader(loader, sampler, mode="predict") - - class TestDataLoader(DataLoader): - def __init__(self, num_feat, dataset, **kwargs): - self.feat_num = num_feat - super().__init__(dataset) - - loader = TestDataLoader(1, ds) - sampler = SequentialSampler(ds) - match = escape("missing attributes are ['num_feat']") - with pytest.raises(MisconfigurationException, match=match): - _update_dataloader(loader, sampler, mode="fit") - match = escape("missing attributes are ['num_feat']") - with pytest.raises(MisconfigurationException, match=match): - _update_dataloader(loader, sampler, mode="predict") - - -def test_update_dataloader_with_multiprocessing_context(): - """This test verifies that replace_sampler conserves multiprocessing context.""" - train = RandomDataset(32, 64) - context = "spawn" - train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) - new_data_loader = _update_dataloader(train, SequentialSampler(train.dataset)) - assert new_data_loader.multiprocessing_context == train.multiprocessing_context - - -def test_dataloader_reinit_for_subclass(): - class CustomDataLoader(DataLoader): - def __init__( - self, - dataset, - batch_size=1, - shuffle=False, - sampler=None, - batch_sampler=None, - num_workers=0, - collate_fn=None, - pin_memory=False, - drop_last=False, - timeout=0, - worker_init_fn=None, - dummy_kwarg=None, - ): - super().__init__( - dataset, - batch_size, - shuffle, - sampler, - batch_sampler, - num_workers, - collate_fn, - pin_memory, - drop_last, - timeout, - worker_init_fn, - ) - self.dummy_kwarg = dummy_kwarg - self.something_unrelated = 1 - - trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp_spawn") - - class CustomDummyObj: - sampler = None - - result = trainer._data_connector._prepare_dataloader(CustomDummyObj(), shuffle=True) - assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" - - dataset = list(range(10)) - result = trainer._data_connector._prepare_dataloader(CustomDataLoader(dataset), shuffle=True) - assert isinstance(result, DataLoader) - assert isinstance(result, CustomDataLoader) - assert result.dummy_kwarg is None - - # Shuffled DataLoader should also work - result = trainer._data_connector._prepare_dataloader(CustomDataLoader(dataset, shuffle=True), shuffle=True) - assert isinstance(result, DataLoader) - assert isinstance(result, CustomDataLoader) - assert result.dummy_kwarg is None - - class CustomSampler(Sampler): - pass - - # Should raise an error if existing sampler is being replaced - dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset)) - with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): - trainer._data_connector._prepare_dataloader(dataloader, shuffle=True) - - -class LoaderTestModel(BoringModel): - def training_step(self, batch, batch_idx): - assert len(self.trainer.train_dataloader.loaders) == 10 - return super().training_step(batch, batch_idx) - - def validation_step(self, batch, batch_idx): - assert len(self.trainer.val_dataloaders[0]) == 10 - return super().validation_step(batch, batch_idx) - - def test_step(self, batch, batch_idx): - assert len(self.trainer.test_dataloaders[0]) == 10 - return super().test_step(batch, batch_idx) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - assert len(self.trainer.predict_dataloaders[0]) == 10 - return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) - - -def test_loader_detaching(): - """Checks that the loader has been reset after the entrypoint.""" - - loader = DataLoader(RandomDataset(32, 10), batch_size=1) - - model = LoaderTestModel() - - assert len(model.train_dataloader()) == 64 - assert len(model.val_dataloader()) == 64 - assert len(model.predict_dataloader()) == 64 - assert len(model.test_dataloader()) == 64 - - trainer = Trainer(fast_dev_run=1) - trainer.fit(model, loader, loader) - - assert len(model.train_dataloader()) == 64 - assert len(model.val_dataloader()) == 64 - assert len(model.predict_dataloader()) == 64 - assert len(model.test_dataloader()) == 64 - - trainer.validate(model, loader) - - assert len(model.train_dataloader()) == 64 - assert len(model.val_dataloader()) == 64 - assert len(model.predict_dataloader()) == 64 - assert len(model.test_dataloader()) == 64 - - trainer.predict(model, loader) - - assert len(model.train_dataloader()) == 64 - assert len(model.val_dataloader()) == 64 - assert len(model.predict_dataloader()) == 64 - assert len(model.test_dataloader()) == 64 - - trainer.test(model, loader) - - assert len(model.train_dataloader()) == 64 - assert len(model.val_dataloader()) == 64 - assert len(model.predict_dataloader()) == 64 - assert len(model.test_dataloader()) == 64 - - -def test_pre_made_batches(): - """Check that loader works with pre-made batches.""" - loader = DataLoader(RandomDataset(32, 10), batch_size=None) - trainer = Trainer(fast_dev_run=1) - trainer.predict(LoaderTestModel(), loader) - - -def test_error_raised_with_float_limited_eval_batches(): - """Test that an error is raised if there are not enough batches when passed with float value of - limit_eval_batches.""" - model = BoringModel() - dl_size = len(model.val_dataloader()) - limit_val_batches = 1 / (dl_size + 2) - trainer = Trainer(limit_val_batches=limit_val_batches) - trainer._data_connector.attach_data(model) - with pytest.raises( - MisconfigurationException, - match=rf"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", - ): - trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model) - - -@pytest.mark.parametrize( - "val_dl", - [ - DataLoader(dataset=RandomDataset(32, 64), shuffle=True), - CombinedLoader(DataLoader(dataset=RandomDataset(32, 64), shuffle=True)), - CombinedLoader( - [DataLoader(dataset=RandomDataset(32, 64)), DataLoader(dataset=RandomDataset(32, 64), shuffle=True)] - ), - CombinedLoader( - { - "dl1": DataLoader(dataset=RandomDataset(32, 64)), - "dl2": DataLoader(dataset=RandomDataset(32, 64), shuffle=True), - } - ), - ], -) -def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl): - trainer = Trainer() - model = BoringModel() - trainer._data_connector.attach_data(model, val_dataloaders=val_dl) - with pytest.warns(PossibleUserWarning, match="recommended .* turn shuffling off for val/test/predict"): - trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model)