From 601709df1c6f0bd7fb35c7d1709a89289481c291 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 10 Jun 2025 23:15:19 +0000 Subject: [PATCH 1/6] Flatten basic op params during fuser init Signed-off-by: Jan Bielak (cherry picked from commit 949abe97070721b1da5117903067608250f5fb61) --- transformer_engine/pytorch/ops/fuser.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 8ff0242229..3c18e6e08a 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -346,6 +346,10 @@ def __init__( if fuse_ops: self.fuse_ops() + # Flatten list of parameters + self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()] + self._num_basic_op_params = len(self._basic_op_params) + @classmethod def _fuse_forward_ops( cls, @@ -385,10 +389,7 @@ def __call__( # Canonicalize op kwargs if basic_op_kwargs is None: - basic_op_kwargs = [{} for _ in range(len(self._basic_ops))] - - # Flatten list of parameters - params = [param for op in self._basic_ops for param in op.parameters()] + basic_op_kwargs = [{}] * self._num_basic_ops # Fuser forward pass is_grad_enabled = torch.is_grad_enabled() @@ -405,9 +406,9 @@ def __call__( self._basic_ops, basic_op_kwargs, is_grad_enabled, - len(params), + self._num_basic_op_params, self._num_extra_inputs, - *params, + *self._basic_op_params, *extra_inputs, ) return forward_func(*args) From 39f4a59f3decaf3dccbad65f820c4c1cf87797ef Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Wed, 11 Jun 2025 00:29:39 +0000 Subject: [PATCH 2/6] Add caching for is_non_tn_fp8_gemm_supported Signed-off-by: Jan Bielak (cherry picked from commit fd830ae24ffbd2d0727010b1a8a119ca72f61ce5) --- transformer_engine/pytorch/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 3abebdf1e4..8dee68a38e 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -440,6 +440,7 @@ def is_bf16_compatible() -> None: return torch.cuda.get_device_capability()[0] >= 8 +@functools.lru_cache(maxsize=None) def is_non_tn_fp8_gemm_supported() -> bool: """Checks whether the device supports non-TN layouts for FP8 GEMMs. From 147acd73a1afff4ce83d79c43a3ecdfe2122b8c0 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Wed, 11 Jun 2025 20:25:19 +0000 Subject: [PATCH 3/6] Pass fuser to _OperationFuserAutogradFunction.forward and moving computation to __init__ Signed-off-by: Jan Bielak (cherry picked from commit fd808991993958b670726896254b82fcb967fa07) --- transformer_engine/pytorch/ops/fuser.py | 91 +++++++++---------------- 1 file changed, 31 insertions(+), 60 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 3c18e6e08a..8ad596ca5b 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -61,13 +61,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): def forward( func_ctx: Optional[torch.autograd.function.FunctionCtx], input_: torch.Tensor, - forward_ops: list[tuple[FusibleOperation, list[int]]], - backward_ops: list[tuple[FusibleOperation, list[int]]], - basic_ops: list[BasicOperation], - basic_op_kwargs: list[dict[str, Any]], - is_grad_enabled: bool, - num_params: int, - num_extra_inputs: int, + fuser: OperationFuser, *params_and_extra_inputs: torch.nn.Parameter, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass @@ -78,20 +72,8 @@ def forward( Context for PyTorch autograd function input_: torch.Tensor Input to first operation in pipeline - forward_ops: list of tuple - Forward pass operations and the indices of the - corresponding basic operations. The order should match - basic_ops. - backward_ops: list of tuple - Backward pass operations and the indices of the - corresponding basic operations. The order should be the - reverse of basic_ops. - basic_ops: list of BasicOperation - Basic operations - basic_op_kwargs: list of dict - Keyword arguments to BasicOperation - num_params: int - Number of parameter tensors to include in autograd graph. + fuser: OperationFuser + Container for the pipeline of operations to run *params_and_extra_inputs: torch.Tensor Other tensor inputs to include in autograd graph. Consists of parameter tensors, followed by extra operation inputs. @@ -106,29 +88,23 @@ def forward( """ # Operation autograd contexts - basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))] + basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)] # Unflatten list of parameters and extra tensor inputs - if len(params_and_extra_inputs) != num_params + num_extra_inputs: - raise ValueError( - f"Expected {num_params + num_extra_inputs} extra tensor arguments " - f"({num_params} parameters, {num_extra_inputs} extra inputs), " - f"but got {len(params_and_extra_inputs)}" - ) - _, extra_inputs = _split_tuple(params_and_extra_inputs, num_params) + extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs:] basic_op_extra_inputs = [] - for op in basic_ops: + for op in fuser._basic_ops: xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) basic_op_extra_inputs.append(xs) # Apply forward ops x = input_ - requires_grad = is_grad_enabled and x.requires_grad - extra_outputs = [None for _ in range(len(basic_ops))] - for op, basic_op_idxs in forward_ops: + requires_grad = fuser._is_grad_enabled and x.requires_grad + extra_outputs = [None] * fuser._num_basic_ops + for op, basic_op_idxs in fuser._forward_ops: # Check if backward op is required - if is_grad_enabled: + if fuser._is_grad_enabled: if not requires_grad: requires_grad = any(param.requires_grad for param in op.parameters()) if not requires_grad: @@ -143,9 +119,9 @@ def forward( # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] - prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] + prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] next_ops = [ - basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs + fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None for idx in basic_op_idxs ] x, fused_op_extra_outputs = op.fuser_forward( [basic_op_ctxs[idx] for idx in basic_op_idxs], @@ -153,7 +129,7 @@ def forward( basic_op_extra_inputs=extra_inputs, basic_op_prev_ops=prev_ops, basic_op_next_ops=next_ops, - basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], + basic_op_kwargs=[fuser._basic_op_kwargs[idx] for idx in basic_op_idxs], ) x.requires_grad_(requires_grad=requires_grad) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): @@ -165,7 +141,7 @@ def forward( extra_outputs_flat = [] for idx, ys in enumerate(extra_outputs): ys = list(ys) - num_extra_outputs = basic_ops[idx].num_extra_outputs + num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs if len(ys) != num_extra_outputs: raise RuntimeError( f"Expected op {idx} to generate " @@ -175,7 +151,7 @@ def forward( extra_outputs_flat.extend(ys) # Save context for backward pass - if is_grad_enabled: + if fuser._is_grad_enabled: # Flatten list of saved tensors to_save = [] @@ -189,11 +165,11 @@ def forward( func_ctx.save_for_backward(*to_save) # Other context - func_ctx.backward_ops = backward_ops - func_ctx.basic_ops = basic_ops + func_ctx.backward_ops = fuser._backward_ops + func_ctx.basic_ops = fuser._basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops] - func_ctx.num_extra_inputs = num_extra_inputs + func_ctx.basic_op_num_params = fuser._num_list_basic_op_params + func_ctx.num_extra_inputs = fuser._num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() @@ -293,13 +269,7 @@ def backward( return ( dx, # input_ - None, # forward_ops - None, # backward_ops - None, # basic_ops - None, # basic_op_kwargs - None, # is_grad_enabled - None, # num_params - None, # num_extra_inputs + None, # fuser *grad_params_flat, *grad_extra_inputs_flat, ) @@ -348,7 +318,7 @@ def __init__( # Flatten list of parameters self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()] - self._num_basic_op_params = len(self._basic_op_params) + self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops] @classmethod def _fuse_forward_ops( @@ -382,6 +352,12 @@ def __call__( *extra_inputs: torch.Tensor, basic_op_kwargs: Optional[list[dict[str, Any]]] = None, ) -> torch.Tensor | tuple[torch.Tensor, ...]: + # Verify extra input count + if len(extra_inputs) != self._num_extra_inputs: + raise ValueError( + f"Expected {self._num_extra_inputs} extra inputs " + f"but got {len(extra_inputs)}" + ) # Initialization before forward pass for op in self._basic_ops: @@ -390,10 +366,11 @@ def __call__( # Canonicalize op kwargs if basic_op_kwargs is None: basic_op_kwargs = [{}] * self._num_basic_ops + self._basic_op_kwargs = basic_op_kwargs # Fuser forward pass - is_grad_enabled = torch.is_grad_enabled() - if is_grad_enabled: + self._is_grad_enabled = torch.is_grad_enabled() + if self._is_grad_enabled: forward_func = _OperationFuserAutogradFunction.apply args = [] else: @@ -401,13 +378,7 @@ def __call__( args = [None] args += ( input, - self._forward_ops, - self._backward_ops, - self._basic_ops, - basic_op_kwargs, - is_grad_enabled, - self._num_basic_op_params, - self._num_extra_inputs, + self, *self._basic_op_params, *extra_inputs, ) From 03267a24a5ff72bb9382891bcbdf89a4525304cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Jun 2025 21:21:47 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/fuser.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 8ad596ca5b..1591031a84 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -91,7 +91,7 @@ def forward( basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)] # Unflatten list of parameters and extra tensor inputs - extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs:] + extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :] basic_op_extra_inputs = [] for op in fuser._basic_ops: xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) @@ -121,7 +121,8 @@ def forward( extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] next_ops = [ - fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None for idx in basic_op_idxs + fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None + for idx in basic_op_idxs ] x, fused_op_extra_outputs = op.fuser_forward( [basic_op_ctxs[idx] for idx in basic_op_idxs], @@ -355,9 +356,8 @@ def __call__( # Verify extra input count if len(extra_inputs) != self._num_extra_inputs: raise ValueError( - f"Expected {self._num_extra_inputs} extra inputs " - f"but got {len(extra_inputs)}" - ) + f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}" + ) # Initialization before forward pass for op in self._basic_ops: From 9b2ef2c3dc4c7d1cf8e1683e99ce1a4257b38628 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Wed, 11 Jun 2025 22:54:37 +0000 Subject: [PATCH 5/6] Pass basic_op_kwargs and is_grad_enabled as parameters rather than in fuser Signed-off-by: Jan Bielak --- transformer_engine/pytorch/ops/fuser.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 1591031a84..5e0cee3e80 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -62,6 +62,8 @@ def forward( func_ctx: Optional[torch.autograd.function.FunctionCtx], input_: torch.Tensor, fuser: OperationFuser, + basic_op_kwargs: list[dict[str, Any]], + is_grad_enabled: bool, *params_and_extra_inputs: torch.nn.Parameter, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass @@ -74,6 +76,10 @@ def forward( Input to first operation in pipeline fuser: OperationFuser Container for the pipeline of operations to run + basic_op_kwargs: list of dict + Keyword arguments to BasicOperation + is_grad_enabled: bool + Should context be saved for backward *params_and_extra_inputs: torch.Tensor Other tensor inputs to include in autograd graph. Consists of parameter tensors, followed by extra operation inputs. @@ -91,7 +97,7 @@ def forward( basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)] # Unflatten list of parameters and extra tensor inputs - extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :] + extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs:] basic_op_extra_inputs = [] for op in fuser._basic_ops: xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) @@ -99,12 +105,12 @@ def forward( # Apply forward ops x = input_ - requires_grad = fuser._is_grad_enabled and x.requires_grad + requires_grad = is_grad_enabled and x.requires_grad extra_outputs = [None] * fuser._num_basic_ops for op, basic_op_idxs in fuser._forward_ops: # Check if backward op is required - if fuser._is_grad_enabled: + if is_grad_enabled: if not requires_grad: requires_grad = any(param.requires_grad for param in op.parameters()) if not requires_grad: @@ -130,7 +136,7 @@ def forward( basic_op_extra_inputs=extra_inputs, basic_op_prev_ops=prev_ops, basic_op_next_ops=next_ops, - basic_op_kwargs=[fuser._basic_op_kwargs[idx] for idx in basic_op_idxs], + basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], ) x.requires_grad_(requires_grad=requires_grad) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): @@ -152,7 +158,7 @@ def forward( extra_outputs_flat.extend(ys) # Save context for backward pass - if fuser._is_grad_enabled: + if is_grad_enabled: # Flatten list of saved tensors to_save = [] @@ -271,6 +277,8 @@ def backward( return ( dx, # input_ None, # fuser + None, # basic_op_kwargs + None, # is_grad_enabled *grad_params_flat, *grad_extra_inputs_flat, ) @@ -366,11 +374,10 @@ def __call__( # Canonicalize op kwargs if basic_op_kwargs is None: basic_op_kwargs = [{}] * self._num_basic_ops - self._basic_op_kwargs = basic_op_kwargs # Fuser forward pass - self._is_grad_enabled = torch.is_grad_enabled() - if self._is_grad_enabled: + is_grad_enabled = torch.is_grad_enabled() + if is_grad_enabled: forward_func = _OperationFuserAutogradFunction.apply args = [] else: @@ -379,6 +386,8 @@ def __call__( args += ( input, self, + basic_op_kwargs, + is_grad_enabled, *self._basic_op_params, *extra_inputs, ) From 7f768e93a0469f29611472ad5c6f2aeb5eac9c6f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Jun 2025 22:56:03 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/fuser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 5e0cee3e80..cf61d15eff 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -97,7 +97,7 @@ def forward( basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)] # Unflatten list of parameters and extra tensor inputs - extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs:] + extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :] basic_op_extra_inputs = [] for op in fuser._basic_ops: xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)