diff --git a/CHANGELOG.md b/CHANGELOG.md index 45f5efcc56216..1d23dc150648b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.5.2] - 2021-11-16 + +### Fixed + +- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) +- Fixed an issue where class or init-only variables of dataclasses were passed to the dataclass constructor in `utilities.apply_to_collection` ([#9702](https://github.com/PyTorchLightning/pytorch-lightning/issues/9702)) +- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493)) +- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463)) +- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461)) +- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486)) +- Fixed scripting causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/pull/10470), [#10555](https://github.com/PyTorchLightning/pytorch-lightning/pull/10555)) +- Do not fail if batch size could not be inferred for logging when using DeepSpeed ([#10438](https://github.com/PyTorchLightning/pytorch-lightning/issues/10438)) +- Fixed propagation of device and dtype information to submodules of LightningLite when they inherit from `DeviceDtypeModuleMixin` ([#10559](https://github.com/PyTorchLightning/pytorch-lightning/issues/10559)) + + ## [1.5.1] - 2021-11-09 ### Fixed diff --git a/pytorch_lightning/__about__.py b/pytorch_lightning/__about__.py index e1bdee9b7320b..bc9aefd1d2dad 100644 --- a/pytorch_lightning/__about__.py +++ b/pytorch_lightning/__about__.py @@ -1,7 +1,7 @@ import time _this_year = time.strftime("%Y") -__version__ = "1.5.1" +__version__ = "1.5.2" __author__ = "William Falcon et al." __author_email__ = "waf2107@columbia.edu" __license__ = "Apache-2.0" diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index b5118846875db..096cb4849cc39 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -202,7 +202,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: ): # short circuit if metric not present return - current = logs.get(self.monitor) + current = logs[self.monitor].squeeze() should_stop, reason = self._evaluate_stopping_criteria(current) # stop every ddp process if any world process decides to stop diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index ab771992a960b..3c70bfe735f95 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -129,11 +129,12 @@ def render(self, task) -> RenderableType: class MetricsTextColumn(ProgressColumn): """A column containing text.""" - def __init__(self, trainer): + def __init__(self, trainer, style): self._trainer = trainer self._tasks = {} self._current_task_id = 0 self._metrics = {} + self._style = style super().__init__() def update(self, metrics): @@ -158,23 +159,34 @@ def render(self, task) -> Text: for k, v in self._metrics.items(): _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " - return Text(_text, justify="left") + return Text(_text, justify="left", style=self._style) @dataclass class RichProgressBarTheme: """Styles to associate to different base components. + Args: + description: Style for the progress bar description. For eg., Epoch x, Testing, etc. + progress_bar: Style for the bar in progress. + progress_bar_finished: Style for the finished progress bar. + progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed. + batch_progress: Style for the progress tracker (i.e 10/50 batches completed). + time: Style for the processed time and estimate time remaining. + processing_speed: Style for the speed of the batches being processed. + metrics: Style for the metrics + https://rich.readthedocs.io/en/stable/style.html """ - text_color: str = "white" - progress_bar_complete: Union[str, Style] = "#6206E0" + description: Union[str, Style] = "white" + progress_bar: Union[str, Style] = "#6206E0" progress_bar_finished: Union[str, Style] = "#6206E0" progress_bar_pulse: Union[str, Style] = "#6206E0" - batch_process: str = "white" - time: str = "grey54" - processing_speed: str = "grey70" + batch_progress: Union[str, Style] = "white" + time: Union[str, Style] = "grey54" + processing_speed: Union[str, Style] = "grey70" + metrics: Union[str, Style] = "white" class RichProgressBar(ProgressBarBase): @@ -268,7 +280,7 @@ def _init_progress(self, trainer): self._reset_progress_bar_ids() self._console: Console = Console() self._console.clear_live() - self._metric_component = MetricsTextColumn(trainer) + self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, @@ -351,7 +363,7 @@ def on_validation_epoch_start(self, trainer, pl_module): def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: return self.progress.add_task( - f"[{self.theme.text_color}]{description}", total=total_batches, visible=visible + f"[{self.theme.description}]{description}", total=total_batches, visible=visible ) def _update(self, progress_bar_id: int, visible: bool = True) -> None: @@ -448,11 +460,11 @@ def configure_columns(self, trainer) -> list: return [ TextColumn("[progress.description]{task.description}"), CustomBarColumn( - complete_style=self.theme.progress_bar_complete, + complete_style=self.theme.progress_bar, finished_style=self.theme.progress_bar_finished, pulse_style=self.theme.progress_bar_pulse, ), - BatchesProcessedColumn(style=self.theme.batch_process), + BatchesProcessedColumn(style=self.theme.batch_progress), CustomTimeColumn(style=self.theme.time), ProcessingSpeedColumn(style=self.theme.processing_speed), ] diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c59193859b171..bc89cc2b18e93 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -116,6 +116,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._param_requires_grad_state = {} self._metric_attributes: Optional[Dict[int, str]] = None self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False + # TODO: remove after the 1.6 release + self._running_torchscript = False self._register_sharded_tensor_state_dict_hooks_if_available() @@ -1962,6 +1964,8 @@ def to_torchscript( """ mode = self.training + self._running_torchscript = True + if method == "script": torchscript_module = torch.jit.script(self.eval(), **kwargs) elif method == "trace": @@ -1987,6 +1991,8 @@ def to_torchscript( with fs.open(file_path, "wb") as f: torch.jit.save(torchscript_module, f) + self._running_torchscript = False + return torchscript_module @property @@ -1996,11 +2002,12 @@ def model_size(self) -> float: Note: This property will not return correct value for Deepspeed (stage 3) and fully-sharded training. """ - rank_zero_deprecation( - "The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7." - " Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.", - stacklevel=5, - ) + if not self._running_torchscript: # remove with the deprecation removal + rank_zero_deprecation( + "The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7." + " Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.", + stacklevel=5, + ) return get_model_size_mb(self) def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index e02790edddd1e..e8b122989cd9c 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -17,6 +17,8 @@ import torch from torch.nn import Module +import pytorch_lightning as pl + class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ["device", "dtype"] @@ -177,7 +179,9 @@ def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None ) -> None: def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: - if not isinstance(module, DeviceDtypeModuleMixin): + # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't + # work when using `init_meta_context`. + if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): return if device is not None: module._device = device diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 615f461055204..ff95e89d1d2cf 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -24,6 +24,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.plugins import PrecisionPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -64,7 +65,7 @@ def step(self, closure: Optional[Callable] = None) -> None: ) -class _LiteModule(nn.Module): +class _LiteModule(DeviceDtypeModuleMixin): def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None: """The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast automatically for the forward pass. diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index f26fc75ac58db..1ceadb8658a3d 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -240,7 +240,9 @@ def log_graph(self, model: "pl.LightningModule", input_array=None): if input_array is not None: input_array = model._apply_batch_transfer_handler(input_array) + model._running_torchscript = True self.experiment.add_graph(model, input_array) + model._running_torchscript = False else: rank_zero_warn( "Could not log computational graph since the" diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 8dc42c2f36b88..13f518d53ec00 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -622,11 +622,6 @@ def _format_batch_size_and_grad_accum_config(self): ) self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches if "train_micro_batch_size_per_gpu" not in self.config: - rank_zero_warn( - "Inferring the batch size for internal deepspeed logging from the `train_dataloader()`. " - "If you require skipping this, please pass " - "`Trainer(strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`" - ) batch_size = self._auto_select_batch_size() self.config["train_micro_batch_size_per_gpu"] = batch_size if "gradient_clipping" not in self.config: @@ -638,9 +633,19 @@ def _auto_select_batch_size(self): batch_size = 1 train_dl_source = self.lightning_module.trainer._data_connector._train_dataloader_source if train_dl_source.is_defined(): - train_dataloader = train_dl_source.dataloader() - if hasattr(train_dataloader, "batch_sampler"): - batch_size = train_dataloader.batch_sampler.batch_size + try: + train_dataloader = train_dl_source.dataloader() + if hasattr(train_dataloader, "batch_sampler"): + batch_size = train_dataloader.batch_sampler.batch_size + # broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup` + # to have been called before + except Exception: + if self.global_rank == 0: + deepspeed.utils.logging.logger.warning( + "Tried to infer the batch size for internal deepspeed logging from the `train_dataloader()`. " + "To ensure DeepSpeed logging remains correct, please manually pass the plugin with the " + "batch size, `Trainer(strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=batch_size))`." + ) return batch_size def _format_precision_config(self): diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 4d9f937c58467..898e62791d6ee 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -237,21 +237,25 @@ def to_tensor(x): args = apply_to_collection(args, dtype=(int, float), function=to_tensor) return args - def training_step(self, *args, **kwargs): + def _step(self, stage: RunningStage, *args: Any, **kwargs: Any): args = self._prepare_input(args) - return self.poptorch_models[RunningStage.TRAINING](*args, **kwargs) + poptorch_model = self.poptorch_models[stage] + self.lightning_module._running_torchscript = True + out = poptorch_model(*args, **kwargs) + self.lightning_module._running_torchscript = False + return out + + def training_step(self, *args, **kwargs): + return self._step(RunningStage.TRAINING, *args, **kwargs) def validation_step(self, *args, **kwargs): - args = self._prepare_input(args) - return self.poptorch_models[RunningStage.VALIDATING](*args, **kwargs) + return self._step(RunningStage.VALIDATING, *args, **kwargs) def test_step(self, *args, **kwargs): - args = self._prepare_input(args) - return self.poptorch_models[RunningStage.TESTING](*args, **kwargs) + return self._step(RunningStage.TESTING, *args, **kwargs) def predict_step(self, *args, **kwargs): - args = self._prepare_input(args) - return self.poptorch_models[RunningStage.PREDICTING](*args, **kwargs) + return self._step(RunningStage.PREDICTING, *args, **kwargs) def teardown(self) -> None: # undo dataloader patching diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f798cf3ee2b82..53034ac77db3f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -51,8 +51,8 @@ class _Sync: fn: Optional[Callable] = None _should: bool = False rank_zero_only: bool = False - op: Optional[str] = None - group: Optional[Any] = None + _op: Optional[str] = None + _group: Optional[Any] = None def __post_init__(self) -> None: self._generate_sync_fn() @@ -67,6 +67,26 @@ def should(self, should: bool) -> None: # `self._fn` needs to be re-generated. self._generate_sync_fn() + @property + def op(self) -> Optional[str]: + return self._op + + @op.setter + def op(self, op: Optional[str]) -> None: + self._op = op + # `self._fn` needs to be re-generated. + self._generate_sync_fn() + + @property + def group(self) -> Optional[Any]: + return self._group + + @group.setter + def group(self, group: Optional[Any]) -> None: + self._group = group + # `self._fn` needs to be re-generated. + self._generate_sync_fn() + def _generate_sync_fn(self) -> None: """Used to compute the syncing function and cache it.""" fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn @@ -426,7 +446,7 @@ def log( dataloader_idx=dataloader_idx, metric_attribute=metric_attribute, ) - meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only) + meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, _group=sync_dist_group, rank_zero_only=rank_zero_only) # register logged value if it doesn't exist if key not in self: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 0f1eccc3f7cbb..9c40e728391c1 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -28,7 +28,7 @@ from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( @@ -136,14 +136,22 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn if isinstance(dataloader, CombinedLoader): # apply `prepare_dataloader` on all the collection of loaders dataloader.loaders = apply_to_collection( - dataloader.loaders, DataLoader, self.prepare_dataloader, shuffle, mode=mode + dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode ) + # the length need to recomputed across all dataloaders in case of special behavior. + dataloader._apply_cycle_iterator_length() return dataloader # don't do anything if it's not a dataloader - if not isinstance(dataloader, DataLoader): + if not isinstance(dataloader, (DataLoader, CycleIterator)): return dataloader + cycle_iterator: Optional[CycleIterator] = None + + if isinstance(dataloader, CycleIterator): + cycle_iterator = dataloader + dataloader = dataloader.loader + if ( _fault_tolerant_training() # injects components to track the state or self._requires_distributed_sampler(dataloader) # sets the distributed sampler @@ -153,6 +161,10 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) dataloader = self._update_dataloader(dataloader, sampler, mode=mode) + if cycle_iterator is not None: + cycle_iterator.loader = dataloader + return cycle_iterator + return dataloader def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler: @@ -426,8 +438,7 @@ def _reset_eval_dataloader( for loader_i in range(len(dataloaders)): loader = dataloaders[loader_i] - if hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler): - + if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0 and mode.evaluating: rank_zero_warn( @@ -579,16 +590,17 @@ def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: @staticmethod def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]: - has_random_sampler = False + all_have_sequential_sampler = True - def resolve_had_random_sampler(dataloader: DataLoader): - nonlocal has_random_sampler - if not has_random_sampler: - has_random_sampler = isinstance(dataloader.sampler, RandomSampler) + def resolve_has_no_sequential_sampler(dataloader: DataLoader): + nonlocal all_have_sequential_sampler + all_have_sequential_sampler = all_have_sequential_sampler & isinstance( + dataloader.sampler, SequentialSampler + ) - apply_to_collection(dataloader, DataLoader, resolve_had_random_sampler) + apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler) - if has_random_sampler: + if not all_have_sequential_sampler: rank_zero_warn( "You requested to overfit but enabled training dataloader shuffling." " We are turning off the training dataloader shuffling for you." diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 816f4da38f5b9..6e2e51e82bbf1 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -457,6 +457,19 @@ def _wrap_loaders_max_size_cycle(self) -> Any: ) state.reset() + def _apply_cycle_iterator_length(self) -> None: + """When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to + all dataloaders.""" + if self.mode != "max_size_cycle": + return + + def set_len(cycle_iterator: CycleIterator, length: int) -> None: + cycle_iterator.length = length + + all_lengths = apply_to_collection(self.loaders, CycleIterator, lambda c: get_len(c.loader)) + max_length = _nested_calc_num_data(all_lengths, max) + apply_to_collection(self.loaders, CycleIterator, set_len, length=max_length) + def __iter__(self) -> Any: """Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.""" @@ -473,11 +486,12 @@ def __getstate__patch__(*_): return iterator @staticmethod - def _calc_num_batches(loaders: Any) -> Union[int, float]: + def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]: """Compute the length (aka the number of batches) of `CombinedLoader`. Args: loaders: a collections of loaders. + mode: Mode used by the CombinedDataloader Returns: length: the minimum length of loaders @@ -486,10 +500,10 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]: if isinstance(all_lengths, (int, float)): return all_lengths - return _nested_calc_num_data(all_lengths, min) + return _nested_calc_num_data(all_lengths, max if mode == "max_size_cycle" else min) def __len__(self) -> int: - return self._calc_num_batches(self.loaders) + return self._calc_num_batches(self.loaders, mode=self.mode) @staticmethod def _shutdown_workers_and_reset_iterator(dataloader) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7e5d21e18dc26..b4a3f97025701 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -86,7 +86,7 @@ from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.meta import materialize_module +from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import ( @@ -697,6 +697,8 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: # reset bookkeeping self.state.stage = None self.on_exception(exception) + # shutdown workers + self._data_connector.teardown() raise def fit( @@ -1435,10 +1437,21 @@ def _call_setup_hook(self) -> None: def _call_configure_sharded_model(self) -> None: with self.accelerator.model_sharded_context(): - materialize_module(self.lightning_module) + self._handle_meta_model() self.call_hook("configure_sharded_model") self.call_hook("on_configure_sharded_model") + def _handle_meta_model(self) -> None: + if not is_on_meta_device(self.lightning_module): + return + + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") + + materialize_module(self.lightning_module) + # the trainer reference is lost during materialization + self.lightning_module.trainer = proxy(self) + def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 1e981a0f543e7..5a76f402bcc02 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -16,7 +16,7 @@ from abc import ABC from collections import defaultdict, OrderedDict from collections.abc import Mapping, Sequence -from copy import copy +from copy import copy, deepcopy from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union @@ -119,11 +119,21 @@ def apply_to_collection( return elem_type(*out) if is_namedtuple else elem_type(out) if _is_dataclass_instance(data): - out_dict = {} + # make a deepcopy of the data, + # but do not deepcopy mapped fields since the computation would + # be wasted on values that likely get immediately overwritten + fields = {} + memo = {} for field in dataclasses.fields(data): - if field.init: + field_value = getattr(data, field.name) + fields[field.name] = (field_value, field.init) + memo[id(field_value)] = field_value + result = deepcopy(data, memo=memo) + # apply function to each field + for field_name, (field_value, field_init) in fields.items(): + if field_init: v = apply_to_collection( - getattr(data, field.name), + field_value, dtype, function, *args, @@ -131,9 +141,10 @@ def apply_to_collection( include_none=include_none, **kwargs, ) - if include_none or v is not None: - out_dict[field.name] = v - return elem_type(**out_dict) + if not field_init or (not include_none and v is None): # retain old value + v = getattr(data, field_name) + setattr(result, field_name, v) + return result # data is neither of dtype, nor a collection return data diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 60e6cc791b7ae..6d3c1d6b5f11b 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -18,13 +18,14 @@ from functools import partial from itertools import chain from types import ModuleType -from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type import torch from torch import nn, Tensor from torch.nn import Module from torch.nn.modules.container import ModuleDict, ModuleList, Sequential +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10 @@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module: # cache subclasses to optimize the search when resetting the meta device later on. __STORAGE_META__ = {} - __CREATED_MODULES__ = set() @@ -237,45 +237,52 @@ def _set_meta_device() -> None: for subclass in get_all_subclasses(torch.nn.modules.module.Module): - if isinstance(subclass, (Sequential, ModuleList, ModuleDict)): + if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule): continue # if a subclass has already been stored, we should use the cache if str(subclass) in __STORAGE_META__: - # reset the class import package to its rightfull state. + # reset the class import package to its rightful state. mods, subclass, meta_class = __STORAGE_META__[subclass] for mod in mods: setattr(mod, subclass.__name__, meta_class) continue + class _IsinstanceMetaclass(type(subclass)): + def __instancecheck__(self, instance: Any) -> bool: + """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects.""" + return isinstance(instance, self.__bases__[0]) + # Create a class subclassing current `subclass` overriding its new method. # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` # version of the current subclass module - class _MetaClass(subclass): + class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): @classmethod @contextmanager - def instantiation_context(cls, materialize: bool): + def instantiation_context(cls): _unset_meta_device(from_created=True) yield _set_meta_device_populated(from_created=True) @classmethod def materialize(cls, materialize_fn: Callable): - with cls.instantiation_context(materialize=True): + with cls.instantiation_context(): obj = materialize_fn() return obj @staticmethod def add_subclasses(subclass): - """This is used to unrol the instantion tree while creating the modules.""" - __CREATED_MODULES__.add(subclass) + """This is used to unroll the instantiation tree while creating the modules.""" + # Don't store the LightningModule as skipped from the Meta process. + if subclass != pl.LightningModule: + __CREATED_MODULES__.add(subclass) if subclass.__bases__[0] != torch.nn.modules.module.Module: - _MetaClass.add_subclasses(subclass.__bases__[0]) + _MaterializerModule.add_subclasses(subclass.__bases__[0]) def __new__(cls, *args, **kwargs): subclass = cls.__bases__[0] cls.add_subclasses(subclass) - with cls.instantiation_context(materialize=False): + with cls.instantiation_context(): obj = init_meta(subclass, *args, **kwargs) obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) @@ -294,9 +301,8 @@ def search(mod: ModuleType) -> List[ModuleType]: # nn.Module class can be imported at different level and they all need to be mocked. # Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear # Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear - # needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass - out = [] - out.append(search(mod)) + # needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule + out = [search(mod)] for name in submodules[1:]: mod = getattr(mod, name) out.append(search(mod)) @@ -305,11 +311,11 @@ def search(mod: ModuleType) -> List[ModuleType]: mods = [mod for mod in chain(*out) if mod] # store the modules search so it doesn't have to be performed again for this class - __STORAGE_META__[subclass] = (mods, subclass, _MetaClass) + __STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule) # replace all subclass by its meta form for mod in mods: - setattr(mod, subclass.__name__, _MetaClass) + setattr(mod, subclass.__name__, _MaterializerModule) @contextmanager @@ -321,3 +327,11 @@ def init_meta_context() -> Generator: _set_meta_device() yield _unset_meta_device() + + +def is_on_meta_device(module: nn.Module) -> bool: + try: + param = next(module.parameters()) + return param.device.type == "meta" + except StopIteration: + return False diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 2b4fe9f05eb87..1540cbeba5189 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -381,7 +381,7 @@ def on_train_end(self) -> None: _ES_CHECK = dict(check_on_train_epoch_end=True) _ES_CHECK_P3 = dict(patience=3, check_on_train_epoch_end=True) -_NO_WIN = dict(marks=RunIf(skip_windows=True)) +_SPAWN_MARK = dict(marks=RunIf(skip_windows=True, skip_49370=True)) @pytest.mark.parametrize( @@ -389,8 +389,8 @@ def on_train_end(self) -> None: [ ([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, None, 1), ([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, None, 1), - pytest.param([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_spawn", 2, **_NO_WIN), - pytest.param([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, **_NO_WIN), + pytest.param([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_spawn", 2, **_SPAWN_MARK), + pytest.param([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, **_SPAWN_MARK), ([EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, None, 1), ([EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, None, 1), pytest.param( @@ -399,7 +399,7 @@ def on_train_end(self) -> None: True, "ddp_spawn", 2, - **_NO_WIN, + **_SPAWN_MARK, ), pytest.param( [EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], @@ -407,7 +407,7 @@ def on_train_end(self) -> None: True, "ddp_spawn", 2, - **_NO_WIN, + **_SPAWN_MARK, ), ], ) @@ -469,3 +469,16 @@ def validation_step(self, batch, batch_idx): assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval) else: assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch - 1 + + +def test_early_stopping_squeezes(): + early_stopping = EarlyStopping(monitor="foo") + trainer = Trainer() + trainer.callback_metrics["foo"] = torch.tensor([[[0]]]) + + with mock.patch( + "pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", return_value=(False, "") + ) as es_mock: + early_stopping._run_early_stopping_check(trainer) + + es_mock.assert_called_once_with(torch.tensor(0)) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 1c1f84b5b95a0..c813ed2b02e28 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -187,7 +187,7 @@ def test_pruning_callback_ddp_spawn(tmpdir): train_with_pruning_callback(tmpdir, use_global_unstructured=True, strategy="ddp_spawn", gpus=2) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_pruning_callback_ddp_cpu(tmpdir): train_with_pruning_callback(tmpdir, parameters_to_prune=True, strategy="ddp_spawn", num_processes=2) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 31681754423a8..8f3f20630b5c0 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -106,11 +106,11 @@ def test_rich_progress_bar_custom_theme(tmpdir): assert progress_bar.theme == theme args, kwargs = mocks["CustomBarColumn"].call_args - assert kwargs["complete_style"] == theme.progress_bar_complete + assert kwargs["complete_style"] == theme.progress_bar assert kwargs["finished_style"] == theme.progress_bar_finished args, kwargs = mocks["BatchesProcessedColumn"].call_args - assert kwargs["style"] == theme.batch_process + assert kwargs["style"] == theme.batch_progress args, kwargs = mocks["CustomTimeColumn"].call_args assert kwargs["style"] == theme.time diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index e10f99d33d564..4a0f154928adb 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -148,7 +148,7 @@ def test_swa_callback_ddp_spawn(tmpdir): train_with_swa(tmpdir, strategy="ddp_spawn", gpus=2) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_swa_callback_ddp_cpu(tmpdir): train_with_swa(tmpdir, strategy="ddp_spawn", num_processes=2) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 518d67cf251f5..04255d51ad069 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -385,7 +385,7 @@ def on_train_end(self, trainer, pl_module): assert torch.save.call_count == 0 -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_model_checkpoint_no_extraneous_invocations(tmpdir): """Test to ensure that the model callback saves the checkpoints only once in distributed mode.""" model = LogInTwoMethods() diff --git a/tests/checkpointing/test_torch_saving.py b/tests/checkpointing/test_torch_saving.py index 8b0f0e457bff9..f9634a9dadb2a 100644 --- a/tests/checkpointing/test_torch_saving.py +++ b/tests/checkpointing/test_torch_saving.py @@ -34,7 +34,7 @@ def test_model_torch_save(tmpdir): trainer = torch.load(temp_path) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_model_torch_save_ddp_cpu(tmpdir): """Test to ensure torch save does not fail for model and trainer using cpu ddp.""" model = BoringModel() diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 0e62441b1d40e..a39ce51788ff9 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -33,7 +33,7 @@ def _setup_ddp(rank, worldsize): def _ddp_test_fn(rank, worldsize): _setup_ddp(rank, worldsize) tensor = torch.tensor([1.0]) - sync = _Sync(sync_ddp_if_available, _should=True, op="SUM") + sync = _Sync(sync_ddp_if_available, _should=True, _op="SUM") actual = sync(tensor) assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors" diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 16c511b6effd9..47d72814cd2b6 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -242,7 +242,7 @@ def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: return super().get_from_queue(queue) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_v1_7_0_deprecate_add_get_queue(tmpdir): model = BoringCallbackDDPSpawnModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, num_processes=2, strategy="ddp_spawn") diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 490e023662f79..e53d3811f6b34 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -70,6 +70,7 @@ def __new__( fairscale_fully_sharded: bool = False, deepspeed: bool = False, rich: bool = False, + skip_49370: bool = False, **kwargs, ): """ @@ -91,6 +92,7 @@ def __new__( fairscale_fully_sharded: if `fairscale` fully sharded module is required to run the test deepspeed: if `deepspeed` module is required to run the test rich: if `rich` module is required to run the test + skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370. kwargs: native pytest.mark.skipif keyword arguments """ conditions = [] @@ -165,6 +167,15 @@ def __new__( conditions.append(not _RICH_AVAILABLE) reasons.append("Rich") + if skip_49370: + # strategy=ddp_spawn, accelerator=cpu, python>=3.9, torch<1.8 does not work + py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ge_3_9 = Version(py_version) >= Version("3.9") + torch_version = get_distribution("torch").version + old_torch = Version(torch_version) < Version("1.8") + conditions.append(ge_3_9 and old_torch) + reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370") + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( *args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 4993a10c8dbc2..c271d3b3163ed 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -17,6 +17,7 @@ import torch from torch.utils.data.dataloader import DataLoader +from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from tests.helpers.runif import RunIf @@ -65,6 +66,27 @@ def check_autocast(forward_input): assert out.dtype == input_type or out.dtype == torch.get_default_dtype() +@pytest.mark.parametrize( + "device", [torch.device("cpu"), pytest.param(torch.device("cuda", 0), marks=RunIf(min_gpus=1))] +) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_lite_module_device_dtype_propagation(device, dtype): + """Test that the LiteModule propagates device and dtype properties to its submodules (e.g. torchmetrics).""" + + class DeviceModule(DeviceDtypeModuleMixin): + pass + + device_module = DeviceModule() + lite_module = _LiteModule(device_module, Mock()) + lite_module.to(device) + assert device_module.device == device + assert lite_module.device == device + + lite_module.to(dtype) + assert device_module.dtype == dtype + assert lite_module.dtype == dtype + + def test_lite_dataloader_iterator(): """Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic device placement).""" diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 271ffce811fe5..370b24431b088 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -321,8 +321,8 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): assert pl_module.logger.experiment.something(foo="bar") is None +@RunIf(skip_windows=True, skip_49370=True) @pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger]) -@RunIf(skip_windows=True) def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class): """Test that loggers get replaced by dummy loggers on global rank > 0.""" _patch_comet_atexit(monkeypatch) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index bad9a717d1629..28ac1f3f2aefc 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -24,7 +24,7 @@ from pl_examples.bug_report_model import RandomDataset from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.loops import Loop, TrainingBatchLoop from pytorch_lightning.trainer.progress import BaseProgress from tests.helpers import BoringModel @@ -912,8 +912,10 @@ def val_dataloader(self): @RunIf(min_torch="1.8.0") -@pytest.mark.parametrize("persistent_workers", (False, True)) -def test_workers_are_shutdown(tmpdir, persistent_workers): +@pytest.mark.parametrize("should_fail", [False, True]) +# False is de-activated due to slowness +@pytest.mark.parametrize("persistent_workers", [True]) +def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers): # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance @@ -941,12 +943,30 @@ def _get_iterator(self): train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) + class TestCallback(Callback): + def on_train_epoch_end(self, trainer, *_): + if trainer.current_epoch == 1: + raise CustomException + max_epochs = 3 + model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=max_epochs) - trainer.fit(model, train_dataloader, val_dataloader) - assert train_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs) + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=max_epochs, + callbacks=TestCallback() if should_fail else None, + ) + + if should_fail: + with pytest.raises(CustomException): + trainer.fit(model, train_dataloader, val_dataloader) + else: + trainer.fit(model, train_dataloader, val_dataloader) + + assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs) # on sanity checking end, the workers are being deleted too. - assert val_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs + 1) + assert val_dataloader.count_shutdown_workers == 2 if persistent_workers else (3 if should_fail else max_epochs + 1) assert train_dataloader._iterator is None assert val_dataloader._iterator is None diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 2fb537b1d2861..c110f3a83d815 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -122,7 +122,7 @@ def validation_step(self, *args, **kwargs): model.unfreeze() -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_multi_cpu_model_ddp(tmpdir): """Make sure DDP works.""" tutils.set_random_main_port() diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index abf5a34757424..59a22cf1656d1 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -66,7 +66,7 @@ def _run_horovod(trainer_options, on_gpu=False): assert exit_code == 0 -@RunIf(skip_windows=True, horovod=True) +@RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu(tmpdir): """Test Horovod running multi-process on CPU.""" trainer_options = dict( @@ -82,7 +82,7 @@ def test_horovod_cpu(tmpdir): _run_horovod(trainer_options) -@RunIf(skip_windows=True, horovod=True) +@RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu_clip_grad_by_value(tmpdir): """Test Horovod running multi-process on CPU.""" trainer_options = dict( @@ -99,7 +99,7 @@ def test_horovod_cpu_clip_grad_by_value(tmpdir): _run_horovod(trainer_options) -@RunIf(skip_windows=True, horovod=True) +@RunIf(skip_windows=True, horovod=True, skip_49370=True) def test_horovod_cpu_implicit(tmpdir): """Test Horovod without specifying a backend, inferring from env set by `horovodrun`.""" trainer_options = dict( diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index d8ceb4106fd07..31ebd3968ff3e 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -407,7 +407,7 @@ def test_tpu_sync_dist(): """Test tpu spawn sync dist operation.""" def test_sync_dist(_): - sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM) + sync = _Sync(TPUSpawnPlugin().reduce, should=True, _op=torch.distributed.ReduceOp.SUM) value = torch.tensor([1.0]) value = (sync(value),) assert value.item() == 8 diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index c389cf9290c78..c5e5f7ccda748 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -46,7 +46,7 @@ def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: return super().get_from_queue(queue) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_ddp_cpu(): """Tests if device is set correctly when training for DDPSpawnPlugin.""" trainer = Trainer(num_processes=2, fast_dev_run=True) @@ -91,7 +91,7 @@ def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQu return super().get_from_queue(trainer, queue) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_ddp_spawn_add_get_queue(tmpdir): """Tests add_to_queue/get_from_queue with DDPSpawnPlugin.""" @@ -128,7 +128,7 @@ def on_predict_start(self) -> None: assert isinstance(self.trainer.model, LightningModule) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_ddp_spawn_configure_ddp(tmpdir): """Tests with ddp spawn plugin.""" trainer = Trainer(default_root_dir=tmpdir, num_processes=2, strategy="ddp_spawn", fast_dev_run=True) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 981d0d5db8cf6..c5b71e908795b 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -1,5 +1,6 @@ import contextlib import json +import logging import os from typing import Any, Dict, Optional from unittest import mock @@ -872,24 +873,9 @@ def training_step(self, batch, batch_idx): trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) -def test_deepspeed_warn_train_dataloader_called(tmpdir): - """Test DeepSpeed warns when it calls ``lightning_module.train_dataloader`` internally for logging batch - size.""" - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - strategy=DeepSpeedPlugin(), - gpus=1, - fast_dev_run=True, - ) - with pytest.warns(UserWarning, match="Inferring the batch size for internal deepspeed logging"): - trainer.fit(model) - - @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_setup_train_dataloader(tmpdir): - """Test DeepSpeed works when setup is required to call, and the user passes the batch size manually.""" + """Test DeepSpeed works when setup is required to call in the DataModule.""" class TestSetupIsCalledDataModule(LightningDataModule): def __init__(self): @@ -914,13 +900,14 @@ def test_dataloader(self): model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, - strategy=DeepSpeedPlugin(logging_batch_size_per_gpu=32), + strategy=DeepSpeedPlugin(logging_level=logging.INFO), gpus=1, fast_dev_run=True, ) dm = TestSetupIsCalledDataModule() - trainer.fit(model, datamodule=dm) - trainer.test(model, datamodule=dm) + with mock.patch("deepspeed.utils.logging.logger.warning", autospec=True) as mock_object: + trainer.fit(model, datamodule=dm) + assert any("Tried to infer the batch size" in str(arg) for arg in mock_object.call_args_list) @mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True) diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 7369ab9a4a140..f9a4727334a4f 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -162,7 +162,7 @@ def test_simple_profiler_with_nonexisting_dirpath(tmpdir): assert nonexisting_tmpdir.join("fit-profiler.txt").exists() -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_simple_profiler_distributed_files(tmpdir): """Ensure the proper files are saved in distributed.""" profiler = SimpleProfiler(dirpath=tmpdir, filename="profiler") @@ -227,6 +227,7 @@ def test_advanced_profiler_iterable_durations(advanced_profiler, action: str, ex np.testing.assert_allclose(recored_total_duration, expected_total_duration, rtol=0.2) +@pytest.mark.flaky(reruns=3) def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): """ensure that the profiler doesn't introduce too much overhead during training.""" for _ in range(n_iter): diff --git a/tests/trainer/flags/test_overfit_batches.py b/tests/trainer/flags/test_overfit_batches.py index 76c8b37405b47..3860d85ec9836 100644 --- a/tests/trainer/flags/test_overfit_batches.py +++ b/tests/trainer/flags/test_overfit_batches.py @@ -13,13 +13,16 @@ # limitations under the License. import pytest import torch +from torch.utils.data.sampler import Sampler, SequentialSampler from pytorch_lightning import Trainer from tests.helpers.boring_model import BoringModel, RandomDataset def test_overfit_multiple_val_loaders(tmpdir): - """Tests that only training_step can be used.""" + """Tests that overfit batches works with multiple val dataloaders.""" + val_dl_count = 2 + overfit_batches = 3 class TestModel(BoringModel): def validation_step(self, batch, batch_idx, dataloader_idx): @@ -31,25 +34,65 @@ def validation_epoch_end(self, outputs) -> None: pass def val_dataloader(self): - dl1 = torch.utils.data.DataLoader(RandomDataset(32, 64)) - dl2 = torch.utils.data.DataLoader(RandomDataset(32, 64)) - return [dl1, dl2] + dls = [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(val_dl_count)] + return dls model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, max_epochs=2, overfit_batches=1, log_every_n_steps=1, enable_model_summary=False + default_root_dir=tmpdir, + max_epochs=2, + overfit_batches=overfit_batches, + log_every_n_steps=1, + enable_model_summary=False, ) trainer.fit(model) + assert trainer.num_training_batches == overfit_batches + assert len(trainer.num_val_batches) == val_dl_count + assert all(nbatches == overfit_batches for nbatches in trainer.num_val_batches) -@pytest.mark.parametrize("overfit", [1, 2, 0.1, 0.25, 1.0]) -def test_overfit_basic(tmpdir, overfit): - """Tests that only training_step can be used.""" +@pytest.mark.parametrize("overfit_batches", [1, 2, 0.1, 0.25, 1.0]) +def test_overfit_basic(tmpdir, overfit_batches): + """Tests that only training_step can be used when overfitting.""" model = BoringModel() + model.validation_step = None + total_train_samples = len(BoringModel().train_dataloader()) - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=overfit, enable_model_summary=False) - + trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, overfit_batches=overfit_batches, enable_model_summary=False + ) trainer.fit(model) + + assert trainer.num_val_batches == [] + assert trainer.num_training_batches == int( + overfit_batches * (1 if isinstance(overfit_batches, int) else total_train_samples) + ) + + +def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir): + class NonSequentialSampler(Sampler): + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + class TestModel(BoringModel): + def train_dataloader(self): + dataset = RandomDataset(32, 64) + sampler = NonSequentialSampler(dataset) + return torch.utils.data.DataLoader(dataset, sampler=sampler) + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2) + + with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"): + trainer.fit(model) + + assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler) diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py index 487b7f38e4e19..d4ba4f242294a 100644 --- a/tests/trainer/logging_/test_distributed_logging.py +++ b/tests/trainer/logging_/test_distributed_logging.py @@ -59,7 +59,7 @@ def on_train_end(self): assert self.log_name.format(rank=self.local_rank) in self.logger.logs, "Expected rank to be logged" -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_all_rank_logging_ddp_cpu(tmpdir): """Check that all ranks can be logged from.""" model = TestModel() diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 5b775b9968d99..22a1a2c90d756 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -395,7 +395,7 @@ def validation_step(self, batch, batch_idx): return super().validation_step(batch, batch_idx) -@pytest.mark.parametrize("devices", [1, pytest.param(2, marks=RunIf(skip_windows=True))]) +@pytest.mark.parametrize("devices", [1, pytest.param(2, marks=RunIf(skip_windows=True, skip_49370=True))]) def test_logging_sync_dist_true(tmpdir, devices): """Tests to ensure that the sync_dist flag works (should just return the original value)""" fake_result = 1 diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 6e405739e83fe..ed81b90a2d142 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -37,7 +37,7 @@ def test_get_model(tmpdir): trainer.fit(model) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_get_model_ddp_cpu(tmpdir): """Tests that `trainer.lightning_module` extracts the model correctly when using ddp on cpu.""" diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 35f9838f0b04a..c04c6f0f6ea41 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -133,7 +133,7 @@ def _get_warning_msg(): assert warn_str in msg -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) @pytest.mark.parametrize("num_workers", [0, 1]) def test_dataloader_warnings(tmpdir, num_workers): trainer = Trainer(default_root_dir=tmpdir, strategy="ddp_spawn", num_processes=2, fast_dev_run=4) diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 204f3079f544b..e4598550c24fb 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -33,8 +33,10 @@ ) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler +from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 +from tests.helpers.boring_model import RandomDataset def test_tensor_running_accum_reset(): @@ -379,3 +381,56 @@ def _assert_dataset(loader): assert isinstance(d, CustomDataset) apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset) + + +@pytest.mark.parametrize("replace_sampler_ddp", [False, True]) +def test_combined_data_loader_with_max_size_cycle_and_ddp(replace_sampler_ddp, tmpdir): + """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader + with ddp and `max_size_cycle` mode.""" + trainer = Trainer(strategy="ddp", accelerator="auto", devices=2, replace_sampler_ddp=replace_sampler_ddp) + + dataloader = CombinedLoader( + {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, + ) + dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + assert len(dataloader) == 4 if replace_sampler_ddp else 8 + + for a_length in [6, 8, 10]: + dataloader = CombinedLoader( + { + "a": DataLoader(range(a_length), batch_size=1), + "b": DataLoader(range(8), batch_size=1), + }, + mode="max_size_cycle", + ) + + length = max(a_length, 8) + assert len(dataloader) == length + dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + assert len(dataloader) == length // 2 if replace_sampler_ddp else length + if replace_sampler_ddp: + last_batch = list(dataloader)[-1] + if a_length == 6: + assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])} + elif a_length == 8: + assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])} + elif a_length == 10: + assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])} + + class InfiniteDataset(IterableDataset): + def __iter__(self): + while True: + yield 1 + + dataloader = CombinedLoader( + { + "a": DataLoader(InfiniteDataset(), batch_size=1), + "b": DataLoader(range(8), batch_size=1), + }, + mode="max_size_cycle", + ) + assert get_len(dataloader) == float("inf") + assert len(dataloader.loaders["b"].loader) == 8 + dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8 + assert get_len(dataloader) == float("inf") diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0a3eacb23863f..c4a8884a23ccd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1809,7 +1809,7 @@ def on_predict_start(self) -> None: @pytest.mark.parametrize( - "strategy,num_processes", [(None, 1), pytest.param("ddp_spawn", 2, marks=RunIf(skip_windows=True))] + "strategy,num_processes", [(None, 1), pytest.param("ddp_spawn", 2, marks=RunIf(skip_windows=True, skip_49370=True))] ) def test_model_in_correct_mode_during_stages(tmpdir, strategy, num_processes): model = TrainerStagesModel() @@ -1830,7 +1830,7 @@ def validation_epoch_end(self, outputs) -> None: pass -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, skip_49370=True) def test_fit_test_synchronization(tmpdir): """Test that the trainer synchronizes processes before returning control back to the caller.""" tutils.set_random_main_port() diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 073468fc4cb28..2ed42b0b0f21a 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -41,8 +41,8 @@ def _test_all_gather_ddp(rank, world_size): assert torch.allclose(grad2, tensor2.grad) -@RunIf(skip_windows=True) -def test_all_gather_ddp(): +@RunIf(skip_windows=True, skip_49370=True) +def test_all_gather_ddp_spawn(): world_size = 3 torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index da309f7d22b50..9b0fcbd643744 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -14,7 +14,8 @@ import dataclasses import numbers from collections import defaultdict, namedtuple, OrderedDict -from typing import List +from dataclasses import InitVar +from typing import Any, ClassVar, List, Optional import numpy as np import pytest @@ -31,6 +32,12 @@ class Feature: input_ids: torch.Tensor segment_ids: np.ndarray + def __eq__(self, o: object) -> bool: + if not isinstance(o, Feature): + return NotImplemented + else: + return torch.equal(self.input_ids, o.input_ids) and np.equal(self.segment_ids, o.segment_ids).all() + @dataclasses.dataclass class ModelExample: example_ids: List[str] @@ -41,6 +48,71 @@ class ModelExample: def __post_init__(self): self.some_constant = 7 + def __eq__(self, o: object) -> bool: + if not isinstance(o, ModelExample): + return NotImplemented + else: + return ( + self.example_ids == o.example_ids + and self.feature == o.feature + and torch.equal(self.label, o.label) + and self.some_constant == o.some_constant + ) + + @dataclasses.dataclass + class WithClassVar: + class_var: ClassVar[int] = 0 + dummy: Any + + def __eq__(self, o: object) -> bool: + if not isinstance(o, WithClassVar): + return NotImplemented + elif isinstance(self.dummy, torch.Tensor): + return torch.equal(self.dummy, o.dummy) + else: + return self.dummy == o.dummy + + @dataclasses.dataclass + class WithInitVar: + dummy: Any + override: InitVar[Optional[Any]] = None + + def __post_init__(self, override: Optional[Any]): + if override is not None: + self.dummy = override + + def __eq__(self, o: object) -> bool: + if not isinstance(o, WithInitVar): + return NotImplemented + elif isinstance(self.dummy, torch.Tensor): + return torch.equal(self.dummy, o.dummy) + else: + return self.dummy == o.dummy + + @dataclasses.dataclass + class WithClassAndInitVar: + class_var: ClassVar[torch.Tensor] = torch.tensor(0) + dummy: Any + override: InitVar[Optional[Any]] = torch.tensor(1) + + def __post_init__(self, override: Optional[Any]): + if override is not None: + self.dummy = override + + def __eq__(self, o: object) -> bool: + if not isinstance(o, WithClassAndInitVar): + return NotImplemented + elif isinstance(self.dummy, torch.Tensor): + return torch.equal(self.dummy, o.dummy) + else: + return self.dummy == o.dummy + + model_example = ModelExample( + example_ids=["i-1", "i-2", "i-3"], + feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), + label=torch.tensor([7.0, 8.0, 9.0]), + ) + to_reduce = { "a": torch.tensor([1.0]), # Tensor "b": [torch.tensor([2.0])], # list @@ -50,13 +122,18 @@ def __post_init__(self): "f": "this_is_a_dummy_str", # string "g": 12.0, # number "h": Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), # dataclass - "i": ModelExample( - example_ids=["i-1", "i-2", "i-3"], - feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])), - label=torch.tensor([7.0, 8.0, 9.0]), - ), # nested dataclass + "i": model_example, # nested dataclass + "j": WithClassVar(torch.arange(3)), # dataclass with class variable + "k": WithInitVar("this_gets_overridden", torch.tensor([2.0])), # dataclass with init-only variable + "l": WithClassAndInitVar(model_example, None), # nested dataclass with class and init-only variables } + model_example_result = ModelExample( + example_ids=["i-1", "i-2", "i-3"], + feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])), + label=torch.tensor([14.0, 16.0, 18.0]), + ) + expected_result = { "a": torch.tensor([2.0]), "b": [torch.tensor([4.0])], @@ -66,32 +143,31 @@ def __post_init__(self): "f": "this_is_a_dummy_str", "g": 24.0, "h": Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])), - "i": ModelExample( - example_ids=["i-1", "i-2", "i-3"], - feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])), - label=torch.tensor([14.0, 16.0, 18.0]), - ), + "i": model_example_result, + "j": WithClassVar(torch.arange(0, 6, 2)), + "k": WithInitVar(torch.tensor([4.0])), + "l": WithClassAndInitVar(model_example_result, None), } reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2) - assert isinstance(reduced, dict), " Type Consistency of dict not preserved" + assert isinstance(reduced, dict), "Type Consistency of dict not preserved" assert all(x in reduced for x in to_reduce), "Not all entries of the dict were preserved" assert all( isinstance(reduced[k], type(expected_result[k])) for k in to_reduce ), "At least one type was not correctly preserved" assert isinstance(reduced["a"], torch.Tensor), "Reduction Result of a Tensor should be a Tensor" - assert torch.allclose(expected_result["a"], reduced["a"]), "Reduction of a tensor does not yield the expected value" + assert torch.equal(expected_result["a"], reduced["a"]), "Reduction of a tensor does not yield the expected value" assert isinstance(reduced["b"], list), "Reduction Result of a list should be a list" assert all( - torch.allclose(x, y) for x, y in zip(reduced["b"], expected_result["b"]) + torch.equal(x, y) for x, y in zip(reduced["b"], expected_result["b"]) ), "At least one value of list reduction did not come out as expected" assert isinstance(reduced["c"], tuple), "Reduction Result of a tuple should be a tuple" assert all( - torch.allclose(x, y) for x, y in zip(reduced["c"], expected_result["c"]) + torch.equal(x, y) for x, y in zip(reduced["c"], expected_result["c"]) ), "At least one value of tuple reduction did not come out as expected" assert isinstance(reduced["d"], ntc), "Type Consistency for named tuple not given" @@ -109,34 +185,30 @@ def __post_init__(self): assert isinstance(reduced["g"], numbers.Number), "Reduction of a number should result in a number" assert reduced["g"] == expected_result["g"], "Reduction of a number did not yield the desired result" - assert dataclasses.is_dataclass(reduced["h"]) and not isinstance( - reduced["h"], type - ), "Reduction of a dataclass should result in a dataclass" - assert torch.allclose( - reduced["h"].input_ids, expected_result["h"].input_ids - ), "Reduction of a dataclass did not yield the desired result" - assert np.allclose( - reduced["h"].segment_ids, expected_result["h"].segment_ids - ), "Reduction of a dataclass did not yield the desired result" - - assert dataclasses.is_dataclass(reduced["i"]) and not isinstance( - reduced["i"], type - ), "Reduction of a dataclass should result in a dataclass" - assert dataclasses.is_dataclass(reduced["i"].feature) and not isinstance( - reduced["i"].feature, type - ), "Reduction of a nested dataclass should result in a nested dataclass" - assert ( - reduced["i"].example_ids == expected_result["i"].example_ids - ), "Reduction of a nested dataclass did not yield the desired result" - assert torch.allclose( - reduced["i"].label, expected_result["i"].label - ), "Reduction of a nested dataclass did not yield the desired result" - assert torch.allclose( - reduced["i"].feature.input_ids, expected_result["i"].feature.input_ids - ), "Reduction of a nested dataclass did not yield the desired result" - assert np.allclose( - reduced["i"].feature.segment_ids, expected_result["i"].feature.segment_ids - ), "Reduction of a nested dataclass did not yield the desired result" + def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""): + assert dataclasses.is_dataclass(actual) and not isinstance( + actual, type + ), f"Reduction of a {dataclass_type} dataclass should result in a dataclass" + for field in dataclasses.fields(actual): + if dataclasses.is_dataclass(field.type): + _assert_dataclass_reduction(getattr(actual, field.name), getattr(expected, field.name), "nested") + assert actual == expected, f"Reduction of a {dataclass_type} dataclass did not yield the desired result" + + _assert_dataclass_reduction(reduced["h"], expected_result["h"]) + + _assert_dataclass_reduction(reduced["i"], expected_result["i"]) + + dataclass_type = "ClassVar-containing" + _assert_dataclass_reduction(reduced["j"], expected_result["j"], dataclass_type) + assert WithClassVar.class_var == 0, f"Reduction of a {dataclass_type} dataclass should not change the class var" + + _assert_dataclass_reduction(reduced["k"], expected_result["k"], "InitVar-containing") + + dataclass_type = "Class-and-InitVar-containing" + _assert_dataclass_reduction(reduced["l"], expected_result["l"], dataclass_type) + assert torch.equal( + WithClassAndInitVar.class_var, torch.tensor(0) + ), f"Reduction of a {dataclass_type} dataclass should not change the class var" # mapping support reduced = apply_to_collection({"a": 1, "b": 2}, int, lambda x: str(x)) diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py index 8e36a86c3beef..581b949d9167f 100644 --- a/tests/utilities/test_meta.py +++ b/tests/utilities/test_meta.py @@ -14,7 +14,7 @@ from torch import nn from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities.meta import init_meta_context, materialize_module +from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module from tests.helpers.runif import RunIf @@ -31,18 +31,23 @@ def __init__(self, num_layers: int): self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)]) -@RunIf(min_torch="1.10.0") +@RunIf(special=True, min_torch="1.10.0") def test_init_meta_context(): with init_meta_context(): m = nn.Linear(in_features=1, out_features=1) + assert isinstance(m, nn.Linear) assert m.weight.device.type == "meta" + assert is_on_meta_device(m) mlp = MLP(4) assert mlp.layer[0].weight.device.type == "meta" mlp = materialize_module(mlp) assert mlp.layer[0].weight.device.type == "cpu" + assert not is_on_meta_device(mlp) + assert not is_on_meta_device(nn.Module()) + model = BoringModel(4) assert model.layer[0].weight.device.type == "meta" materialize_module(model)