From 0fee28409b00ed8bad9c3ba52477292c9cf4687a Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 17 Feb 2023 17:58:14 +0100 Subject: [PATCH] Introduce new precision layout in PL (#16783) --- .../source-pytorch/common/precision_basic.rst | 24 ++++++-- .../common/precision_expert.rst | 2 +- .../common/precision_intermediate.rst | 2 +- docs/source-pytorch/common/trainer.rst | 4 +- .../fabric/fundamentals/launch.rst | 1 - examples/app_multi_node/train_fabric.py | 2 +- examples/pl_hpu/mnist_sample.py | 2 +- src/lightning/pytorch/CHANGELOG.md | 2 + .../pytorch/plugins/precision/amp.py | 15 +++-- .../pytorch/plugins/precision/deepspeed.py | 10 ++-- .../pytorch/plugins/precision/double.py | 2 +- .../pytorch/plugins/precision/fsdp.py | 8 +-- .../pytorch/plugins/precision/hpu.py | 12 ++-- .../pytorch/plugins/precision/ipu.py | 12 ++-- .../pytorch/plugins/precision/tpu_bf16.py | 2 +- src/lightning/pytorch/strategies/deepspeed.py | 12 ++-- src/lightning/pytorch/strategies/fsdp.py | 4 +- src/lightning/pytorch/strategies/utils.py | 14 ++++- .../connectors/accelerator_connector.py | 50 +++++++---------- src/lightning/pytorch/trainer/trainer.py | 7 ++- tests/tests_pytorch/accelerators/test_hpu.py | 2 +- tests/tests_pytorch/accelerators/test_ipu.py | 34 +++++++---- .../checkpointing/test_legacy_checkpoints.py | 2 +- .../helpers/deterministic_model.py | 2 +- tests/tests_pytorch/models/test_amp.py | 18 +++--- .../tests_pytorch/models/test_ddp_fork_amp.py | 2 +- tests/tests_pytorch/models/test_hooks.py | 6 +- tests/tests_pytorch/models/test_tpu.py | 6 +- .../plugins/precision/hpu/test_hpu.py | 16 +++--- .../plugins/precision/test_amp.py | 4 +- .../plugins/precision/test_amp_integration.py | 2 +- .../precision/test_deepspeed_precision.py | 2 +- .../tests_pytorch/plugins/test_amp_plugins.py | 10 ++-- .../plugins/test_double_plugin.py | 4 +- tests/tests_pytorch/strategies/test_ddp.py | 2 +- .../strategies/test_deepspeed_strategy.py | 56 +++++++++---------- tests/tests_pytorch/strategies/test_fsdp.py | 12 ++-- .../tests_pytorch/strategies/test_registry.py | 2 +- .../connectors/test_accelerator_connector.py | 23 ++++---- .../optimization/test_manual_optimization.py | 17 +++--- tests/tests_pytorch/trainer/test_trainer.py | 8 +-- .../tuner/test_scale_batch_size.py | 2 +- .../test_deepspeed_collate_checkpoint.py | 2 +- .../utilities/test_deepspeed_model_summary.py | 2 +- .../utilities/test_torchdistx.py | 2 +- 45 files changed, 227 insertions(+), 198 deletions(-) diff --git a/docs/source-pytorch/common/precision_basic.rst b/docs/source-pytorch/common/precision_basic.rst index 3cc0b3a9677be..0b8706a194b68 100644 --- a/docs/source-pytorch/common/precision_basic.rst +++ b/docs/source-pytorch/common/precision_basic.rst @@ -20,11 +20,11 @@ Higher precision, such as the 64-bit floating-point, can be used for highly sens 16-bit Precision **************** -Use 16-bit precision to cut your memory consumption in half so that you can train and deploy larger models. If your GPUs are [`Tensor Core `_] GPUs, you can also get a ~3x speed improvement. Half precision can sometimes lead to unstable training. +Use 16-bit mixed precision to lower your memory consumption by up to half so that you can train and deploy larger models. If your GPUs are [`Tensor Core `_] GPUs, you can also get a ~3x speed improvement. Half precision can sometimes lead to unstable training. .. code:: - Trainer(precision=16) + Trainer(precision='16-mixed') ---- @@ -36,6 +36,12 @@ Use 16-bit precision to cut your memory consumption in half so that you can trai .. testcode:: + Trainer(precision='32-true') + + # or + Trainer(precision='32') + + # or Trainer(precision=32) ---- @@ -48,6 +54,12 @@ For certain scientific computations, 64-bit precision enables more accurate mode .. testcode:: + Trainer(precision='64-true') + + # or + Trainer(precision='64') + + # or Trainer(precision=64) .. note:: @@ -70,22 +82,22 @@ Precision support by accelerator - GPU - TPU - IPU - * - 16 + * - 16 Mixed - No - Yes - No - Yes - * - BFloat16 + * - BFloat16 Mixed - Yes - Yes - Yes - No - * - 32 + * - 32 True - Yes - Yes - Yes - Yes - * - 64 + * - 64 True - Yes - Yes - No diff --git a/docs/source-pytorch/common/precision_expert.rst b/docs/source-pytorch/common/precision_expert.rst index 34bc95568c962..7a6c2dada1c17 100644 --- a/docs/source-pytorch/common/precision_expert.rst +++ b/docs/source-pytorch/common/precision_expert.rst @@ -20,7 +20,7 @@ You can also customize and pass your own Precision Plugin by subclassing the :cl .. code-block:: python class CustomPrecisionPlugin(PrecisionPlugin): - precision = 16 + precision = '16-mixed' ... diff --git a/docs/source-pytorch/common/precision_intermediate.rst b/docs/source-pytorch/common/precision_intermediate.rst index 52ad86d004e0b..7cdd929ad0e4b 100644 --- a/docs/source-pytorch/common/precision_intermediate.rst +++ b/docs/source-pytorch/common/precision_intermediate.rst @@ -63,7 +63,7 @@ Since computation happens in FP16, there is a chance of numerical instability du .. note:: - When using TPUs, setting ``precision=16`` will enable bfloat16, the only supported half precision type on TPUs. + When using TPUs, setting ``precision='16-mixed'`` will enable bfloat16, the only supported half precision type on TPUs. .. testcode:: :skipif: not torch.cuda.is_available() diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index fd8b3af0cf982..3ea6436f7c537 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -926,10 +926,10 @@ Half precision, or mixed precision, is the combined use of 32 and 16 bit floatin trainer = Trainer(precision=32) # 16-bit precision - trainer = Trainer(precision=16, accelerator="gpu", devices=1) # works only on CUDA + trainer = Trainer(precision="16-mixed", accelerator="gpu", devices=1) # works only on CUDA # bfloat16 precision - trainer = Trainer(precision="bf16") + trainer = Trainer(precision="bf16-mixed") # 64-bit precision trainer = Trainer(precision=64) diff --git a/docs/source-pytorch/fabric/fundamentals/launch.rst b/docs/source-pytorch/fabric/fundamentals/launch.rst index a8311e6134c14..af766c56e4a0c 100644 --- a/docs/source-pytorch/fabric/fundamentals/launch.rst +++ b/docs/source-pytorch/fabric/fundamentals/launch.rst @@ -74,7 +74,6 @@ This is essentially the same as running ``python path/to/your/script.py``, but i precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``) - --help Show this message and exit. diff --git a/examples/app_multi_node/train_fabric.py b/examples/app_multi_node/train_fabric.py index 1bb2ecd313202..335e1e73db6e0 100644 --- a/examples/app_multi_node/train_fabric.py +++ b/examples/app_multi_node/train_fabric.py @@ -15,7 +15,7 @@ def run(self): ) # 2. Create Fabric. - fabric = Fabric(strategy="ddp", precision=16) + fabric = Fabric(strategy="ddp", precision="16-mixed") model, optimizer = fabric.setup(model, torch.optim.SGD(model.parameters(), lr=0.01)) criterion = torch.nn.MSELoss() diff --git a/examples/pl_hpu/mnist_sample.py b/examples/pl_hpu/mnist_sample.py index ccb60e7c9de14..0ed24ad75403b 100644 --- a/examples/pl_hpu/mnist_sample.py +++ b/examples/pl_hpu/mnist_sample.py @@ -63,7 +63,7 @@ def configure_optimizers(self): "accelerator": "hpu", "devices": 1, "max_epochs": 1, - "plugins": lazy_instance(HPUPrecisionPlugin, precision=16), + "plugins": lazy_instance(HPUPrecisionPlugin, precision="16-mixed"), }, run=False, save_config_kwargs={"overwrite": True}, diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 2a230c608d991..18fc06564c254 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -107,6 +107,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781)) +- Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16783](https://github.com/Lightning-AI/lightning/pull/16783)) + ### Deprecated - diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 3d6b894097649..1ac94415e33dc 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -34,15 +34,18 @@ class MixedPrecisionPlugin(PrecisionPlugin): """ def __init__( - self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None + self, + precision: Literal["16-mixed", "bf16-mixed"], + device: str, + scaler: Optional[torch.cuda.amp.GradScaler] = None, ) -> None: - self.precision = cast(Literal["16", "bf16"], str(precision)) # type: ignore - if scaler is None and self.precision == "16": + self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision)) + if scaler is None and self.precision == "16-mixed": with _patch_cuda_is_available(): # if possible, we defer CUDA initialization to support strategies that will attempt forks scaler = torch.cuda.amp.GradScaler() - if scaler is not None and self.precision == "bf16": - raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.") + if scaler is not None and self.precision == "bf16-mixed": + raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device self.scaler = scaler @@ -97,7 +100,7 @@ def clip_gradients( def autocast_context_manager(self) -> torch.autocast: # the dtype could be automatically inferred but we need to manually set it due to a bug upstream # https://github.com/pytorch/pytorch/issues/67233 - return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half) + return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half) @contextmanager def forward_context(self) -> Generator[None, None, None]: diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 8f0845303c8ba..627026214eaf4 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -31,9 +31,7 @@ warning_cache = WarningCache() -_PRECISION_INPUT_INT = Literal[32, 16] -_PRECISION_INPUT_STR = Literal["32", "16", "bf16"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] +_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"] class DeepSpeedPrecisionPlugin(PrecisionPlugin): @@ -46,14 +44,14 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): If unsupported ``precision`` is provided. """ - def __init__(self, precision: Literal["32", 32, "16", 16, "bf16"]) -> None: - supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + def __init__(self, precision: Literal["32-true", "16-mixed", "bf16-mixed"]) -> None: + supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore + self.precision = cast(_PRECISION_INPUT, str(precision)) def backward( # type: ignore[override] self, diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index 78785a4c58ca5..77fa9c4171a2b 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -72,7 +72,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class DoublePrecisionPlugin(PrecisionPlugin): """Plugin for training with double (``torch.float64``) precision.""" - precision: Literal["64"] = "64" # type: ignore + precision: Literal["64-true"] = "64-true" def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 7e1d6a5250294..1561bd693f037 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -31,12 +31,12 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin): """AMP for Fully Sharded Data Parallel (FSDP) Training.""" def __init__( - self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None + self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional[ShardedGradScaler] = None ) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise MisconfigurationException("`FSDPMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards.") super().__init__( - precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None) + precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16-mixed" else None) ) def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: @@ -52,9 +52,9 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: @property def mixed_precision_config(self) -> Optional[MixedPrecision]: assert MixedPrecision is not None - if self.precision == "16": + if self.precision == "16-mixed": dtype = torch.float16 - elif self.precision == "bf16": + elif self.precision == "bf16-mixed": dtype = torch.bfloat16 else: raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") diff --git a/src/lightning/pytorch/plugins/precision/hpu.py b/src/lightning/pytorch/plugins/precision/hpu.py index e668285c445c5..47a145807bcff 100644 --- a/src/lightning/pytorch/plugins/precision/hpu.py +++ b/src/lightning/pytorch/plugins/precision/hpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Literal, Optional, Union +from typing import cast, Literal, Optional from typing_extensions import get_args @@ -22,9 +22,7 @@ if _HPU_AVAILABLE: from habana_frameworks.torch.hpex import hmp -_PRECISION_INPUT_INT = Literal[32, 16] -_PRECISION_INPUT_STR = Literal["32", "16", "bf16"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] +_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"] class HPUPrecisionPlugin(PrecisionPlugin): @@ -48,14 +46,14 @@ def __init__( ) -> None: if not _HPU_AVAILABLE: raise MisconfigurationException("HPU precision plugin requires HPU devices.") - supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore - if self.precision in ("16", "bf16"): + self.precision = cast(_PRECISION_INPUT, str(precision)) + if self.precision in ("16-mixed", "bf16-mixed"): hmp.convert( opt_level=opt_level, bf16_file_path=bf16_file_path, fp32_file_path=fp32_file_path, isVerbose=verbose ) diff --git a/src/lightning/pytorch/plugins/precision/ipu.py b/src/lightning/pytorch/plugins/precision/ipu.py index 104cec0dcfe99..e414bc693163e 100644 --- a/src/lightning/pytorch/plugins/precision/ipu.py +++ b/src/lightning/pytorch/plugins/precision/ipu.py @@ -27,9 +27,7 @@ warning_cache = WarningCache() -_PRECISION_INPUT_INT = Literal[32, 16] -_PRECISION_INPUT_STR = Literal["32", "16"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] +_PRECISION_INPUT = Literal["32-true", "16-mixed"] class IPUPrecisionPlugin(PrecisionPlugin): @@ -37,17 +35,17 @@ class IPUPrecisionPlugin(PrecisionPlugin): Raises: ValueError: - If the precision is neither 16 nor 32. + If the precision is neither 16-mixed nor 32-true. """ - def __init__(self, precision: Literal["32", 32, "16", 16]) -> None: - supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + def __init__(self, precision: Literal["32-true", "16-mixed"]) -> None: + supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore + self.precision = cast(_PRECISION_INPUT, str(precision)) def backward( # type: ignore[override] self, diff --git a/src/lightning/pytorch/plugins/precision/tpu_bf16.py b/src/lightning/pytorch/plugins/precision/tpu_bf16.py index aff41d9c92357..bef5989736a18 100644 --- a/src/lightning/pytorch/plugins/precision/tpu_bf16.py +++ b/src/lightning/pytorch/plugins/precision/tpu_bf16.py @@ -23,7 +23,7 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin): """Plugin that enables bfloats on TPUs.""" - precision: Literal["bf16"] = "bf16" # type: ignore + precision: Literal["bf16-mixed"] = "bf16-mixed" def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index d3d0545e05865..2c7a61827ff24 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -127,8 +127,8 @@ def __init__( Arguments: - zero_optimization: Enable ZeRO optimization. This is compatible with either `precision=16` or - `precision="bf16"`. + zero_optimization: Enable ZeRO optimization. This is compatible with either `precision="16-mixed"` or + `precision="bf16-mixed"`. stage: Different stages of the ZeRO Optimizer. 0 is disabled, 1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning, @@ -505,9 +505,9 @@ def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: assert self._config_initialized - if self.precision_plugin.precision == "16": + if self.precision_plugin.precision == "16-mixed": dtype = torch.float16 - elif self.precision_plugin.precision == "bf16": + elif self.precision_plugin.precision == "bf16-mixed": dtype = torch.bfloat16 else: dtype = torch.float32 @@ -641,7 +641,7 @@ def _auto_select_batch_size(self) -> int: def _format_precision_config(self) -> None: assert isinstance(self.config, dict) - if self.precision_plugin.precision == "16": + if self.precision_plugin.precision == "16-mixed": if "fp16" not in self.config: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") @@ -653,7 +653,7 @@ def _format_precision_config(self) -> None: "hysteresis": self.hysteresis, "min_loss_scale": self.min_loss_scale, } - elif "bf16" not in self.config and self.precision_plugin.precision == "bf16": + elif "bf16" not in self.config and self.precision_plugin.precision == "bf16-mixed": rank_zero_info("Enabling DeepSpeed BF16.") self.config["bf16"] = {"enabled": True} diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 0ac1709ad3680..f58dfc1db90f8 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -99,8 +99,8 @@ class FSDPStrategy(ParallelStrategy): algorithms to help backward communication and computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. mixed_precision: - Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` - or BF16 if ``precision=bf16`` unless a config is passed in. + Mixed Precision config. By default, Lightning will enable FP16 if ``precision="16-mixed"`` + or BF16 if ``precision="bf16-mixed"`` unless a config is passed in. This is only available in PyTorch 1.12 and later. activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation checkpointing. This is typically your transformer block (including attention + feed-forward). diff --git a/src/lightning/pytorch/strategies/utils.py b/src/lightning/pytorch/strategies/utils.py index 1c3d72337786d..f67fb55823a51 100644 --- a/src/lightning/pytorch/strategies/utils.py +++ b/src/lightning/pytorch/strategies/utils.py @@ -32,9 +32,17 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) -> mod.register_strategies(registry) -def _fp_to_half(tensor: Tensor, precision: Literal["64", 64, "32", 32, "16", 16, "bf16"]) -> Tensor: - if str(precision) == "16": +def _fp_to_half( + tensor: Tensor, + precision: Literal[ + "64-true", + "32-true", + "16-mixed", + "bf16-mixed", + ], +) -> Tensor: + if str(precision) == "16-mixed": return _convert_fp_tensor(tensor, torch.half) - if precision == "bf16": + if precision == "bf16-mixed": return _convert_fp_tensor(tensor, torch.bfloat16) return tensor diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 9b09ac2c29542..ced2daefb1508 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -15,11 +15,11 @@ import logging import os from collections import Counter -from typing import cast, Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union import torch -from typing_extensions import get_args +from lightning.fabric.connector import _convert_precision_to_unified_args, _PRECISION_INPUT, _PRECISION_INPUT_STR from lightning.fabric.plugins.environments import ( ClusterEnvironment, KubeflowEnvironment, @@ -75,9 +75,6 @@ log = logging.getLogger(__name__) _LITERAL_WARN = Literal["warn"] -_PRECISION_INPUT_INT = Literal[64, 32, 16] -_PRECISION_INPUT_STR = Literal["64", "32", "16", "bf16"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] class AcceleratorConnector: @@ -88,7 +85,7 @@ def __init__( accelerator: Optional[Union[str, Accelerator]] = None, strategy: Optional[Union[str, Strategy]] = None, plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, - precision: _PRECISION_INPUT = 32, + precision: _PRECISION_INPUT = "32-true", sync_batchnorm: bool = False, benchmark: Optional[bool] = None, replace_sampler_ddp: bool = True, @@ -136,7 +133,7 @@ def __init__( # Set each valid flag to `self._x_flag` after validation self._strategy_flag: Optional[Union[Strategy, str]] = None self._accelerator_flag: Optional[Union[Accelerator, str]] = None - self._precision_flag: _PRECISION_INPUT_STR = "32" + self._precision_flag: _PRECISION_INPUT_STR = "32-true" self._precision_plugin_flag: Optional[PrecisionPlugin] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None self._parallel_devices: List[Union[int, torch.device, str]] = [] @@ -243,12 +240,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = accelerator - supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) - if precision not in supported_precision: - raise MisconfigurationException( - f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}" - ) - self._precision_flag = cast(_PRECISION_INPUT_STR, str(precision)) + self._precision_flag = _convert_precision_to_unified_args(precision) if plugins: plugins_flags_types: Dict[str, int] = Counter() @@ -518,13 +510,13 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if isinstance(self.accelerator, HPUAccelerator): return HPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, TPUAccelerator): - if self._precision_flag == "32": + if self._precision_flag == "32-true": return TPUPrecisionPlugin() - elif self._precision_flag in ("16", "bf16"): - if self._precision_flag == "16": + elif self._precision_flag in ("16-mixed", "bf16-mixed"): + if self._precision_flag == "16-mixed": rank_zero_warn( - "You passed `Trainer(accelerator='tpu', precision=16)` but AMP" - " is not supported with TPUs. Using `precision='bf16'` instead." + "You passed `Trainer(accelerator='tpu', precision='16-mixed')` but AMP with fp16" + " is not supported on TPUs. Using `precision='bf16-mixed'` instead." ) return TPUBf16PrecisionPlugin() @@ -537,21 +529,21 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin(self._precision_flag) - if self._precision_flag == "32": + if self._precision_flag == "32-true": return PrecisionPlugin() - if self._precision_flag == "64": + if self._precision_flag == "64-true": return DoublePrecisionPlugin() - if self._precision_flag == "16" and self._accelerator_flag == "cpu": + if self._precision_flag == "16-mixed" and self._accelerator_flag == "cpu": rank_zero_warn( - "You passed `Trainer(accelerator='cpu', precision=16)` but AMP is not supported on CPU." - " Using `precision='bf16'` instead." + "You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on " + "CPU. Using `precision='bf16-mixed'` instead." ) - self._precision_flag = "bf16" + self._precision_flag = "bf16-mixed" - if self._precision_flag in ("16", "bf16"): + if self._precision_flag in ("16-mixed", "bf16-mixed"): rank_zero_info( - f"Using {'16bit' if self._precision_flag == 16 else 'bfloat16'} Automatic Mixed Precision (AMP)" + f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" @@ -564,9 +556,9 @@ def _check_and_init_precision(self) -> PrecisionPlugin: def _validate_precision_choice(self) -> None: """Validate the combination of choices for precision, AMP type, and accelerator.""" if isinstance(self.accelerator, TPUAccelerator): - if self._precision_flag == "64": + if self._precision_flag == "64-true": raise MisconfigurationException( - "`Trainer(accelerator='tpu', precision=64)` is not implemented." + "`Trainer(accelerator='tpu', precision='64-true')` is not implemented." " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`" " requesting this feature." ) @@ -578,7 +570,7 @@ def _validate_precision_choice(self) -> None: f" found: {self._precision_plugin_flag}." ) if isinstance(self.accelerator, HPUAccelerator): - if self._precision_flag not in ("16", "bf16", "32"): + if self._precision_flag not in ("16-mixed", "bf16-mixed", "32-true"): raise MisconfigurationException( f"`Trainer(accelerator='hpu', precision={self._precision_flag!r})` is not supported." ) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 85db7aaed8ea7..546918fbb008e 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -117,7 +117,7 @@ def __init__( accelerator: Optional[Union[str, Accelerator]] = None, strategy: Optional[Union[str, Strategy]] = None, sync_batchnorm: bool = False, - precision: _PRECISION_INPUT = 32, + precision: _PRECISION_INPUT = "32-true", enable_model_summary: bool = True, num_sanity_val_steps: int = 2, profiler: Optional[Union[Profiler, str]] = None, @@ -221,9 +221,10 @@ def __init__( plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. Default: ``None``. - precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). + precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), + 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). Can be used on CPU, GPU, TPUs, HPUs or IPUs. - Default: ``32``. + Default: ``'32-true'``. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. diff --git a/tests/tests_pytorch/accelerators/test_hpu.py b/tests/tests_pytorch/accelerators/test_hpu.py index 6307a78b1c815..b8ba801e7ede4 100644 --- a/tests/tests_pytorch/accelerators/test_hpu.py +++ b/tests/tests_pytorch/accelerators/test_hpu.py @@ -61,7 +61,7 @@ def test_all_stages(tmpdir, hpus): fast_dev_run=True, accelerator="hpu", devices=hpus, - precision=16, + precision="16-mixed", ) trainer.fit(model) trainer.validate(model) diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index dd45734d0d818..f33e0201ba66e 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -177,15 +177,20 @@ def test_optimization(tmpdir): def test_half_precision(tmpdir): class TestCallback(Callback): def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: - assert trainer.precision == "16" + assert trainer.precision == "16-mixed" raise SystemExit model = IPUModel() trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, accelerator="ipu", devices=1, precision=16, callbacks=TestCallback() + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="ipu", + devices=1, + precision="16-mixed", + callbacks=TestCallback(), ) assert isinstance(trainer.strategy.precision_plugin, IPUPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == "16" + assert trainer.strategy.precision_plugin.precision == "16-mixed" with pytest.raises(SystemExit): trainer.fit(model) @@ -194,7 +199,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non def test_pure_half_precision(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - assert trainer.strategy.precision_plugin.precision == "16" + assert trainer.strategy.precision_plugin.precision == "16-mixed" for param in trainer.strategy.model.parameters(): assert param.dtype == torch.float16 raise SystemExit @@ -202,22 +207,31 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: model = IPUModel() model = model.half() trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, accelerator="ipu", devices=1, precision=16, callbacks=TestCallback() + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="ipu", + devices=1, + precision="16-mixed", + callbacks=TestCallback(), ) assert isinstance(trainer.strategy, IPUStrategy) assert isinstance(trainer.strategy.precision_plugin, IPUPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == "16" + assert trainer.strategy.precision_plugin.precision == "16-mixed" changed_dtypes = [torch.float, torch.float64] data = [torch.zeros((1), dtype=dtype) for dtype in changed_dtypes] new_data = trainer.strategy.batch_to_device(data) - assert all(val.dtype is torch.half for val in new_data) + assert all(val.dtype is torch.half for val in new_data), "".join( + [f"{dtype}: {val.dtype}" for dtype, val in zip(changed_dtypes, new_data)] + ) not_changed_dtypes = [torch.uint8, torch.int8, torch.int32, torch.int64] data = [torch.zeros((1), dtype=dtype) for dtype in not_changed_dtypes] new_data = trainer.strategy.batch_to_device(data) - assert all(val.dtype is dtype for val, dtype in zip(new_data, not_changed_dtypes)) + assert all(val.dtype is dtype for val, dtype in zip(new_data, not_changed_dtypes)), "".join( + [f"{dtype}: {val.dtype}" for dtype, val in zip(not_changed_dtypes, new_data)] + ) with pytest.raises(SystemExit): trainer.fit(model) @@ -531,8 +545,8 @@ def configure_optimizers(self): def test_precision_plugin(): """Ensure precision plugin value is set correctly.""" - plugin = IPUPrecisionPlugin(precision=16) - assert plugin.precision == "16" + plugin = IPUPrecisionPlugin(precision="16-mixed") + assert plugin.precision == "16-mixed" @RunIf(ipu=True) diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 829c498e1e7c5..86dd5c6cfe9b7 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -103,7 +103,7 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): default_root_dir=str(tmpdir), accelerator="auto", devices=1, - precision=(16 if torch.cuda.is_available() else 32), + precision=("16-mixed" if torch.cuda.is_available() else "32-true"), callbacks=[stop], max_epochs=21, accumulate_grad_batches=2, diff --git a/tests/tests_pytorch/helpers/deterministic_model.py b/tests/tests_pytorch/helpers/deterministic_model.py index b5a4b588881c2..158406b4b7435 100644 --- a/tests/tests_pytorch/helpers/deterministic_model.py +++ b/tests/tests_pytorch/helpers/deterministic_model.py @@ -98,7 +98,7 @@ def configure_optimizers__lr_on_plateau_step(self): def backward(self, loss, *args, **kwargs): if self.assert_backward: - if self.trainer.precision == "16": + if self.trainer.precision == "16-mixed": assert loss > 171 * 1000 else: assert loss == 171.0 diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index d7c6922362141..01d16e1c64adb 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -29,7 +29,7 @@ class AMPTestModel(BoringModel): def step(self, batch): self._assert_autocast_enabled() output = self(batch) - is_bfloat16 = self.trainer.precision_plugin.precision == "bf16" + is_bfloat16 = self.trainer.precision_plugin.precision == "bf16-mixed" assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16 loss = self.loss(output) return loss @@ -37,7 +37,7 @@ def step(self, batch): def predict_step(self, batch, batch_idx, dataloader_idx=0): self._assert_autocast_enabled() output = self(batch) - is_bfloat16 = self.trainer.precision_plugin.precision == "bf16" + is_bfloat16 = self.trainer.precision_plugin.precision == "bf16-mixed" assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16 return output @@ -52,10 +52,10 @@ def _assert_autocast_enabled(self): @pytest.mark.parametrize( ("strategy", "precision", "devices"), ( - ("single_device", 16, 1), - ("single_device", "bf16", 1), - ("ddp_spawn", 16, 2), - ("ddp_spawn", "bf16", 2), + ("single_device", "16-mixed", 1), + ("single_device", "bf16-mixed", 1), + ("ddp_spawn", "16-mixed", 2), + ("ddp_spawn", "bf16-mixed", 2), ), ) def test_amp_cpus(tmpdir, strategy, precision, devices): @@ -83,7 +83,7 @@ def test_amp_cpus(tmpdir, strategy, precision, devices): @pytest.mark.parametrize("strategy", [None, "ddp_spawn"]) -@pytest.mark.parametrize("precision", [16, pytest.param("bf16", marks=RunIf(bf16_cuda=True))]) +@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) @pytest.mark.parametrize( "devices", (pytest.param(1, marks=RunIf(min_cuda_gpus=1)), pytest.param(2, marks=RunIf(min_cuda_gpus=2))) ) @@ -135,7 +135,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): accelerator="gpu", devices=[0], strategy="ddp_spawn", - precision=16, + precision="16-mixed", callbacks=[checkpoint], logger=logger, ) @@ -153,7 +153,7 @@ def test_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): enable_progress_bar=False, max_epochs=1, devices=1, - precision=16, + precision="16-mixed", limit_train_batches=4, limit_val_batches=0, gradient_clip_val=clip_val, diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py index ae873ccad6eb0..13434dcab69bf 100644 --- a/tests/tests_pytorch/models/test_ddp_fork_amp.py +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -24,7 +24,7 @@ def test_amp_gpus_ddp_fork(): """Ensure the use of AMP with `ddp_fork` (or associated alias strategies) does not generate CUDA initialization errors.""" - _ = MixedPrecisionPlugin(precision=16, device="cuda") + _ = MixedPrecisionPlugin(precision="16-mixed", device="cuda") with multiprocessing.get_context("fork").Pool(1) as pool: in_bad_fork = pool.apply(torch.cuda._is_in_bad_fork) assert not in_bad_fork diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 44bbedc3a819d..50f6f36a0811d 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -401,9 +401,9 @@ def _predict_batch(trainer, model, batches): [ {}, # these precision plugins modify the optimization flow, so testing them explicitly - pytest.param(dict(accelerator="gpu", devices=1, precision=16), marks=RunIf(min_cuda_gpus=1)), + pytest.param(dict(accelerator="gpu", devices=1, precision="16-mixed"), marks=RunIf(min_cuda_gpus=1)), pytest.param( - dict(accelerator="gpu", devices=1, precision=16, strategy="deepspeed"), + dict(accelerator="gpu", devices=1, precision="16-mixed", strategy="deepspeed"), marks=RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True), ), ], @@ -453,7 +453,7 @@ def training_step(self, batch, batch_idx): "loops": ANY, } using_deepspeed = kwargs.get("strategy") == "deepspeed" - if kwargs.get("precision") == 16 and not using_deepspeed: + if kwargs.get("precision") == "16-mixed" and not using_deepspeed: saved_ckpt[trainer.precision_plugin.__class__.__qualname__] = ANY device = torch.device("cuda:0" if "accelerator" in kwargs and kwargs["accelerator"] == "gpu" else "cpu") expected = [ diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index ceebbca6a7194..5685739c78837 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -104,7 +104,7 @@ def test_model_16bit_tpu_devices_1(tmpdir): """Make sure model trains on TPU.""" trainer_options = dict( default_root_dir=tmpdir, - precision=16, + precision="16-mixed", enable_progress_bar=False, max_epochs=2, accelerator="tpu", @@ -124,7 +124,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core): """Make sure model trains on TPU.""" trainer_options = dict( default_root_dir=tmpdir, - precision=16, + precision="16-mixed", enable_progress_bar=False, max_epochs=2, accelerator="tpu", @@ -146,7 +146,7 @@ def test_model_16bit_tpu_devices_8(tmpdir): """Make sure model trains on TPU.""" trainer_options = dict( default_root_dir=tmpdir, - precision=16, + precision="16-mixed", enable_progress_bar=False, max_epochs=1, accelerator="tpu", diff --git a/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py b/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py index 718ef030eb507..54599f58448c0 100644 --- a/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py +++ b/tests/tests_pytorch/plugins/precision/hpu/test_hpu.py @@ -34,15 +34,15 @@ def hmp_params(request): @RunIf(hpu=True) def test_precision_plugin(hmp_params): - plugin = HPUPrecisionPlugin(precision="bf16", **hmp_params) - assert plugin.precision == "bf16" + plugin = HPUPrecisionPlugin(precision="bf16-mixed", **hmp_params) + assert plugin.precision == "bf16-mixed" @RunIf(hpu=True) def test_mixed_precision(tmpdir, hmp_params: dict): class TestCallback(Callback): def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: - assert trainer.precision == "bf16" + assert trainer.precision == "bf16-mixed" raise SystemExit model = BoringModel() @@ -51,12 +51,12 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non fast_dev_run=True, accelerator="hpu", devices=1, - plugins=[HPUPrecisionPlugin(precision="bf16", **hmp_params)], + plugins=[HPUPrecisionPlugin(precision="bf16-mixed", **hmp_params)], callbacks=TestCallback(), ) assert isinstance(trainer.strategy, SingleHPUStrategy) assert isinstance(trainer.strategy.precision_plugin, HPUPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == "bf16" + assert trainer.strategy.precision_plugin.precision == "bf16-mixed" with pytest.raises(SystemExit): trainer.fit(model) @@ -65,7 +65,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non def test_pure_half_precision(tmpdir, hmp_params: dict): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - assert trainer.precision == "16" + assert trainer.precision == "16-mixed" for param in trainer.strategy.model.parameters(): assert param.dtype == torch.float16 raise SystemExit @@ -77,13 +77,13 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: fast_dev_run=True, accelerator="hpu", devices=1, - plugins=[HPUPrecisionPlugin(precision=16, **hmp_params)], + plugins=[HPUPrecisionPlugin(precision="16-mixed", **hmp_params)], callbacks=TestCallback(), ) assert isinstance(trainer.strategy, SingleHPUStrategy) assert isinstance(trainer.strategy.precision_plugin, HPUPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == "16" + assert trainer.strategy.precision_plugin.precision == "16-mixed" with pytest.raises(RuntimeError, match=r"float16/half is not supported on Gaudi."): trainer.fit(model) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index 189386cb90502..4c86f02986894 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -23,7 +23,7 @@ def test_clip_gradients(): """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" optimizer = Mock(spec=Optimizer) - precision = MixedPrecisionPlugin(precision=16, device="cuda:0", scaler=Mock()) + precision = MixedPrecisionPlugin(precision="16-mixed", device="cuda:0", scaler=Mock()) precision.clip_grad_by_value = Mock() precision.clip_grad_by_norm = Mock() precision.clip_gradients(optimizer) @@ -47,7 +47,7 @@ def test_optimizer_amp_scaling_support_in_step_method(): gradient clipping (example: fused Adam).""" optimizer = Mock(_step_supports_amp_scaling=True) - precision = MixedPrecisionPlugin(precision=16, device="cuda:0", scaler=Mock()) + precision = MixedPrecisionPlugin(precision="16-mixed", device="cuda:0", scaler=Mock()) with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): precision.clip_gradients(optimizer, clip_val=1.0) diff --git a/tests/tests_pytorch/plugins/precision/test_amp_integration.py b/tests/tests_pytorch/plugins/precision/test_amp_integration.py index 0d7fb3f8e2bc0..8a64169e9f1fe 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp_integration.py +++ b/tests/tests_pytorch/plugins/precision/test_amp_integration.py @@ -38,7 +38,7 @@ def run(fused=False): default_root_dir=tmpdir, accelerator="cuda", devices=1, - precision=16, + precision="16-mixed", max_steps=5, logger=False, enable_checkpointing=False, diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py index 8420c5c793aec..b0ef260309639 100644 --- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py @@ -19,4 +19,4 @@ def test_invalid_precision_with_deepspeed_precision(): with pytest.raises(ValueError, match="is not supported. `precision` must be one of"): - DeepSpeedPrecisionPlugin(precision=64) + DeepSpeedPrecisionPlugin(precision="64-true") diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index e542c01967cf7..5cca3a93aa518 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -54,10 +54,10 @@ class MyAMP(MixedPrecisionPlugin): def test_amp_ddp(cuda_count_2, strategy, devices, custom_plugin, plugin_cls): plugin = None if custom_plugin: - plugin = plugin_cls(16, "cpu") + plugin = plugin_cls("16-mixed", "cpu") trainer = Trainer( fast_dev_run=True, - precision=16, + precision="16-mixed", accelerator="gpu", devices=devices, strategy=strategy, @@ -137,7 +137,7 @@ def test_amp_gradient_unscale(tmpdir, accum: int): strategy="ddp_spawn", accelerator="gpu", devices=2, - precision=16, + precision="16-mixed", # use a tiny value to make sure it works gradient_clip_val=1e-3, gradient_clip_algorithm="value", @@ -179,14 +179,14 @@ def configure_optimizers(self): torch.optim.SGD(self.layer2.parameters(), lr=0.1), ] - trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=1, fast_dev_run=1, precision=16) + trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=1, fast_dev_run=1, precision="16-mixed") model = CustomBoringModel() trainer.fit(model) def test_cpu_amp_precision_context_manager(tmpdir): """Test to ensure that the context manager correctly is set to CPU + bfloat16.""" - plugin = MixedPrecisionPlugin("bf16", "cpu") + plugin = MixedPrecisionPlugin("bf16-mixed", "cpu") assert plugin.device == "cpu" assert plugin.scaler is None context_manager = plugin.autocast_context_manager() diff --git a/tests/tests_pytorch/plugins/test_double_plugin.py b/tests/tests_pytorch/plugins/test_double_plugin.py index 9c93f09cad221..8d801d6eaf7eb 100644 --- a/tests/tests_pytorch/plugins/test_double_plugin.py +++ b/tests/tests_pytorch/plugins/test_double_plugin.py @@ -135,7 +135,7 @@ def on_fit_start(self): def test_double_precision(tmpdir, boring_model): model = boring_model() - trainer = Trainer(max_epochs=2, default_root_dir=tmpdir, fast_dev_run=2, precision=64, log_every_n_steps=1) + trainer = Trainer(max_epochs=2, default_root_dir=tmpdir, fast_dev_run=2, precision="64-true", log_every_n_steps=1) trainer.fit(model) trainer.test(model) trainer.predict(model) @@ -152,7 +152,7 @@ def test_double_precision_ddp(tmpdir): accelerator="gpu", devices=2, fast_dev_run=2, - precision=64, + precision="64-true", log_every_n_steps=1, ) trainer.fit(model) diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 248e42bd7e69d..f6470764d9016 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -96,7 +96,7 @@ def setup(self, stage: str) -> None: @RunIf(min_cuda_gpus=2, standalone=True) -@pytest.mark.parametrize("precision", (16, 32)) +@pytest.mark.parametrize("precision", ("16-mixed", "32-true")) def test_ddp_wrapper(tmpdir, precision): """Test parameters to ignore are carried over for DDP.""" diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index e6eeff8c36f5f..76f248bb5264e 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -139,12 +139,12 @@ def test_deepspeed_precision_choice(cuda_count_1, tmpdir): default_root_dir=tmpdir, accelerator="gpu", strategy="deepspeed", - precision=16, + precision="16-mixed", ) assert isinstance(trainer.strategy, DeepSpeedStrategy) assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == "16" + assert trainer.strategy.precision_plugin.precision == "16-mixed" @RunIf(deepspeed=True) @@ -189,7 +189,7 @@ def backward(self, loss: Tensor, *args, **kwargs) -> None: strategy=DeepSpeedStrategy(), accelerator="gpu", devices=1, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -264,7 +264,7 @@ def configure_optimizers(self): accelerator="gpu", devices=1, fast_dev_run=True, - precision=16, + precision="16-mixed", callbacks=[TestCB(), lr_monitor], logger=CSVLogger(tmpdir), enable_progress_bar=False, @@ -303,7 +303,7 @@ def on_train_start(self, trainer, pl_module) -> None: limit_val_batches=4, limit_test_batches=4, max_epochs=2, - precision=16, + precision="16-mixed", callbacks=[TestCB(), lr_monitor], logger=CSVLogger(tmpdir), enable_progress_bar=False, @@ -337,7 +337,7 @@ def on_train_start(self, trainer, pl_module) -> None: trainer = Trainer( default_root_dir=tmpdir, strategy=ds, - precision=16, + precision="16-mixed", accelerator="gpu", devices=1, callbacks=[TestCB()], @@ -380,7 +380,7 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmpdir): default_root_dir=tmpdir, fast_dev_run=1, strategy=ds, - precision=16, + precision="16-mixed", accelerator="gpu", devices=1, enable_progress_bar=False, @@ -413,7 +413,7 @@ def setup(self, trainer, pl_module, stage=None) -> None: enable_progress_bar=False, max_epochs=1, strategy=DeepSpeedStrategy(config=deepspeed_zero_config), - precision=16, + precision="16-mixed", accelerator="gpu", devices=1, callbacks=[TestCallback()], @@ -433,7 +433,7 @@ def test_deepspeed_multigpu(tmpdir): accelerator="gpu", devices=2, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -476,7 +476,7 @@ def test_deepspeed_stage_3_save_warning(tmpdir): accelerator="gpu", devices=2, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -508,7 +508,7 @@ def test_deepspeed_multigpu_single_file(tmpdir): accelerator="gpu", devices=1, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -524,7 +524,7 @@ def test_deepspeed_multigpu_single_file(tmpdir): accelerator="gpu", devices=1, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -626,7 +626,7 @@ def test_deepspeed_multigpu_stage_3(tmpdir): accelerator="gpu", devices=2, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -646,7 +646,7 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config accelerator="gpu", devices=2, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -672,7 +672,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir, automatic_optimization strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=2, - precision=16, + precision="16-mixed", accumulate_grad_batches=accumulate_grad_batches, callbacks=[ck], enable_progress_bar=False, @@ -693,7 +693,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir, automatic_optimization accelerator="gpu", devices=2, strategy=DeepSpeedStrategy(stage=3), - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -722,7 +722,7 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): strategy=DeepSpeedStrategy(stage=3, load_full_weights=True), accelerator="gpu", devices=1, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -751,7 +751,7 @@ def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): strategy=DeepSpeedStrategy(stage=3), accelerator="gpu", devices=1, - precision=16, + precision="16-mixed", callbacks=[ck], enable_progress_bar=False, enable_model_summary=False, @@ -792,7 +792,7 @@ def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> max_epochs=2, limit_train_batches=1, limit_val_batches=0, - precision=16, + precision="16-mixed", callbacks=TestCallback(), enable_progress_bar=False, enable_model_summary=False, @@ -828,7 +828,7 @@ def on_train_batch_start(self, trainer, pl_module: LightningModule, batch: Any, devices=2, limit_train_batches=5, limit_val_batches=2, - precision=16, + precision="16-mixed", accumulate_grad_batches=2, callbacks=[verification_callback], enable_progress_bar=False, @@ -849,7 +849,7 @@ def test_deepspeed_multigpu_test(tmpdir): accelerator="gpu", devices=2, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -885,7 +885,7 @@ def on_train_epoch_start(self) -> None: accelerator="gpu", devices=1, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -912,7 +912,7 @@ def on_train_epoch_start(self) -> None: accelerator="gpu", devices=1, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -976,7 +976,7 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir): accelerator="gpu", devices=2, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -998,7 +998,7 @@ def training_step(self, batch, batch_idx): accelerator="gpu", devices=1, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) @@ -1212,7 +1212,7 @@ def test_deepspeed_with_bfloat16_precision(tmpdir): accelerator="gpu", devices=2, fast_dev_run=True, - precision="bf16", + precision="bf16-mixed", num_sanity_val_steps=0, enable_progress_bar=False, enable_model_summary=False, @@ -1220,7 +1220,7 @@ def test_deepspeed_with_bfloat16_precision(tmpdir): trainer.fit(model) assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.strategy.precision_plugin.precision == "bf16" + assert trainer.strategy.precision_plugin.precision == "bf16-mixed" assert trainer.strategy.config["zero_optimization"]["stage"] == 3 assert trainer.strategy.config["bf16"]["enabled"] assert model.layer.weight.dtype == torch.bfloat16 @@ -1271,7 +1271,7 @@ def transfer_batch_to_device(self, batch, *args, **kwargs): return super().transfer_batch_to_device(batch, *args, **kwargs) model = CustomBoringModel() - trainer = Trainer(strategy="deepspeed", devices=1, accelerator="cuda", precision=16) + trainer = Trainer(strategy="deepspeed", devices=1, accelerator="cuda", precision="16-mixed") trainer.strategy.connect(model) batch = torch.zeros((1), dtype=torch.float32) batch = trainer.strategy.batch_to_device(batch) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 42425f581765f..05aec225204a4 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -64,7 +64,7 @@ def on_predict_batch_end(self, *_) -> None: def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, FullyShardedDataParallel) assert isinstance(self.trainer.strategy.precision_plugin, FSDPMixedPrecisionPlugin) - precision = torch.float16 if self.trainer.precision == "16" else torch.bfloat16 + precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16 assert self.layer.mixed_precision.param_dtype == precision assert self.layer.mixed_precision.reduce_dtype == precision assert self.layer.mixed_precision.buffer_dtype == precision @@ -100,7 +100,7 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, torch.nn.Sequential) assert isinstance(self.trainer.strategy.precision_plugin, FSDPMixedPrecisionPlugin) - precision = torch.float16 if self.trainer.precision == "16" else torch.bfloat16 + precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16 for layer_num in [0, 2]: assert isinstance(self.layer[layer_num], FullyShardedDataParallel) assert self.layer[layer_num].mixed_precision.param_dtype == precision @@ -164,7 +164,7 @@ def test_invalid_on_cpu(tmpdir): @RunIf(min_torch="1.12", min_cuda_gpus=1) -@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) +@pytest.mark.parametrize("precision, expected", [("16-mixed", torch.float16), ("bf16-mixed", torch.bfloat16)]) def test_precision_plugin_config(precision, expected): plugin = FSDPMixedPrecisionPlugin(precision=precision, device="cuda") config = plugin.mixed_precision_config @@ -191,7 +191,7 @@ def test_fsdp_strategy_sync_batchnorm(tmpdir): accelerator="gpu", devices=2, strategy="fsdp", - precision=16, + precision="16-mixed", max_epochs=1, sync_batchnorm=True, ) @@ -199,7 +199,7 @@ def test_fsdp_strategy_sync_batchnorm(tmpdir): @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") -@pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) +@pytest.mark.parametrize("precision", ("16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)))) def test_fsdp_strategy_checkpoint(tmpdir, precision): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" model = TestFSDPModel() @@ -230,7 +230,7 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy): accelerator="gpu", devices=2, strategy=strategy, - precision=16, + precision="16-mixed", max_epochs=1, limit_train_batches=2, limit_val_batches=2, diff --git a/tests/tests_pytorch/strategies/test_registry.py b/tests/tests_pytorch/strategies/test_registry.py index 8882bd441fe1a..75b7b63957387 100644 --- a/tests/tests_pytorch/strategies/test_registry.py +++ b/tests/tests_pytorch/strategies/test_registry.py @@ -48,7 +48,7 @@ def test_strategy_registry_with_deepspeed_strategies(strategy_name, init_params) @pytest.mark.parametrize("strategy", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"]) def test_deepspeed_strategy_registry_with_trainer(tmpdir, strategy): - trainer = Trainer(default_root_dir=tmpdir, strategy=strategy, precision=16) + trainer = Trainer(default_root_dir=tmpdir, strategy=strategy, precision="16-mixed") assert isinstance(trainer.strategy, DeepSpeedStrategy) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index f5b6c25200940..e98d4df2a9c54 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -413,7 +413,7 @@ def test_device_type_when_strategy_instance_gpu_passed(strategy_class, cuda_coun @pytest.mark.parametrize("precision", [1, 12, "invalid"]) def test_validate_precision_type(precision): - with pytest.raises(MisconfigurationException, match=f"Precision {repr(precision)} is invalid"): + with pytest.raises(ValueError, match=f"Precision {repr(precision)} is invalid"): Trainer(precision=precision) @@ -596,14 +596,16 @@ def test_check_fsdp_strategy_and_fallback(): def test_unsupported_tpu_choice(tpu_available): - with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"): - Trainer(accelerator="tpu", precision=64) + with pytest.raises( + MisconfigurationException, match=r"accelerator='tpu', precision='64-true'\)` is not implemented" + ): + Trainer(accelerator="tpu", precision="64-true") # if user didn't set strategy, AcceleratorConnector will choose the TPUSingleStrategy or XLAStrategy with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"), pytest.warns( - UserWarning, match=r"accelerator='tpu', precision=16\)` but AMP is not supported" + UserWarning, match=r"accelerator='tpu', precision=16-mixed\)` but AMP with fp16 is not supported" ): - Trainer(accelerator="tpu", precision=16, strategy="ddp") + Trainer(accelerator="tpu", precision="16-mixed", strategy="ddp") @mock.patch("lightning.pytorch.accelerators.ipu.IPUAccelerator.is_available", return_value=True) @@ -613,10 +615,10 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch): monkeypatch.setattr(ipu_, "_IPU_AVAILABLE", True) monkeypatch.setattr(ipu, "_IPU_AVAILABLE", True) - with pytest.raises(ValueError, match=r"accelerator='ipu', precision='bf16'\)` is not supported"): - Trainer(accelerator="ipu", precision="bf16") - with pytest.raises(ValueError, match=r"accelerator='ipu', precision='64'\)` is not supported"): - Trainer(accelerator="ipu", precision=64) + with pytest.raises(ValueError, match=r"accelerator='ipu', precision='bf16-mixed'\)` is not supported"): + Trainer(accelerator="ipu", precision="bf16-mixed") + with pytest.raises(ValueError, match=r"accelerator='ipu', precision='64-true'\)` is not supported"): + Trainer(accelerator="ipu", precision="64-true") @mock.patch("lightning.pytorch.accelerators.tpu._XLA_AVAILABLE", return_value=False) @@ -839,6 +841,7 @@ def get_defaults(cls): @RunIf(min_cuda_gpus=1) # trigger this test on our GPU pipeline, because we don't install the package on the CPU suite @pytest.mark.skipif(not package_available("lightning_colossalai"), reason="Requires Colossal AI Strategy") +@pytest.mark.skip def test_colossalai_external_strategy(monkeypatch): with mock.patch( "lightning.pytorch.trainer.connectors.accelerator_connector._LIGHTNING_COLOSSALAI_AVAILABLE", False @@ -847,5 +850,5 @@ def test_colossalai_external_strategy(monkeypatch): from lightning_colossalai import ColossalAIStrategy - trainer = Trainer(strategy="colossalai", precision=16) + trainer = Trainer(strategy="colossalai", precision="16-mixed") assert isinstance(trainer.strategy, ColossalAIStrategy) diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index 8a5bedf8efabe..ad6e0c69908af 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -72,7 +72,8 @@ def configure_optimizers(self): @pytest.mark.parametrize( - "kwargs", [{}, pytest.param({"accelerator": "gpu", "devices": 1, "precision": 16}, marks=RunIf(min_cuda_gpus=1))] + "kwargs", + [{}, pytest.param({"accelerator": "gpu", "devices": 1, "precision": "16-mixed"}, marks=RunIf(min_cuda_gpus=1))], ) def test_multiple_optimizers_manual_call_counts(tmpdir, kwargs): model = ManualOptModel() @@ -87,7 +88,7 @@ def test_multiple_optimizers_manual_call_counts(tmpdir, kwargs): **kwargs, ) - if kwargs.get("precision") == 16: + if kwargs.get("precision") == "16-mixed": # mock the scaler instead of the optimizer step because it can be skipped with NaNs scaler_step_patch = mock.patch.object( trainer.precision_plugin.scaler, "step", wraps=trainer.precision_plugin.scaler.step @@ -99,7 +100,7 @@ def test_multiple_optimizers_manual_call_counts(tmpdir, kwargs): assert bwd_mock.call_count == limit_train_batches * 3 assert trainer.global_step == limit_train_batches * 2 - if kwargs.get("precision") == 16: + if kwargs.get("precision") == "16-mixed": scaler_step_patch.stop() assert scaler_step.call_count == len(model.optimizers()) * limit_train_batches @@ -141,7 +142,7 @@ def test_multiple_optimizers_manual_amp(tmpdir, accelerator): max_epochs=1, log_every_n_steps=1, enable_model_summary=False, - precision=16, + precision="16-mixed", accelerator=accelerator, devices=1, ) @@ -224,7 +225,7 @@ def test_manual_optimization_and_return_tensor(tmpdir): limit_train_batches=10, limit_test_batches=0, limit_val_batches=0, - precision=16, + precision="16-mixed", strategy="ddp_spawn", accelerator="gpu", devices=2, @@ -309,7 +310,7 @@ def on_train_epoch_end(self, *_, **__): limit_train_batches=20, limit_test_batches=0, limit_val_batches=0, - precision=16, + precision="16-mixed", accelerator="gpu", devices=1, ) @@ -383,7 +384,7 @@ def on_before_optimizer_step(self, optimizer, *_): max_epochs=1, log_every_n_steps=1, enable_model_summary=False, - precision=16, + precision="16-mixed", accelerator="gpu", devices=1, ) @@ -848,7 +849,7 @@ def test_lr_scheduler_step_not_called(tmpdir): @RunIf(min_cuda_gpus=1) -@pytest.mark.parametrize("precision", [16, 32]) +@pytest.mark.parametrize("precision", ["16-mixed", "32-true"]) def test_multiple_optimizers_logging(precision, tmpdir): """Tests that metrics are properly being logged.""" diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 1a77eed0c96eb..f7042f3a5ad9a 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1019,7 +1019,7 @@ def on_exception(self, trainer, pl_module, exception): assert isinstance(handle_interrupt_callback.exception, MisconfigurationException) -@pytest.mark.parametrize("precision", [32, pytest.param(16, marks=RunIf(min_cuda_gpus=1))]) +@pytest.mark.parametrize("precision", ["32-true", pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1))]) @RunIf(sklearn=True) def test_gradient_clipping_by_norm(tmpdir, precision): """Test gradient clipping by norm.""" @@ -1048,7 +1048,7 @@ def configure_gradient_clipping(self, *args, **kwargs): assert model.assertion_called -@pytest.mark.parametrize("precision", [32, pytest.param(16, marks=RunIf(min_cuda_gpus=1))]) +@pytest.mark.parametrize("precision", ["32-true", pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1))]) def test_gradient_clipping_by_value(tmpdir, precision): """Test gradient clipping by value.""" trainer = Trainer( @@ -1437,7 +1437,7 @@ def test_spawn_predict_return_predictions(tmpdir): @pytest.mark.parametrize("return_predictions", [None, False, True]) -@pytest.mark.parametrize("precision", [32, 64]) +@pytest.mark.parametrize("precision", ["32-true", "64-true"]) def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir): """Test that `return_predictions=True`.""" seed_everything(42) @@ -1448,7 +1448,7 @@ def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir): if return_predictions or return_predictions is None: assert len(preds) == 1 assert preds[0].shape == torch.Size([1, 2]) - assert preds[0].dtype == (torch.float64 if precision == 64 else torch.float32) + assert preds[0].dtype == (torch.float64 if precision == "64-true" else torch.float32) @pytest.mark.parametrize(["max_steps", "max_epochs", "global_step"], [(10, 5, 10), (20, None, 20)]) diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index a5fec99febf71..d36efb88c7495 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -254,7 +254,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir): def test_auto_scale_batch_size_with_amp(tmpdir): before_batch_size = 2 model = BatchSizeModel(batch_size=before_batch_size) - trainer = Trainer(default_root_dir=tmpdir, max_steps=1, accelerator="gpu", devices=1, precision=16) + trainer = Trainer(default_root_dir=tmpdir, max_steps=1, accelerator="gpu", devices=1, precision="16-mixed") tuner = Tuner(trainer) tuner.scale_batch_size(model) after_batch_size = model.batch_size diff --git a/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py b/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py index 5c0cb588ebe5b..9f4bcef723434 100644 --- a/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py @@ -32,7 +32,7 @@ def test_deepspeed_collate_checkpoint(tmpdir): accelerator="gpu", devices=2, fast_dev_run=True, - precision=16, + precision="16-mixed", enable_progress_bar=False, enable_model_summary=False, ) diff --git a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py index f06c101db8246..146ab1aa6601b 100644 --- a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py +++ b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py @@ -45,7 +45,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - accelerator="gpu", fast_dev_run=True, devices=2, - precision=16, + precision="16-mixed", enable_model_summary=True, callbacks=[TestCallback()], ) diff --git a/tests/tests_pytorch/utilities/test_torchdistx.py b/tests/tests_pytorch/utilities/test_torchdistx.py index 9fee068cee9ab..187a9a5c56084 100644 --- a/tests/tests_pytorch/utilities/test_torchdistx.py +++ b/tests/tests_pytorch/utilities/test_torchdistx.py @@ -55,7 +55,7 @@ def test_deferred_init_with_lightning_module(): ( {"accelerator": "auto", "devices": 1}, pytest.param( - {"strategy": "deepspeed_stage_3", "accelerator": "gpu", "devices": 2, "precision": 16}, + {"strategy": "deepspeed_stage_3", "accelerator": "gpu", "devices": 2, "precision": "16-mixed"}, marks=RunIf(min_cuda_gpus=2, deepspeed=True), ), ),