-
Notifications
You must be signed in to change notification settings - Fork 234
/
fp8.py
672 lines (572 loc) · 24 KB
/
fp8.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 utilities for TransformerEngine"""
import os
from contextlib import contextmanager
from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch
import transformer_engine_extensions as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type
from .utils import get_device_compute_capability
from .jit import jit_fuser
__all__ = ["fp8_autocast"]
def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if get_device_compute_capability() >= 9.0: # hopper and above
return True, ""
if get_device_compute_capability() < 8.9: # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if float(torch.version.cuda) < 12.1:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling()
def get_fp8_te_dtype(
fp8_recipe: DelayedScaling, fprop_tensor: bool = True
) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
class FP8GlobalStateManager:
"""Class to keep track of and manipulate the global
FP8 state at different stages of execution.
"""
FP8_ENABLED = False
FP8_CALIBRATION = False
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
IS_FIRST_FP8_MODULE = False
FP8_AUTOCAST_COUNTER = 0
FP8_CURRENT_CONTEXT_ID = 0
FP8_AUTOCAST_DEPTH = 0
global_fp8_buffer = {}
fp8_tensors_recompute_buffer = []
amax_forward_global_reduce_func = None
buffer_delete_key_fwd = None
buffer_delete_key_bwd = None
amax_reduce_handle_fwd = None
fp8_available = None
reason_for_no_fp8 = ""
dp_amax_reduce_interval = None
dp_amax_reduce_forward_idx = 0
dp_amax_reduce_backward_idx = 0
@classmethod
def is_fp8_available(cls) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if cls.fp8_available is None:
cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support()
return cls.fp8_available, cls.reason_for_no_fp8
@classmethod
def get_global_fp8_state_checkpoint(cls) -> Dict[str, Union[int, str]]:
"""Returns global fp8 state variables."""
# Convert attributes to dictionary to make future proof against
# changes in global state variables in order to make setting the
# checkpoint backwards compatible.
global_fp8_state = {}
global_fp8_state["FP8_AUTOCAST_COUNTER"] = cls.FP8_AUTOCAST_COUNTER
global_fp8_state["FP8_CURRENT_CONTEXT_ID"] = cls.FP8_CURRENT_CONTEXT_ID
global_fp8_state["FP8_AUTOCAST_DEPTH"] = cls.FP8_AUTOCAST_DEPTH
global_fp8_state["buffer_delete_key_fwd"] = cls.buffer_delete_key_fwd
global_fp8_state["buffer_delete_key_bwd"] = cls.buffer_delete_key_bwd
global_fp8_state["dp_amax_reduce_interval"] = cls.dp_amax_reduce_interval
global_fp8_state["dp_amax_reduce_forward_idx"] = cls.dp_amax_reduce_forward_idx
global_fp8_state["dp_amax_reduce_backward_idx"] = cls.dp_amax_reduce_backward_idx
return global_fp8_state
@classmethod
def set_global_fp8_state_checkpoint(cls, state: Dict[str, Union[int, str]]) -> None:
"""Sets global fp8 state variables."""
for k, v in state.items():
if hasattr(cls, k):
setattr(cls, k, v)
@classmethod
def get_global_fp8_buffer_checkpoint(cls) -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 amax buffer."""
return cls.global_fp8_buffer
@classmethod
def set_global_fp8_buffer_checkpoint(cls, buffer: Dict[str, List[torch.Tensor]]) -> None:
"""Sets global fp8 amax buffer."""
# Map all tensors back to GPU.
for k, v in buffer.items():
buffer[k] = [tensor.cuda() for tensor in v]
cls.global_fp8_buffer = buffer
@staticmethod
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`."""
if forward:
return "scaling_fwd"
return "scaling_bwd"
@staticmethod
def get_buffer_position_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "global_fp8_buffer_pos_fwd"
return "global_fp8_buffer_pos_bwd"
@staticmethod
def get_autocast_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "autocast_id_fwd"
return "autocast_id_bwd"
@staticmethod
def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str:
"""Return a key in `_global_fp8_buffer` for the AMAX storage."""
if forward:
return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}"
return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}"
@classmethod
def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]:
"""Return AMAX reduction wait handle of forward prop."""
return cls.amax_reduce_handle_fwd
@classmethod
def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None:
"""Sets up the function to call during autocast exit."""
cls.amax_forward_global_reduce_func = f
@classmethod
def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None:
"""Append 1D tensor `amax` to global buffer."""
buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
buffer_position_key = cls.get_buffer_position_key(forward=forward)
if buffer_key not in cls.global_fp8_buffer:
cls.global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
else:
cls.global_fp8_buffer[buffer_key].append(
fp8_meta[fp8_meta_tensor_key].amax_history[0]
)
if buffer_position_key not in fp8_meta:
fp8_meta[buffer_position_key] = len(cls.global_fp8_buffer[buffer_key]) - 1
# Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(cls.global_fp8_buffer[buffer_key]) - 1, \
"Same module is being invoked more than once inside an `fp8_autocast` " \
"region when using FP8 with amax reduction. This behavior is currently" \
" unsupported. For more details and correct usage, please see " \
"https://github.com/NVIDIA/TransformerEngine/pull/93."
@classmethod
def copy_amax_from_global_buffer(
cls, fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Populate current amax with the correct location from buffer."""
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
buffer_position_key = cls.get_buffer_position_key(forward=forward)
if buffer_position_key not in fp8_meta:
return
amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
assert amax_buffer_key in cls.global_fp8_buffer, "TE internal error."
fp8_meta[fp8_meta_tensor_key].amax_history[0] = cls.global_fp8_buffer[amax_buffer_key][
fp8_meta[buffer_position_key]
]
@classmethod
def set_amax_buffer_key_deletion(
cls, fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Delete this amax key from global buffer during autocast end."""
if cls.get_autocast_key(forward=forward) not in fp8_meta:
return
if forward:
cls.buffer_delete_key_fwd = cls.get_amax_buffer_key(fp8_meta, forward=forward)
else:
cls.buffer_delete_key_bwd = cls.get_amax_buffer_key(fp8_meta, forward=forward)
@classmethod
def delete_key_from_amax_buffer(cls, forward: bool = True) -> None:
"""Delete the key from global amax buffer."""
if forward:
if (
cls.buffer_delete_key_fwd is not None
and cls.buffer_delete_key_fwd in cls.global_fp8_buffer
):
del cls.global_fp8_buffer[cls.buffer_delete_key_fwd]
else:
if (
cls.buffer_delete_key_bwd is not None
and cls.buffer_delete_key_bwd in cls.global_fp8_buffer
):
del cls.global_fp8_buffer[cls.buffer_delete_key_bwd]
@classmethod
def get_fp8_context_id(cls) -> int:
"""Returns an ID for the current FP8 context."""
return cls.FP8_CURRENT_CONTEXT_ID
@classmethod
def set_fp8_context_id(cls, ctx_id: int) -> None:
"""Sets the current FP8 context."""
cls.FP8_CURRENT_CONTEXT_ID = ctx_id
@classmethod
def new_fp8_context_id(cls) -> int:
"""Returns global autocast counter as a proxy to be used
as the autocast ID for FP8 modules.
"""
return cls.FP8_AUTOCAST_COUNTER
@classmethod
def is_fp8_enabled(cls) -> bool:
"""Is FP8 enabled"""
return cls.FP8_ENABLED
@classmethod
def is_fp8_calibration(cls) -> bool:
"""Is FP8 calibration"""
return cls.FP8_CALIBRATION
@classmethod
def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
tmp = cls.IS_FIRST_FP8_MODULE
cls.IS_FIRST_FP8_MODULE = False
return tmp
@classmethod
def get_fp8_recipe(cls) -> DelayedScaling:
"""Return the fp8 recipe"""
return cls.FP8_RECIPE
@classmethod
def get_fp8_group(cls) -> Union[dist_group_type, None]:
"""Return the fp8 group for scale/amax comm"""
return cls.FP8_DISTRIBUTED_GROUP
@classmethod
def get_fp8_autocast_state(cls) -> Tuple[bool, bool, DelayedScaling, dist_group_type, bool]:
"""FP8 autocast state getter"""
return (
cls.FP8_ENABLED,
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE)
@classmethod
def set_fp8_autocast_state(
cls,
fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool]
) -> None:
"""FP8 autocast state setter"""
(cls.FP8_ENABLED,
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE) = fp8_state
@staticmethod
def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type, async_op: bool
) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
wait_handle = torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=async_op,
)
return wait_handle
return None
@classmethod
def global_amax_reduction(
cls,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
forward: bool = True,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
# Key already deleted.
if amax_buffer_key not in cls.global_fp8_buffer:
return None
# Reduce AMAX in DP-domain at an interval.
if cls.dp_amax_reduce_interval is None:
cls.dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
tp_amax_reduce = False
if forward:
if cls.dp_amax_reduce_forward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
cls.dp_amax_reduce_forward_idx = (
(cls.dp_amax_reduce_forward_idx + 1) % cls.dp_amax_reduce_interval)
else:
if cls.dp_amax_reduce_backward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
cls.dp_amax_reduce_backward_idx = (
(cls.dp_amax_reduce_backward_idx + 1) % cls.dp_amax_reduce_interval)
if tp_amax_reduce:
if tp_size > 1:
reduce_group = tp_group
else:
return None
chunk_sizes = [x.numel() for x in cls.global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key])
wait_handle = cls.reduce_tensor_across_group_op_max(
contiguous_amax,
reduce_group,
fp8_meta["async_amax_reduction"],
)
cls.global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
return wait_handle
@classmethod
def fp8_autocast_enter(
cls,
enabled: bool = False,
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
) -> None:
"""Set state and tracking variables for entry into FP8 region."""
cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
cls.FP8_DISTRIBUTED_GROUP = fp8_group
if cls.FP8_AUTOCAST_DEPTH == 0:
cls.IS_FIRST_FP8_MODULE = True
cls.FP8_AUTOCAST_COUNTER += 1
cls.FP8_AUTOCAST_DEPTH += 1
if enabled:
fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
assert fp8_available, reason_for_no_fp8
@classmethod
def fp8_autocast_exit(cls):
"""Set state and tracking variables for exit from FP8 region."""
cls.FP8_AUTOCAST_DEPTH -= 1
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
@classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase
to ensure both forward steps are numerically same.
"""
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
to_copy = [
fp8_meta["scaling_fwd"].amax_history.clone(),
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone(),
]
if buffer_position_key in fp8_meta:
cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
else:
if len(cls.fp8_tensors_recompute_buffer) == 0:
cls.fp8_tensors_recompute_buffer = [deque()]
else:
cls.fp8_tensors_recompute_buffer.append(deque())
cls.fp8_tensors_recompute_buffer[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1
@classmethod
def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs.
"""
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history
fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale
fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[
fp8_meta[buffer_position_key]
].popleft()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]
@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
@contextmanager
def fp8_autocast(
enabled: bool = False,
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
) -> None:
"""
Context manager for FP8 usage.
.. code-block:: python
with fp8_autocast(enabled=True):
out = model(inp)
.. note::
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
with shapes where both dimensions are divisible by 16. In terms of the input to the full
Transformer network, this typically requires padding sequence length to be multiple of 16.
.. note::
When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
inside a single `fp8_autocast` region. This is unsupported behavior because the amax
reduction is handled during the exit of the `fp8_autocast` context. Calling the same
module more than once inside an `fp8_autocast` region overrides the amax tensors
before reduction can occur.
Parameters
----------
enabled: bool, default = `False`
whether or not to enable fp8
calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is
useful for saving an inference ready fp8 checkpoint while training
using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
"""
try:
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled, calibrating, fp8_recipe, fp8_group)
yield
finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment
FP8GlobalStateManager.fp8_autocast_exit()
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
"""Update amax history and set next amax to zero."""
if amax_history.shape[0] > 1:
amax_history = torch.roll(amax_history, -1, 0)
amax_history[0].fill_(0.0)
return amax_history
@jit_fuser
def _default_get_amax(
amax_history: torch.Tensor,
amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Default function to obtain amax from history."""
if amax_compute_algo == "max":
amax = torch.max(amax_history, dim=0).values
else: # amax_compute_algo == "most_recent"
amax = amax_history[0].clone()
amax_history = _update_amax_history(amax_history)
return amax_history, amax
@jit_fuser
def _default_sf_compute(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
margin: int,
) -> torch.Tensor:
"""Default function to convert amax to scaling factor."""
exp = torch.floor(torch.log2(fp8_max / amax)) - margin
sf = torch.round(torch.pow(2, torch.abs(exp)))
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
sf = torch.where(exp < 0, 1 / sf, sf)
return sf
@jit_fuser
def _compute_scaling_factor_inverse(
scale: torch.Tensor,
scale_inv: torch.Tensor,
non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> torch.Tensor:
"""Compute inverse of scaling factor."""
if update_weight_scale_inv:
return 1.0 / scale
return torch.where(non_weight_mask, 1.0 / scale, scale_inv)
@jit_fuser
def _fused_amax_and_scale_update(
amax_history: torch.Tensor,
scale: torch.Tensor,
scale_inv: torch.Tensor,
fp8_max: float,
margin: int,
amax_compute_algo: str,
non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Amax to scale conversion."""
# Get amax from history.
amax_history, amax = _default_get_amax(
amax_history,
amax_compute_algo,
)
# Calculate new scaling factor.
scale = _default_sf_compute(
amax,
scale,
fp8_max,
margin,
)
# Calculate new inverse of scaling factor.
scale_inv = _compute_scaling_factor_inverse(
scale,
scale_inv,
non_weight_mask,
update_weight_scale_inv,
)
return amax_history, scale, scale_inv
def _compute_amax(
amax_history: torch.Tensor,
recipe: DelayedScaling,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Obtain the amax from the history."""
if callable(recipe.amax_compute_algo):
amax = recipe.amax_compute_algo(amax_history)
amax_history = _update_amax_history(amax_history)
return amax_history, amax
return _default_get_amax(
amax_history,
recipe.amax_compute_algo,
)
def _compute_scaling_factor(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
recipe: DelayedScaling,
) -> torch.Tensor:
"""Convert amax to scaling factor."""
if recipe.scaling_factor_compute_algo is None:
return _default_sf_compute(
amax,
scale,
fp8_max,
recipe.margin,
)
return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)
def amax_and_scale_update(
fp8_meta: Dict[str, Any],
fwd_update: bool,
update_weight_scale_inv: bool = True,
) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo
sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo
fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd"
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None:
(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
) = _fused_amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
fp8_meta[fp8_max_key],
fp8_meta["recipe"].margin,
fp8_meta["recipe"].amax_compute_algo,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
)
else:
fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta["recipe"],
)
fp8_meta[fp8_meta_tensor_key].scale = _compute_scaling_factor(
amax,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_max_key],
fp8_meta["recipe"],
)
fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse(
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
)