-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
transformer_engine.py
179 lines (152 loc) · 7.91 KB
/
transformer_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import ExitStack
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Mapping, Optional, Union
import torch
from lightning_utilities import apply_to_collection
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import override
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import (
_ClassReplacementContextManager,
_convert_fp_tensor,
_DtypeContextManager,
)
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
if TYPE_CHECKING:
from transformer_engine.common.recipe import DelayedScaling
_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0")
log = logging.getLogger(__name__)
class TransformerEnginePrecision(Precision):
"""Plugin for training with fp8 precision via nvidia's
`Transformer Engine <https://docs.nvidia.com/deeplearning/transformer-engine>`__.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
Args:
weights_dtype: The weights dtype to use.
recipe: Recipe for the DelayedScaling
`configuration <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#transformer_engine.common.recipe.DelayedScaling>`__.
In dict format or the dataclass format.
replace_layers: Whether to replace ``Linear`` and ``LayerNorm`` layers automatically with their Transformer
Engine alternatives. Note that they don't subclass the torch equivalents so checks like
``isinstance(l, torch.nn.Linear)`` will not pass.
fallback_compute_dtype: The compute dtype to use for operations that don't support fp8 autocast. Defaults to the
same as ``weights_dtype``.
.. note::
Support for FP8 in the linear layers with this plugin is currently limited to tensors
with shapes where the dimensions are divisible by 8 and 16 respectively. You might want to add padding to your
inputs to conform to this restriction.
"""
precision: Literal["transformer-engine", "transformer-engine-float16"] = "transformer-engine"
def __init__(
self,
*,
weights_dtype: torch.dtype,
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
replace_layers: Optional[bool] = None,
fallback_compute_dtype: Optional[torch.dtype] = None,
) -> None:
if not _TRANSFORMER_ENGINE_AVAILABLE:
raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE))
from transformer_engine.common.recipe import DelayedScaling
if recipe is None:
recipe = DelayedScaling()
elif isinstance(recipe, Mapping):
recipe = dict(recipe) # copy
if "fp8_format" in recipe:
from transformer_engine.common.recipe import Format
recipe["fp8_format"] = getattr(Format, recipe["fp8_format"])
recipe = DelayedScaling(**recipe)
self.weights_dtype = weights_dtype
self.recipe = recipe
self.replace_layers = replace_layers
self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype
@override
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
# avoid converting if any is found. assume the user took care of it
if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
if self.replace_layers is True:
# info level because this is expected with `init_module`
rank_zero_info(
"`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains"
" TransformerEngine layers. Skipping"
)
elif self.replace_layers in (None, True):
_convert_layers(module)
module = module.to(dtype=self.weights_dtype)
return module
@override
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self.weights_dtype)
@override
def module_init_context(self) -> ContextManager:
dtype_ctx = self.tensor_init_context()
stack = ExitStack()
if self.replace_layers:
import transformer_engine.pytorch as te
context_manager = _ClassReplacementContextManager({
"torch.nn.Linear": te.Linear,
"torch.nn.LayerNorm": te.LayerNorm,
})
stack.enter_context(context_manager)
stack.enter_context(dtype_ctx)
return stack
@override
def forward_context(self) -> ContextManager:
dtype_ctx = _DtypeContextManager(self.weights_dtype)
fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
import transformer_engine.pytorch as te
autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe)
stack = ExitStack()
stack.enter_context(dtype_ctx)
# enable an outer fallback autocast for operations that do not support fp8
stack.enter_context(fallback_autocast_ctx)
stack.enter_context(autocast_ctx)
return stack
@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype)
@override
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
def _convert_layers(module: torch.nn.Module) -> None:
import transformer_engine.pytorch as te
for name, child in module.named_children():
if isinstance(child, torch.nn.Linear):
if child.in_features % 8 != 0 or child.out_features % 16 != 0:
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting
rank_zero_warn(
"Support for FP8 in the linear layers with this plugin is currently limited to"
" tensors with shapes where the dimensions are divisible by 8 and 16 respectively."
f" The layer {name!r} does not fit this criteria. You might want to add padding to your inputs."
)
continue
has_bias = child.bias is not None
replacement = te.Linear(child.in_features, child.out_features, bias=has_bias)
replacement.weight.data = child.weight.data.clone()
if has_bias:
replacement.bias.data = child.bias.data.clone()
log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
module.__setattr__(name, replacement)
elif isinstance(child, torch.nn.LayerNorm):
replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps)
replacement.weight.data = child.weight.data.clone()
replacement.bias.data = child.bias.data.clone()
log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
module.__setattr__(name, replacement)
else:
# there are other transformer engine layers that we could convert but require fusion. full list at:
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html
_convert_layers(child)