From e3d253aca6ea585ce86f37016ca876a5d6e58d98 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 6 May 2024 11:49:21 +0200 Subject: [PATCH 1/4] te: own fp8 sync comms --- thunder/benchmarks/test_benchmark_litgpt.py | 18 ++++---- thunder/executors/transformer_engineex.py | 49 +++++++++------------ 2 files changed, 30 insertions(+), 37 deletions(-) diff --git a/thunder/benchmarks/test_benchmark_litgpt.py b/thunder/benchmarks/test_benchmark_litgpt.py index f5b48cf1e..0a8e80cf1 100644 --- a/thunder/benchmarks/test_benchmark_litgpt.py +++ b/thunder/benchmarks/test_benchmark_litgpt.py @@ -16,6 +16,7 @@ from collections import defaultdict import os import subprocess +from subprocess import PIPE, Popen import json import pandas as pd from datetime import datetime @@ -133,7 +134,10 @@ def run_benchmark(self, kwargs): subprocess_cmd.extend(command_list) print(f'Running {" ".join(subprocess_cmd)!r}') - proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True) + # proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True) + with Popen(subprocess_cmd, stdout=PIPE, bufsize=1, universal_newlines=True) as proc_output: + for line in proc_output.stdout: + print(line, end="") # process line here self.perf_metrics_dict = {} if os.path.exists(self.json_file_path): @@ -220,17 +224,15 @@ def tearDownClass(cls): @parameterized.product( distributed_mode=("fsdp",), - shard_mode=("zero2",), - model_name=("Llama-2-7b-hf",), - micro_batch_size=( - 1, - 4, - ), + shard_mode=("zero3",), + model_name=("pythia-14m",), + micro_batch_size=(1,), compile=( "eager", - "inductor", + # "inductor", "thunder", "thunder_inductor", + "thunder_inductor_transformerengine", ), ) def test(self, **kwargs): diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index 2b96100b4..ae3034d01 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -220,6 +220,7 @@ def __init__(self, in_features: int, out_features: int) -> None: self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) def forward(self, inp, weight, bias, is_first_microbatch: bool | None = None, is_grad_enabled: bool = False): + FP8GlobalStateManager.is_first_fp8_module() # Consume first module. tensor_inputs = tuple(filter(lambda t: isinstance(t, torch.Tensor), (inp, weight, bias))) # See [NOTE] Enable grad within context # TE backward depends on `requires_grad` to compute grads. @@ -513,6 +514,19 @@ def _get_te_wrapper_string(): return TE_CTX_STR +def te_sync_fp8_meta_bwd_meta(): + pass + + +def te_sync_fp8_meta_bwd_impl(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + + +te_sync_fp8_meta_bwd = transformer_engine_ex.register_operator( + "te_sync_fp8_meta_bwd", meta=te_sync_fp8_meta_bwd_meta, fn=te_sync_fp8_meta_bwd_impl +) + + def _rearrange_transformer_engine_linear(fw_extrace, bw_extrace): """ Rearrange the TransformerEngine linear symbols `te_linear_*` in forward trace @@ -567,35 +581,12 @@ def _rearrange_transformer_engine_linear(fw_extrace, bw_extrace): (t6861, t6862, _) = te_functional_linear_backward(t6859, (i304, i305, i306), (i307, i308), None, ctx_te_1) """ # Get the ctx name for the last `te_functional_linear_backward`. - bwd_bsym_ctx = None - for _, bsym in enumerate(reversed(bw_extrace.bound_symbols)): + bwd_idx = None + len_bound_symbols = len(bw_extrace.bound_symbols) + for idx, bsym in enumerate(reversed(bw_extrace.bound_symbols), start=1): if bsym.sym.id == te_functional_linear_backward.id: - bwd_bsym_ctx = bsym.args[-1].name + bwd_idx = len_bound_symbols - idx break - first_sym_idx = None - detected_first_sym_idx = None - # Find the first `te_linear` in forward trace - # and the position of `te_linear` which has the last `ctx_name` - # in backward. - for idx, bsym in enumerate(fw_extrace.bound_symbols): - # Forward symbols are generated on the fly so we don't - # have access here. - # Instead we check for the executor field. - if bsym.sym.executor == transformer_engine_ex: - # Sanity check. - assert "te_linear" in bsym.sym.name - if first_sym_idx is None: - first_sym_idx = idx - if bsym.output[-1].name == bwd_bsym_ctx: - detected_first_sym_idx = idx - break - - # If the first `te_linear` is not same as that one that should be - # we move it to be the first one. - if detected_first_sym_idx != first_sym_idx: - # Move the symbol to be the first `te_linear`. - fwd_bsyms = fw_extrace.bound_symbols - sym_to_swap = fwd_bsyms[detected_first_sym_idx] - del fwd_bsyms[detected_first_sym_idx] - fwd_bsyms.insert(first_sym_idx, sym_to_swap) + if bwd_idx is not None: + bw_extrace.bound_symbols.insert(bwd_idx + 1, te_sync_fp8_meta_bwd.bind(output=None)) From 07312cd09cb10823b85e24d08226d0622ba6a299 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 8 May 2024 13:50:49 +0200 Subject: [PATCH 2/4] update to work with current stable and upcoming releases --- thunder/executors/torch_autograd.py | 6 +-- thunder/executors/transformer_engineex.py | 63 ++++++++++++++++++++--- thunder/tests/distributed/test_ddp.py | 51 ++++++++++++------ 3 files changed, 94 insertions(+), 26 deletions(-) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index a952d3ba4..8e25de207 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -202,11 +202,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat bw_extrace = sort_waits(bw_extrace) # Importing here to avoid cyclical dependencies in future. - from thunder.executors.transformer_engineex import _rearrange_transformer_engine_linear, transformer_engine_ex + from thunder.executors.transformer_engineex import _transformer_engine_bwd_fp8_meta_sync, transformer_engine_ex if transformer_engine_ex in compile_data.executors_list: - # NOTE: `_rearrange_transformer_engine_linear` mutates `fw_extrace`. - _rearrange_transformer_engine_linear(fw_extrace, bw_extrace) + # NOTE: `_transformer_engine_bwd_fp8_meta_sync` may mutate `fw_extrace` or `bw_extrace`. + _transformer_engine_bwd_fp8_meta_sync(fw_extrace, bw_extrace) fw_extrace = del_last_used(fw_extrace) fw_traces.append(fw_extrace) diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index 8932d148a..42aa19bcc 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -164,8 +164,17 @@ def __init__(self, in_features: int, out_features: int) -> None: # Required by `get_fp8_weights_scratchpad` self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) + if TE_VERSION_1_6_PLUS: + # NOTE: Backward FP8 metadata sync + # TransformerEngine v1.6 onwards, we control the sync and update of FP8 metadata for FP8 tensors + # tied to backward pass (i.e. the gradient tensors) + # Also, note that the forward tensor metadata sync occurs at the exit of `fp8_autocast` context manager + # which is not controlled by us. + # + # We consume the `is_first_fp8_module` so that the automatic sync for FP8 metadata is disabled. + FP8GlobalStateManager.is_first_fp8_module() # Consume first module token. + def forward(self, inp, weight, bias, is_first_microbatch: bool | None = None, is_grad_enabled: bool = False): - FP8GlobalStateManager.is_first_fp8_module() # Consume first module. tensor_inputs = tuple(filter(lambda t: isinstance(t, torch.Tensor), (inp, weight, bias))) # See [NOTE] Enable grad within context # TE backward depends on `requires_grad` to compute grads. @@ -472,6 +481,23 @@ def te_sync_fp8_meta_bwd_impl(): ) +def _transformer_engine_bwd_fp8_meta_sync(fw_extrace, bw_extrace): + if TE_VERSION_1_6_PLUS: + # See doc of `_insert_bwd_fp8_meta_sync` for more details. + _insert_bwd_fp8_meta_sync(bw_extrace) + else: + # See doc of `_rearrange_transformer_engine_linear` for more details. + _rearrange_transformer_engine_linear(fw_extrace, bw_extrace) + + +def _insert_bwd_fp8_meta_sync(bw_extrace): + # This functions insert the symbol `te_sync_fp8_meta_bwd` to the end of the backward + # trace which takes care of syncing and updating the FP8 metadata for backward tensors. + # See NOTE: Backward FP8 metadata sync + bwd_idx = len(bw_extrace.bound_symbols) - 1 + bw_extrace.bound_symbols.insert(bwd_idx + 1, te_sync_fp8_meta_bwd.bind(output=None)) + + def _rearrange_transformer_engine_linear(fw_extrace, bw_extrace): """ Rearrange the TransformerEngine linear symbols `te_linear_*` in forward trace @@ -526,12 +552,35 @@ def _rearrange_transformer_engine_linear(fw_extrace, bw_extrace): (t6861, t6862, _) = te_functional_linear_backward(t6859, (i304, i305, i306), (i307, i308), None, ctx_te_1) """ # Get the ctx name for the last `te_functional_linear_backward`. - bwd_idx = None - len_bound_symbols = len(bw_extrace.bound_symbols) - for idx, bsym in enumerate(reversed(bw_extrace.bound_symbols), start=1): + bwd_bsym_ctx = None + for _, bsym in enumerate(reversed(bw_extrace.bound_symbols)): if bsym.sym.id == te_functional_linear_backward.id: - bwd_idx = len_bound_symbols - idx + bwd_bsym_ctx = bsym.args[-1].name break - if bwd_idx is not None: - bw_extrace.bound_symbols.insert(bwd_idx + 1, te_sync_fp8_meta_bwd.bind(output=None)) + first_sym_idx = None + detected_first_sym_idx = None + # Find the first `te_linear` in forward trace + # and the position of `te_linear` which has the last `ctx_name` + # in backward. + for idx, bsym in enumerate(fw_extrace.bound_symbols): + # Forward symbols are generated on the fly so we don't + # have access here. + # Instead we check for the executor field. + if bsym.sym.executor == transformer_engine_ex: + # Sanity check. + assert "te_linear" in bsym.sym.name + if first_sym_idx is None: + first_sym_idx = idx + if bsym.output[-1].name == bwd_bsym_ctx: + detected_first_sym_idx = idx + break + + # If the first `te_linear` is not same as that one that should be + # we move it to be the first one. + if detected_first_sym_idx != first_sym_idx: + # Move the symbol to be the first `te_linear`. + fwd_bsyms = fw_extrace.bound_symbols + sym_to_swap = fwd_bsyms[detected_first_sym_idx] + del fwd_bsyms[detected_first_sym_idx] + fwd_bsyms.insert(first_sym_idx, sym_to_swap) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index d76a06ca6..6b837de8d 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -30,7 +30,13 @@ from thunder.tests.framework import TorchExecutor, nvFuserExecutor from thunder.tests.framework import instantiate -from thunder.executors.transformer_engineex import transformer_engine_ex, TE_AVAILABLE +from thunder.executors.transformer_engineex import ( + transformer_engine_ex, + TE_AVAILABLE, + TE_VERSION_1_6_PLUS, + te_sync_fp8_meta_bwd, +) + is_fp8_supported: bool = False # This will be correctly updated below when TE Engine is installed @@ -1502,21 +1508,34 @@ def _test_ddp_transformer_engine_llama_sanity(input_data): fwd_exec_trace = thunder.last_traces(jit_model)[-1] bwd_exec_trace = thunder.last_backward_traces(jit_model)[-1] - # Verify that the first te_linear in fwd_exec_trace is the - # last one in bwd_exec_tarce. - # We verify that by managing the `ctx` (CollectionProxy) output by `te_linear` which is - # passed to backward. - # As CollectionProxy don't implement __eq__, we verify them by name. - first_ctx_name = None - for bsym in fwd_exec_trace.bound_symbols: - if bsym.sym.name.startswith("te_linear"): - first_ctx_name = bsym.output[1].name - break - - for bsym in reversed(bwd_exec_trace.bound_symbols): - if bsym.sym.name.startswith("te_functional"): - assert first_ctx_name == bsym.args[-1].name, (first_ctx_name, bsym.args[-1].name) - break + if TE_VERSION_1_6_PLUS: + # Verify that the symbol to sync backward + # fp8 metadata is present in backward trace. + found_bwd_sync_symbol = False + for bsym in reversed(bwd_exec_trace.bound_symbols): + if bsym.sym.id == te_sync_fp8_meta_bwd.id: + found_bwd_sync_symbol = True + break + + if not found_bwd_sync_symbol: + raise RuntimeError("Backward sync symbol not found.") + + else: + # Verify that the first te_linear in fwd_exec_trace is the + # last one in bwd_exec_tarce. + # We verify that by managing the `ctx` (CollectionProxy) output by `te_linear` which is + # passed to backward. + # As CollectionProxy don't implement __eq__, we verify them by name. + first_ctx_name = None + for bsym in fwd_exec_trace.bound_symbols: + if bsym.sym.name.startswith("te_linear"): + first_ctx_name = bsym.output[1].name + break + + for bsym in reversed(bwd_exec_trace.bound_symbols): + if bsym.sym.name.startswith("te_functional"): + assert first_ctx_name == bsym.args[-1].name, (first_ctx_name, bsym.args[-1].name) + break except Exception as e: sanity_exceptions.append(e) From 9c8e4fcd34965b76a8059c62b56b64fd0ddc97dc Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 8 May 2024 14:00:01 +0200 Subject: [PATCH 3/4] undo changes to test_benchmark_litgpt.py --- thunder/benchmarks/test_benchmark_litgpt.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/thunder/benchmarks/test_benchmark_litgpt.py b/thunder/benchmarks/test_benchmark_litgpt.py index 0a8e80cf1..f5b48cf1e 100644 --- a/thunder/benchmarks/test_benchmark_litgpt.py +++ b/thunder/benchmarks/test_benchmark_litgpt.py @@ -16,7 +16,6 @@ from collections import defaultdict import os import subprocess -from subprocess import PIPE, Popen import json import pandas as pd from datetime import datetime @@ -134,10 +133,7 @@ def run_benchmark(self, kwargs): subprocess_cmd.extend(command_list) print(f'Running {" ".join(subprocess_cmd)!r}') - # proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True) - with Popen(subprocess_cmd, stdout=PIPE, bufsize=1, universal_newlines=True) as proc_output: - for line in proc_output.stdout: - print(line, end="") # process line here + proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True) self.perf_metrics_dict = {} if os.path.exists(self.json_file_path): @@ -224,15 +220,17 @@ def tearDownClass(cls): @parameterized.product( distributed_mode=("fsdp",), - shard_mode=("zero3",), - model_name=("pythia-14m",), - micro_batch_size=(1,), + shard_mode=("zero2",), + model_name=("Llama-2-7b-hf",), + micro_batch_size=( + 1, + 4, + ), compile=( "eager", - # "inductor", + "inductor", "thunder", "thunder_inductor", - "thunder_inductor_transformerengine", ), ) def test(self, **kwargs): From 44f47cc6aebac84bf16bb1f6bb489437ebab8478 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Mon, 13 May 2024 13:11:46 +0200 Subject: [PATCH 4/4] Update thunder/tests/distributed/test_ddp.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- thunder/tests/distributed/test_ddp.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index a9049022a..f61edcba4 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -1510,13 +1510,10 @@ def _test_ddp_transformer_engine_llama_sanity(input_data): if TE_VERSION_1_6_PLUS: # Verify that the symbol to sync backward # fp8 metadata is present in backward trace. - found_bwd_sync_symbol = False for bsym in reversed(bwd_exec_trace.bound_symbols): if bsym.sym.id == te_sync_fp8_meta_bwd.id: - found_bwd_sync_symbol = True break - - if not found_bwd_sync_symbol: + else: raise RuntimeError("Backward sync symbol not found.") else: