Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformer_engine: Control backward FP8 metadata synchronization (v1.6 onwards) #379

Merged
merged 7 commits into from
May 13, 2024
6 changes: 3 additions & 3 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
bw_extrace = sort_waits(bw_extrace)

# Importing here to avoid cyclical dependencies in future.
from thunder.executors.transformer_engineex import _rearrange_transformer_engine_linear, transformer_engine_ex
from thunder.executors.transformer_engineex import _transformer_engine_bwd_fp8_meta_sync, transformer_engine_ex

if transformer_engine_ex in compile_data.executors_list:
# NOTE: `_rearrange_transformer_engine_linear` mutates `fw_extrace`.
_rearrange_transformer_engine_linear(fw_extrace, bw_extrace)
# NOTE: `_transformer_engine_bwd_fp8_meta_sync` may mutate `fw_extrace` or `bw_extrace`.
_transformer_engine_bwd_fp8_meta_sync(fw_extrace, bw_extrace)

fw_extrace = del_last_used(fw_extrace)
fw_traces.append(fw_extrace)
Expand Down
40 changes: 40 additions & 0 deletions thunder/executors/transformer_engineex.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ def __init__(self, in_features: int, out_features: int) -> None:
# Required by `get_fp8_weights_scratchpad`
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

if TE_VERSION_1_6_PLUS:
# NOTE: Backward FP8 metadata sync
# TransformerEngine v1.6 onwards, we control the sync and update of FP8 metadata for FP8 tensors
# tied to backward pass (i.e. the gradient tensors)
# Also, note that the forward tensor metadata sync occurs at the exit of `fp8_autocast` context manager
# which is not controlled by us.
#
# We consume the `is_first_fp8_module` so that the automatic sync for FP8 metadata is disabled.
FP8GlobalStateManager.is_first_fp8_module() # Consume first module token.

def forward(self, inp, weight, bias, is_first_microbatch: bool | None = None, is_grad_enabled: bool = False):
tensor_inputs = tuple(filter(lambda t: isinstance(t, torch.Tensor), (inp, weight, bias)))
# See [NOTE] Enable grad within context
Expand Down Expand Up @@ -458,6 +468,36 @@ def _get_te_wrapper_string():
return TE_CTX_STR


def te_sync_fp8_meta_bwd_meta():
pass


def te_sync_fp8_meta_bwd_impl():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved


te_sync_fp8_meta_bwd = transformer_engine_ex.register_operator(
"te_sync_fp8_meta_bwd", meta=te_sync_fp8_meta_bwd_meta, fn=te_sync_fp8_meta_bwd_impl
)


def _transformer_engine_bwd_fp8_meta_sync(fw_extrace, bw_extrace):
if TE_VERSION_1_6_PLUS:
# See doc of `_insert_bwd_fp8_meta_sync` for more details.
_insert_bwd_fp8_meta_sync(bw_extrace)
else:
# See doc of `_rearrange_transformer_engine_linear` for more details.
_rearrange_transformer_engine_linear(fw_extrace, bw_extrace)


def _insert_bwd_fp8_meta_sync(bw_extrace):
# This functions insert the symbol `te_sync_fp8_meta_bwd` to the end of the backward
# trace which takes care of syncing and updating the FP8 metadata for backward tensors.
# See NOTE: Backward FP8 metadata sync
bwd_idx = len(bw_extrace.bound_symbols) - 1
bw_extrace.bound_symbols.insert(bwd_idx + 1, te_sync_fp8_meta_bwd.bind(output=None))


def _rearrange_transformer_engine_linear(fw_extrace, bw_extrace):
"""
Rearrange the TransformerEngine linear symbols `te_linear_*` in forward trace
Expand Down
51 changes: 35 additions & 16 deletions thunder/tests/distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
from thunder.tests.framework import TorchExecutor, nvFuserExecutor
from thunder.tests.framework import instantiate

from thunder.executors.transformer_engineex import transformer_engine_ex, TE_AVAILABLE
from thunder.executors.transformer_engineex import (
transformer_engine_ex,
TE_AVAILABLE,
TE_VERSION_1_6_PLUS,
te_sync_fp8_meta_bwd,
)


is_fp8_supported: bool = False
# This will be correctly updated below when TE Engine is installed
Expand Down Expand Up @@ -1502,21 +1508,34 @@ def _test_ddp_transformer_engine_llama_sanity(input_data):
fwd_exec_trace = thunder.last_traces(jit_model)[-1]
bwd_exec_trace = thunder.last_backward_traces(jit_model)[-1]

# Verify that the first te_linear in fwd_exec_trace is the
# last one in bwd_exec_tarce.
# We verify that by managing the `ctx` (CollectionProxy) output by `te_linear` which is
# passed to backward.
# As CollectionProxy don't implement __eq__, we verify them by name.
first_ctx_name = None
for bsym in fwd_exec_trace.bound_symbols:
if bsym.sym.name.startswith("te_linear"):
first_ctx_name = bsym.output[1].name
break

for bsym in reversed(bwd_exec_trace.bound_symbols):
if bsym.sym.name.startswith("te_functional"):
assert first_ctx_name == bsym.args[-1].name, (first_ctx_name, bsym.args[-1].name)
break
if TE_VERSION_1_6_PLUS:
# Verify that the symbol to sync backward
# fp8 metadata is present in backward trace.
found_bwd_sync_symbol = False
for bsym in reversed(bwd_exec_trace.bound_symbols):
if bsym.sym.id == te_sync_fp8_meta_bwd.id:
found_bwd_sync_symbol = True
break

if not found_bwd_sync_symbol:
raise RuntimeError("Backward sync symbol not found.")
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

else:
# Verify that the first te_linear in fwd_exec_trace is the
# last one in bwd_exec_tarce.
# We verify that by managing the `ctx` (CollectionProxy) output by `te_linear` which is
# passed to backward.
# As CollectionProxy don't implement __eq__, we verify them by name.
first_ctx_name = None
for bsym in fwd_exec_trace.bound_symbols:
if bsym.sym.name.startswith("te_linear"):
first_ctx_name = bsym.output[1].name
break

for bsym in reversed(bwd_exec_trace.bound_symbols):
if bsym.sym.name.startswith("te_functional"):
assert first_ctx_name == bsym.args[-1].name, (first_ctx_name, bsym.args[-1].name)
break
except Exception as e:
sanity_exceptions.append(e)

Expand Down