Skip to content

Commit

Permalink
TE: Support v1.8 - current main (#515)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Jun 4, 2024
1 parent 293a228 commit 77d0fbd
Showing 1 changed file with 57 additions and 7 deletions.
64 changes: 57 additions & 7 deletions thunder/executors/transformer_engineex.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
# between version 1.2 and 1.3.
# Hence, we have these guards based on version.
TE_VERSION_1_6_PLUS: bool = False
TE_VERSION_1_8_PLUS: bool = False

te: None | Any = None
if TE_AVAILABLE:
Expand All @@ -48,11 +49,13 @@
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import check_dim_for_fp8_exec
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
import transformer_engine_extensions as tex
except Exception as ex:
warnings.warn(f"transformer_engine failed to import with exception {ex}")
TE_AVAILABLE = False

TE_VERSION_1_6_PLUS = LooseVersion(version("transformer_engine")) > LooseVersion("1.6")
TE_VERSION_1_8_PLUS = LooseVersion(version("transformer_engine")) > LooseVersion("1.8")
if not TE_VERSION_1_6_PLUS:
warnings.warn(
f"Installed version of transformer_engine {version('transformer_engine')} is not supported, please upgrade. `transformer_engine_ex` will not be used."
Expand Down Expand Up @@ -166,8 +169,9 @@ def __init__(self, in_features: int, out_features: int) -> None:
if FP8GlobalStateManager.with_fp8_parameters():
raise RuntimeError("Primary weights in FP8 is not supported under `thunder.jit`.")

# Required by `get_fp8_weights_scratchpad`
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
if not TE_VERSION_1_8_PLUS:
# Required by `get_fp8_weights_scratchpad`
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

# NOTE: Backward FP8 metadata sync
# TransformerEngine v1.6 onwards, we control the sync and update of FP8 metadata for FP8 tensors
Expand All @@ -189,8 +193,10 @@ def forward(self, inp, weight, bias, is_first_microbatch: bool | None = None, is
assert (
self.fp8 or not self.primary_weights_in_fp8
), "Need to run inside fp8_autocast region when weights are stored in FP8."
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)

weight_fp8, weight_t_fp8 = self.get_fp8_weight_version_compat(
weight=weight, is_first_microbatch=is_first_microbatch, is_grad_enabled=is_grad_enabled
)

ctx = Context() if is_grad_enabled else None

Expand All @@ -204,8 +210,8 @@ def forward(self, inp, weight, bias, is_first_microbatch: bool | None = None, is
kwargs = {
"ctx": ctx,
"weight": weight,
"weight_fp8": weight1_fp8,
"weight_t_fp8": weight1_t_fp8,
"weight_fp8": weight_fp8,
"weight_t_fp8": weight_t_fp8,
"inp": inp,
"bias": torch.tensor([]) if not use_bias else bias,
"use_bias": bias is not None,
Expand Down Expand Up @@ -240,10 +246,48 @@ def forward(self, inp, weight, bias, is_first_microbatch: bool | None = None, is
else:
kwargs[param_name] = None

# Remove kwargs if they are not used in the current version.
unused_kwargs = set(kwargs.keys()) - set(params)
if TE_VERSION_1_8_PLUS:
# Sincev1.8 onwards, these args are not part of the _Linear API.
assert unused_kwargs == {"skip_fp8_weight_update", "primary_weights_in_fp8", "weight_t_fp8"}

for unused_kwarg in unused_kwargs:
kwargs.pop(unused_kwarg)

out = _Linear.forward(**kwargs)
ctx_dict = ctx.to_dict() if is_grad_enabled else None
return out, ctx_dict

def get_fp8_weight_version_compat(self, weight, is_first_microbatch, is_grad_enabled):
weight_t_fp8: torch.Tensor = None
weight_fp8: torch.Tensor = None
# Fetch the fp8 weights placeholders (for linear/gemm)
if not TE_VERSION_1_8_PLUS:
weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)
else:
# Initialize FP8 weights workspace if needed

# FP8 cast to workspace buffer
with_transpose = is_grad_enabled
update_workspace = is_first_microbatch is None or is_first_microbatch
skip_fp8_weight_update = None

weight_fp8 = self.get_fp8_workspace(
tensor=weight,
fp8_meta_forward=True,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
cache_name=(None if is_first_microbatch is None else "weight"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
with_transpose=with_transpose,
)

return weight_fp8, weight_t_fp8

# This method is used for supporting TE v1.6 and v1.7.
# v1.8 onwards the implementation of this has moved to `TransformerEngineBaseModule`
# See `get_fp8_workspace`: https://github.com/NVIDIA/TransformerEngine/blob/8b210490b3f46cd409df0ba6a8f4b14273f2975c/transformer_engine/pytorch/module/base.py#L753-L754
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: bool | None,
Expand Down Expand Up @@ -305,7 +349,13 @@ def _te_functional_linear_backward_impl(
# https://github.com/NVIDIA/TransformerEngine/blob/b957aa475bcbcf22405381d18bd7fefe4fb6b171/transformer_engine/pytorch/module/linear.py#L434
with enable_grad(ctx.saved_tensors[2]):
grads = _Linear.backward(ctx, g)
grad_inputs = (grads[3], grads[0], grads[4])

# Due to different in `_Linear.forward` API, position of
# returned grad has changed.
if TE_VERSION_1_8_PLUS:
grad_inputs = (grads[2], grads[0], grads[3])
else:
grad_inputs = (grads[3], grads[0], grads[4])
return grad_inputs


Expand Down

0 comments on commit 77d0fbd

Please sign in to comment.