Skip to content

Commit

Permalink
Introduce new precision layout in PL (#16783)
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Feb 17, 2023
1 parent ec4f592 commit 0fee284
Show file tree
Hide file tree
Showing 45 changed files with 227 additions and 198 deletions.
24 changes: 18 additions & 6 deletions docs/source-pytorch/common/precision_basic.rst
Expand Up @@ -20,11 +20,11 @@ Higher precision, such as the 64-bit floating-point, can be used for highly sens
16-bit Precision
****************

Use 16-bit precision to cut your memory consumption in half so that you can train and deploy larger models. If your GPUs are [`Tensor Core <https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html>`_] GPUs, you can also get a ~3x speed improvement. Half precision can sometimes lead to unstable training.
Use 16-bit mixed precision to lower your memory consumption by up to half so that you can train and deploy larger models. If your GPUs are [`Tensor Core <https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html>`_] GPUs, you can also get a ~3x speed improvement. Half precision can sometimes lead to unstable training.

.. code::
Trainer(precision=16)
Trainer(precision='16-mixed')
----

Expand All @@ -36,6 +36,12 @@ Use 16-bit precision to cut your memory consumption in half so that you can trai

.. testcode::

Trainer(precision='32-true')

# or
Trainer(precision='32')

# or
Trainer(precision=32)

----
Expand All @@ -48,6 +54,12 @@ For certain scientific computations, 64-bit precision enables more accurate mode

.. testcode::

Trainer(precision='64-true')

# or
Trainer(precision='64')

# or
Trainer(precision=64)

.. note::
Expand All @@ -70,22 +82,22 @@ Precision support by accelerator
- GPU
- TPU
- IPU
* - 16
* - 16 Mixed
- No
- Yes
- No
- Yes
* - BFloat16
* - BFloat16 Mixed
- Yes
- Yes
- Yes
- No
* - 32
* - 32 True
- Yes
- Yes
- Yes
- Yes
* - 64
* - 64 True
- Yes
- Yes
- No
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/common/precision_expert.rst
Expand Up @@ -20,7 +20,7 @@ You can also customize and pass your own Precision Plugin by subclassing the :cl
.. code-block:: python
class CustomPrecisionPlugin(PrecisionPlugin):
precision = 16
precision = '16-mixed'
...
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/common/precision_intermediate.rst
Expand Up @@ -63,7 +63,7 @@ Since computation happens in FP16, there is a chance of numerical instability du

.. note::

When using TPUs, setting ``precision=16`` will enable bfloat16, the only supported half precision type on TPUs.
When using TPUs, setting ``precision='16-mixed'`` will enable bfloat16, the only supported half precision type on TPUs.

.. testcode::
:skipif: not torch.cuda.is_available()
Expand Down
4 changes: 2 additions & 2 deletions docs/source-pytorch/common/trainer.rst
Expand Up @@ -926,10 +926,10 @@ Half precision, or mixed precision, is the combined use of 32 and 16 bit floatin
trainer = Trainer(precision=32)

# 16-bit precision
trainer = Trainer(precision=16, accelerator="gpu", devices=1) # works only on CUDA
trainer = Trainer(precision="16-mixed", accelerator="gpu", devices=1) # works only on CUDA

# bfloat16 precision
trainer = Trainer(precision="bf16")
trainer = Trainer(precision="bf16-mixed")

# 64-bit precision
trainer = Trainer(precision=64)
Expand Down
1 change: 0 additions & 1 deletion docs/source-pytorch/fabric/fundamentals/launch.rst
Expand Up @@ -74,7 +74,6 @@ This is essentially the same as running ``python path/to/your/script.py``, but i
precision (``16-mixed`` or ``16``) or
bfloat16 precision (``bf16-mixed`` or
``bf16``)
--help Show this message and exit.
Expand Down
2 changes: 1 addition & 1 deletion examples/app_multi_node/train_fabric.py
Expand Up @@ -15,7 +15,7 @@ def run(self):
)

# 2. Create Fabric.
fabric = Fabric(strategy="ddp", precision=16)
fabric = Fabric(strategy="ddp", precision="16-mixed")
model, optimizer = fabric.setup(model, torch.optim.SGD(model.parameters(), lr=0.01))
criterion = torch.nn.MSELoss()

Expand Down
2 changes: 1 addition & 1 deletion examples/pl_hpu/mnist_sample.py
Expand Up @@ -63,7 +63,7 @@ def configure_optimizers(self):
"accelerator": "hpu",
"devices": 1,
"max_epochs": 1,
"plugins": lazy_instance(HPUPrecisionPlugin, precision=16),
"plugins": lazy_instance(HPUPrecisionPlugin, precision="16-mixed"),
},
run=False,
save_config_kwargs={"overwrite": True},
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Expand Up @@ -107,6 +107,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781))


- Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16783](https://github.com/Lightning-AI/lightning/pull/16783))

### Deprecated

-
Expand Down
15 changes: 9 additions & 6 deletions src/lightning/pytorch/plugins/precision/amp.py
Expand Up @@ -34,15 +34,18 @@ class MixedPrecisionPlugin(PrecisionPlugin):
"""

def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
) -> None:
self.precision = cast(Literal["16", "bf16"], str(precision)) # type: ignore
if scaler is None and self.precision == "16":
self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision))
if scaler is None and self.precision == "16-mixed":
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16":
raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")
if scaler is not None and self.precision == "bf16-mixed":
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
self.scaler = scaler

Expand Down Expand Up @@ -97,7 +100,7 @@ def clip_gradients(
def autocast_context_manager(self) -> torch.autocast:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand Down
10 changes: 4 additions & 6 deletions src/lightning/pytorch/plugins/precision/deepspeed.py
Expand Up @@ -31,9 +31,7 @@

warning_cache = WarningCache()

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"]


class DeepSpeedPrecisionPlugin(PrecisionPlugin):
Expand All @@ -46,14 +44,14 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
If unsupported ``precision`` is provided.
"""

def __init__(self, precision: Literal["32", 32, "16", 16, "bf16"]) -> None:
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
def __init__(self, precision: Literal["32-true", "16-mixed", "bf16-mixed"]) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore
self.precision = cast(_PRECISION_INPUT, str(precision))

def backward( # type: ignore[override]
self,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/double.py
Expand Up @@ -72,7 +72,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
class DoublePrecisionPlugin(PrecisionPlugin):
"""Plugin for training with double (``torch.float64``) precision."""

precision: Literal["64"] = "64" # type: ignore
precision: Literal["64-true"] = "64-true"

def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Expand Up @@ -31,12 +31,12 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin):
"""AMP for Fully Sharded Data Parallel (FSDP) Training."""

def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional[ShardedGradScaler] = None
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException("`FSDPMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards.")
super().__init__(
precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None)
precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16-mixed" else None)
)

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
Expand All @@ -52,9 +52,9 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
@property
def mixed_precision_config(self) -> Optional[MixedPrecision]:
assert MixedPrecision is not None
if self.precision == "16":
if self.precision == "16-mixed":
dtype = torch.float16
elif self.precision == "bf16":
elif self.precision == "bf16-mixed":
dtype = torch.bfloat16
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
Expand Down
12 changes: 5 additions & 7 deletions src/lightning/pytorch/plugins/precision/hpu.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast, Literal, Optional, Union
from typing import cast, Literal, Optional

from typing_extensions import get_args

Expand All @@ -22,9 +22,7 @@
if _HPU_AVAILABLE:
from habana_frameworks.torch.hpex import hmp

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"]


class HPUPrecisionPlugin(PrecisionPlugin):
Expand All @@ -48,14 +46,14 @@ def __init__(
) -> None:
if not _HPU_AVAILABLE:
raise MisconfigurationException("HPU precision plugin requires HPU devices.")
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore
if self.precision in ("16", "bf16"):
self.precision = cast(_PRECISION_INPUT, str(precision))
if self.precision in ("16-mixed", "bf16-mixed"):
hmp.convert(
opt_level=opt_level, bf16_file_path=bf16_file_path, fp32_file_path=fp32_file_path, isVerbose=verbose
)
12 changes: 5 additions & 7 deletions src/lightning/pytorch/plugins/precision/ipu.py
Expand Up @@ -27,27 +27,25 @@

warning_cache = WarningCache()

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
_PRECISION_INPUT = Literal["32-true", "16-mixed"]


class IPUPrecisionPlugin(PrecisionPlugin):
"""Precision plugin for IPU integration.
Raises:
ValueError:
If the precision is neither 16 nor 32.
If the precision is neither 16-mixed nor 32-true.
"""

def __init__(self, precision: Literal["32", 32, "16", 16]) -> None:
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
def __init__(self, precision: Literal["32-true", "16-mixed"]) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore
self.precision = cast(_PRECISION_INPUT, str(precision))

def backward( # type: ignore[override]
self,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/tpu_bf16.py
Expand Up @@ -23,7 +23,7 @@
class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
"""Plugin that enables bfloats on TPUs."""

precision: Literal["bf16"] = "bf16" # type: ignore
precision: Literal["bf16-mixed"] = "bf16-mixed"

def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
Expand Down
12 changes: 6 additions & 6 deletions src/lightning/pytorch/strategies/deepspeed.py
Expand Up @@ -127,8 +127,8 @@ def __init__(
Arguments:
zero_optimization: Enable ZeRO optimization. This is compatible with either `precision=16` or
`precision="bf16"`.
zero_optimization: Enable ZeRO optimization. This is compatible with either `precision="16-mixed"` or
`precision="bf16-mixed"`.
stage: Different stages of the ZeRO Optimizer. 0 is disabled,
1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning,
Expand Down Expand Up @@ -505,9 +505,9 @@ def model_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
assert self._config_initialized

if self.precision_plugin.precision == "16":
if self.precision_plugin.precision == "16-mixed":
dtype = torch.float16
elif self.precision_plugin.precision == "bf16":
elif self.precision_plugin.precision == "bf16-mixed":
dtype = torch.bfloat16
else:
dtype = torch.float32
Expand Down Expand Up @@ -641,7 +641,7 @@ def _auto_select_batch_size(self) -> int:

def _format_precision_config(self) -> None:
assert isinstance(self.config, dict)
if self.precision_plugin.precision == "16":
if self.precision_plugin.precision == "16-mixed":
if "fp16" not in self.config:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
Expand All @@ -653,7 +653,7 @@ def _format_precision_config(self) -> None:
"hysteresis": self.hysteresis,
"min_loss_scale": self.min_loss_scale,
}
elif "bf16" not in self.config and self.precision_plugin.precision == "bf16":
elif "bf16" not in self.config and self.precision_plugin.precision == "bf16-mixed":
rank_zero_info("Enabling DeepSpeed BF16.")
self.config["bf16"] = {"enabled": True}

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Expand Up @@ -99,8 +99,8 @@ class FSDPStrategy(ParallelStrategy):
algorithms to help backward communication and computation overlapping.
The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
mixed_precision:
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
or BF16 if ``precision=bf16`` unless a config is passed in.
Mixed Precision config. By default, Lightning will enable FP16 if ``precision="16-mixed"``
or BF16 if ``precision="bf16-mixed"`` unless a config is passed in.
This is only available in PyTorch 1.12 and later.
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
checkpointing. This is typically your transformer block (including attention + feed-forward).
Expand Down
14 changes: 11 additions & 3 deletions src/lightning/pytorch/strategies/utils.py
Expand Up @@ -32,9 +32,17 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) ->
mod.register_strategies(registry)


def _fp_to_half(tensor: Tensor, precision: Literal["64", 64, "32", 32, "16", 16, "bf16"]) -> Tensor:
if str(precision) == "16":
def _fp_to_half(
tensor: Tensor,
precision: Literal[
"64-true",
"32-true",
"16-mixed",
"bf16-mixed",
],
) -> Tensor:
if str(precision) == "16-mixed":
return _convert_fp_tensor(tensor, torch.half)
if precision == "bf16":
if precision == "bf16-mixed":
return _convert_fp_tensor(tensor, torch.bfloat16)
return tensor

0 comments on commit 0fee284

Please sign in to comment.