diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 9dac1dead..43a0d4152 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -242,11 +242,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat bw_extrace = sort_waits(bw_extrace) # Importing here to avoid cyclical dependencies in future. - from thunder.executors.transformer_engineex import _rearrange_transformer_engine_linear, transformer_engine_ex + from thunder.executors.transformer_engineex import _transformer_engine_bwd_fp8_meta_sync, transformer_engine_ex if transformer_engine_ex in compile_data.executors_list: - # NOTE: `_rearrange_transformer_engine_linear` mutates `fw_extrace`. - _rearrange_transformer_engine_linear(fw_extrace, bw_extrace) + # NOTE: `_transformer_engine_bwd_fp8_meta_sync` may mutate `fw_extrace` or `bw_extrace`. + _transformer_engine_bwd_fp8_meta_sync(fw_extrace, bw_extrace) fw_extrace = del_last_used(fw_extrace) fw_traces.append(fw_extrace) diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index 8e94ef03b..42aa19bcc 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -164,6 +164,16 @@ def __init__(self, in_features: int, out_features: int) -> None: # Required by `get_fp8_weights_scratchpad` self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) + if TE_VERSION_1_6_PLUS: + # NOTE: Backward FP8 metadata sync + # TransformerEngine v1.6 onwards, we control the sync and update of FP8 metadata for FP8 tensors + # tied to backward pass (i.e. the gradient tensors) + # Also, note that the forward tensor metadata sync occurs at the exit of `fp8_autocast` context manager + # which is not controlled by us. + # + # We consume the `is_first_fp8_module` so that the automatic sync for FP8 metadata is disabled. + FP8GlobalStateManager.is_first_fp8_module() # Consume first module token. + def forward(self, inp, weight, bias, is_first_microbatch: bool | None = None, is_grad_enabled: bool = False): tensor_inputs = tuple(filter(lambda t: isinstance(t, torch.Tensor), (inp, weight, bias))) # See [NOTE] Enable grad within context @@ -458,6 +468,36 @@ def _get_te_wrapper_string(): return TE_CTX_STR +def te_sync_fp8_meta_bwd_meta(): + pass + + +def te_sync_fp8_meta_bwd_impl(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + + +te_sync_fp8_meta_bwd = transformer_engine_ex.register_operator( + "te_sync_fp8_meta_bwd", meta=te_sync_fp8_meta_bwd_meta, fn=te_sync_fp8_meta_bwd_impl +) + + +def _transformer_engine_bwd_fp8_meta_sync(fw_extrace, bw_extrace): + if TE_VERSION_1_6_PLUS: + # See doc of `_insert_bwd_fp8_meta_sync` for more details. + _insert_bwd_fp8_meta_sync(bw_extrace) + else: + # See doc of `_rearrange_transformer_engine_linear` for more details. + _rearrange_transformer_engine_linear(fw_extrace, bw_extrace) + + +def _insert_bwd_fp8_meta_sync(bw_extrace): + # This functions insert the symbol `te_sync_fp8_meta_bwd` to the end of the backward + # trace which takes care of syncing and updating the FP8 metadata for backward tensors. + # See NOTE: Backward FP8 metadata sync + bwd_idx = len(bw_extrace.bound_symbols) - 1 + bw_extrace.bound_symbols.insert(bwd_idx + 1, te_sync_fp8_meta_bwd.bind(output=None)) + + def _rearrange_transformer_engine_linear(fw_extrace, bw_extrace): """ Rearrange the TransformerEngine linear symbols `te_linear_*` in forward trace diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 85266d709..f61edcba4 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -30,7 +30,13 @@ from thunder.tests.framework import TorchExecutor, nvFuserExecutor from thunder.tests.framework import instantiate -from thunder.executors.transformer_engineex import transformer_engine_ex, TE_AVAILABLE +from thunder.executors.transformer_engineex import ( + transformer_engine_ex, + TE_AVAILABLE, + TE_VERSION_1_6_PLUS, + te_sync_fp8_meta_bwd, +) + is_fp8_supported: bool = False # This will be correctly updated below when TE Engine is installed @@ -1501,21 +1507,31 @@ def _test_ddp_transformer_engine_llama_sanity(input_data): fwd_exec_trace = thunder.last_traces(jit_model)[-1] bwd_exec_trace = thunder.last_backward_traces(jit_model)[-1] - # Verify that the first te_linear in fwd_exec_trace is the - # last one in bwd_exec_tarce. - # We verify that by managing the `ctx` (CollectionProxy) output by `te_linear` which is - # passed to backward. - # As CollectionProxy don't implement __eq__, we verify them by name. - first_ctx_name = None - for bsym in fwd_exec_trace.bound_symbols: - if bsym.sym.name.startswith("te_linear"): - first_ctx_name = bsym.output[1].name - break - - for bsym in reversed(bwd_exec_trace.bound_symbols): - if bsym.sym.name.startswith("te_functional"): - assert first_ctx_name == bsym.args[-1].name, (first_ctx_name, bsym.args[-1].name) - break + if TE_VERSION_1_6_PLUS: + # Verify that the symbol to sync backward + # fp8 metadata is present in backward trace. + for bsym in reversed(bwd_exec_trace.bound_symbols): + if bsym.sym.id == te_sync_fp8_meta_bwd.id: + break + else: + raise RuntimeError("Backward sync symbol not found.") + + else: + # Verify that the first te_linear in fwd_exec_trace is the + # last one in bwd_exec_tarce. + # We verify that by managing the `ctx` (CollectionProxy) output by `te_linear` which is + # passed to backward. + # As CollectionProxy don't implement __eq__, we verify them by name. + first_ctx_name = None + for bsym in fwd_exec_trace.bound_symbols: + if bsym.sym.name.startswith("te_linear"): + first_ctx_name = bsym.output[1].name + break + + for bsym in reversed(bwd_exec_trace.bound_symbols): + if bsym.sym.name.startswith("te_functional"): + assert first_ctx_name == bsym.args[-1].name, (first_ctx_name, bsym.args[-1].name) + break except Exception as e: sanity_exceptions.append(e)