Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 31 additions & 50 deletions transformer_engine/pytorch/ops/fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
def forward(
func_ctx: Optional[torch.autograd.function.FunctionCtx],
input_: torch.Tensor,
forward_ops: list[tuple[FusibleOperation, list[int]]],
backward_ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: list[BasicOperation],
fuser: OperationFuser,
basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool,
num_params: int,
num_extra_inputs: int,
*params_and_extra_inputs: torch.nn.Parameter,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass
Expand All @@ -78,20 +74,12 @@ def forward(
Context for PyTorch autograd function
input_: torch.Tensor
Input to first operation in pipeline
forward_ops: list of tuple
Forward pass operations and the indices of the
corresponding basic operations. The order should match
basic_ops.
backward_ops: list of tuple
Backward pass operations and the indices of the
corresponding basic operations. The order should be the
reverse of basic_ops.
basic_ops: list of BasicOperation
Basic operations
fuser: OperationFuser
Container for the pipeline of operations to run
basic_op_kwargs: list of dict
Keyword arguments to BasicOperation
num_params: int
Number of parameter tensors to include in autograd graph.
is_grad_enabled: bool
Should context be saved for backward
*params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs.
Expand All @@ -106,26 +94,20 @@ def forward(
"""

# Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))]
basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]

# Unflatten list of parameters and extra tensor inputs
if len(params_and_extra_inputs) != num_params + num_extra_inputs:
raise ValueError(
f"Expected {num_params + num_extra_inputs} extra tensor arguments "
f"({num_params} parameters, {num_extra_inputs} extra inputs), "
f"but got {len(params_and_extra_inputs)}"
)
_, extra_inputs = _split_tuple(params_and_extra_inputs, num_params)
extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
basic_op_extra_inputs = []
for op in basic_ops:
for op in fuser._basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
basic_op_extra_inputs.append(xs)

# Apply forward ops
x = input_
requires_grad = is_grad_enabled and x.requires_grad
extra_outputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in forward_ops:
extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in fuser._forward_ops:

# Check if backward op is required
if is_grad_enabled:
Expand All @@ -143,9 +125,10 @@ def forward(

# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
next_ops = [
basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs
fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None
for idx in basic_op_idxs
]
x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
Expand All @@ -165,7 +148,7 @@ def forward(
extra_outputs_flat = []
for idx, ys in enumerate(extra_outputs):
ys = list(ys)
num_extra_outputs = basic_ops[idx].num_extra_outputs
num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
if len(ys) != num_extra_outputs:
raise RuntimeError(
f"Expected op {idx} to generate "
Expand All @@ -189,11 +172,11 @@ def forward(
func_ctx.save_for_backward(*to_save)

# Other context
func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops
func_ctx.backward_ops = fuser._backward_ops
func_ctx.basic_ops = fuser._basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops]
func_ctx.num_extra_inputs = num_extra_inputs
func_ctx.basic_op_num_params = fuser._num_list_basic_op_params
func_ctx.num_extra_inputs = fuser._num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()

Expand Down Expand Up @@ -293,13 +276,9 @@ def backward(

return (
dx, # input_
None, # forward_ops
None, # backward_ops
None, # basic_ops
None, # fuser
None, # basic_op_kwargs
None, # is_grad_enabled
None, # num_params
None, # num_extra_inputs
*grad_params_flat,
*grad_extra_inputs_flat,
)
Expand Down Expand Up @@ -346,6 +325,10 @@ def __init__(
if fuse_ops:
self.fuse_ops()

# Flatten list of parameters
self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()]
self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops]

@classmethod
def _fuse_forward_ops(
cls,
Expand Down Expand Up @@ -378,17 +361,19 @@ def __call__(
*extra_inputs: torch.Tensor,
basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
# Verify extra input count
if len(extra_inputs) != self._num_extra_inputs:
raise ValueError(
f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}"
)

# Initialization before forward pass
for op in self._basic_ops:
op.pre_forward()

# Canonicalize op kwargs
if basic_op_kwargs is None:
basic_op_kwargs = [{} for _ in range(len(self._basic_ops))]

# Flatten list of parameters
params = [param for op in self._basic_ops for param in op.parameters()]
basic_op_kwargs = [{}] * self._num_basic_ops

# Fuser forward pass
is_grad_enabled = torch.is_grad_enabled()
Expand All @@ -400,14 +385,10 @@ def __call__(
args = [None]
args += (
input,
self._forward_ops,
self._backward_ops,
self._basic_ops,
self,
basic_op_kwargs,
is_grad_enabled,
len(params),
self._num_extra_inputs,
*params,
*self._basic_op_params,
*extra_inputs,
)
return forward_func(*args)
1 change: 1 addition & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8


@functools.lru_cache(maxsize=None)
def is_non_tn_fp8_gemm_supported() -> bool:
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
Expand Down