# Custom epilogue fusions for GEMMs

Note: this notebook requires a GPU with compute capability 100 or 103:

In [None]:
import cutlass_api

if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):
    print(f"This notebook requires a GPU with compute capability 100 or 103.\n{status.error}")
    import sys

    sys.exit(0)

The CUTLASS API provides flexible epilogue fusion support by allowing for the specification of an epilogue via high-level tensor operations that one would like to compose with an operation.

For those familiar with the legacy CUTLASS Python API's [epilogue visitor tree frontend](https://github.com/NVIDIA/cutlass/blob/a2439551c765c5393aebe557ee75d3a0412d2211/examples/python/deprecated/04_epilogue_visitor.ipynb), much of the interface is shared.

The CUTLASS API enables one to express an epilogue using a function operating at the `torch.Tensor`-level, and has tooling to automatically add this to kernels supporting the provided function. 

For example, in PyTorch one might write the following to compute a GEMM + epilogue:

In [None]:
import torch

torch.manual_seed(2025)

L, M, N, K = 1, 1024, 1024, 1024
A = torch.randn(L, M, K, device="cuda", dtype=torch.float16)
B = torch.randn(L, K, N, device="cuda", dtype=torch.float16)
C = torch.randn(L, M, N, device="cuda", dtype=torch.float16)

def my_epilogue(accum, C, alpha, beta, extra_scalar):
    Aux = (alpha * accum) + (beta * C)
    D = extra_scalar * Aux
    return D, Aux

alpha, beta, extra_scalar = 1.0, 2.0, 0.5
D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)


The CUTLASS API allows the same epilogue function `my_epilogue` to be used in GEMMs provided by the API.

To do so, one defines `EpilogueArguments` consisting of the epilogue function to compute (or a string representation of it) along with arguments corresponding to each input and output of the function (except for `accum`):

In [None]:
import cutlass_api
from cutlass_api.arguments import GemmArguments, EpilogueArguments

# Allocate buffers for D and Aux
D_, Aux_ = [torch.empty((L, M, N), device="cuda", dtype=torch.float16) for _ in range(2)]

epi_args = EpilogueArguments(my_epilogue, C=C, alpha=alpha, beta=beta, extra_scalar=extra_scalar, D=D_, Aux=Aux_)


These arguments can be added to `GemmArguments` and passed in to `get_kernels()` for use when retrieving compatible kernels:

In [None]:
args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)
kernels = cutlass_api.get_kernels(args)
assert len(kernels) > 0


Each of the kernels returned by `get_kernels` can be compiled and executed just the same with these new arguments, as it was in examples without
epilogue fusion. For example, using the first kernel:

In [None]:
kernels[0].run(args)

torch.testing.assert_close(D, D_)
torch.testing.assert_close(Aux, Aux_)


## How the epilogue fusion API works
To support specifying an epilogue via a Python function, a kernel needs some mechanism to:
1. Detect the operations in the epilogue function
2. Determine if the kernel can support the operations
3. Emit code to perform these operations within the kernel

Step 1 listed above does not depend on the kernel and its implementation (e.g., DSL), while steps 2 and 3 depend on the kernel and/or its implementation.

Thus, the CUTLASS API separates these components so that step 1 takes place at the API level and steps 2 and 3 take place in the kernel. This process is visualized below. We will walk through each step in greater detail.

```python
    +------------------------------------+
    | def epi(accum, alpha, beta, C):    |
    |   D = (accum * alpha) + (beta * C) |      1. Define epilogue via a Python function
    |   return D                         |
    +------------------------------------+
                      |
                      |
                      |
   GemmArguments(...,                           2. Pass epilogue function, operands, and outputs
     epilogue=EpilogueArguments(                   to EpilogueArguments constructor,
        epi, alpha=alpha, beta=beta, C=C))         and add this to the GemmArguments. Under the
                      |                            hood, this parses the Python AST of the
                      |                            epilogue function to produce a DAG of load,
                      |                            store, and compute nodes.
                      V
  +-----------------------------------------+ 
  |      Intermediate DAG representation    |
  |      ===============================    |
  |                                         |
  |                Store()                  |
  |                   |                     |
  |                 Add()                   |
  |                 /    \                  |
  |                /      \                 |
  |               /        \                |
  |            Mul()        Mul()           |
  |            /   \       /   \            |
  |  AccFetch()     |  Load(C)  \           |
  |                 |            \          |
  |       Load(alpha)           Load(beta)  |
  |                                         |
  +-----------------------------------------+
            /         |         \
           /          |          \              3. Individual kernel classes use the DAG representation
          /           |           \                to determine if the kernel class supports the DAG.
    Kernel 0      Kernel 1     Kernel 2            If so, the kernel class emits DSL-level operations
    epilogue      epilogue     epilogue            needed to compute the epilogue DAG alongside the
    emitter       emitter      emitter             basic operation of the kernel (e.g., GEMM).
       |              |            |
       |              |            |
       V              V            V
```

### Defining an epilogue via a Python function
Epilogue fusion patterns are defined by users in Python functions that perform Tensor-level operations -- using a `torch.Tensor` (for example) resulting from matrix multiplication, the function must be able to compute the desired results of the epilogue.

The structure of these functions is as follows:
```python
def custom_epi_name(accum, *args) -> Union[TensorType, tuple[TensorType]]:
  """
  :param accum: result of matrix multiplication, convolution, etc. before the epilogue
  :type accum: TensorType
  :param args: additional arguments to be used in the epilogue (e.g., aux tensors)
  :type args: list[Union[TensorType, ScalarType]]

  :returns: at least one tensor resulting from the operation of the epilogue
  :rtype: Union[TensorType, tuple[TensorType]]
  """
  # Do some compute
  return D # and potentially other values
```

The user defines a custom epilogue via a Python function that **must** do at least the following:
1. Take in a first positional argument named `accum` that represents the result of operation just before the epilogue is to be performed. For example, in a GEMM, `accum = A @ B`.
2. Return at least one tensor that results from computing the epilogue. Currently, the return list must contain at least one output named `D`, though this constraint may be loosened in the future.

Each additional argument following `accum` in the function definition is expected to be either a Tensor or scalar to be loaded. Each variable in the return statement represents a Tensor or scalar to be stored. The underlying implementation of the epilogue in the kernel will determine how operands are loaded and stored.

Compute operations are represented in static single assignment (SSA) form.
This means that each variable can be assigned exactly once.
Operations currently supported ares:
* Tensor-tensor elementwise addition, subtraction, multiplication, and division
* Scalar broadcasts via addition, subtraction, multiplication, and division
* Predefined elementwise activation functions (e.g., ReLU, sigmoid, tanh)

Operations that are not yet supported include:
* Row/column broadcasts (planned to be added soon)
* Reductions (planned to be added soon)
* Binary minimum and maximum functions (planned to be added soon)
If attempting to use these operations will result in no kernels being found in the call to `get_kernels`.

Violations to SSA or use of unexpected operators will be flagged with an exception when parsing the AST of the custom epilogue.

Examples of epilogues fitting these patterns are given below. We will show full, runnable examples at the end of this notebook.
```python
def relu_aux_store(accum, alpha, C):
  # Note that the function definition itself does not indicate the types and
  # ranks of alpha and C. Thus, one cannot tell whether the epilogue is performing
  # broadcasts or elementwise operations until actual arguments or metadata are
  # provided to the epilogue. See below for details.
  F = (accum * alpha) + (C * 2.0) # Constant beta of 2.0
  D = relu(F)
  return D, F

def aux_normalize(accum, aux):
  D = accum / aux
  return D
```

Additional information about each operand and output must be provided by the user when constructing `EpilogueArguments`, as we will discuss below. This additional information is necessary for fully defining the operations being performed -- without knowledge of whether `alpha` is a scalar or a Tensor, we cannot determine whether multiplication by `alpha` is a broadcasted or elementwise operation.

### Constructing epilogue arguments
`EpilogueArguments` encapsulate the arguments needed to determine the functional operation of a fused epilogue.

A user must provide in the construction of `EpilogueArguments` tensors for all operands and outputs of the epilogue. However, unlike arguments for basic operations (e.g., GEMM), the full set of operands needed to be specified for an epilogue pattern depends upon the custom epilogue defined by the user.

Therefore, `EpilogueArguments` is defined generically as taking in an `epilogue_fn` and additional `kwargs`. Under the hood, the AST for `epilogue_fn` is parsed to determine the operands and outputs of the epilogue. The user is required to provide in `kwargs` Tensors or scalars for all operands and outputs in the provided epilogue.

For example, with an epilogue of:
```python
def my_epi(accum, alpha, C, beta):
  F = (accum * alpha) + (C * beta)
  D = relu(F)
  return D, F
```
A user would need to construct epilogue arguments as follows:
```python
epi_args = EpilogueArguments(my_epi, alpha=..., C=..., beta=..., D=..., F=...)
```

After verifying that all required operands and outputs are present, the constructor to `EpilogueArguments` will perform additional passes on the AST of `epilogue_fn` using the provided inputs to generate an internal DAG representing the epilogue. This DAG structure is attached to `EpilogueArguments` for use as they are passed through a call to `get_kernels`.

### Discovering kernels that support the epilogue pattern

The call to `get_kernels(args)` will return any kernels that support the provided `GemmArguments`.
Since the `GemmArguments` constructed above now include `EpilogueArguments`, returned kernels must support the provided epilogue.

Under the hood of `get_kernels()`, each `Kernel` class will determine in its `generate_kernels()` method whether it supports the provided `EpilogueArguments`.
It can do so by traversing the DAG that resulted from the construction of `EpilogueArguments` to find the operations that compose the epilogue.
Assuming that the `Kernel` can support the DAG, it must then add to the source for the kernel any operations needed to support the DAG.
An example of how this is done generically for an SM100 CuTe DSL GEMM is provided in `sm100_static_persistent_efc.py`.

## Example epilogues
We now provide various examples of adding custom epilogues to GEMM kernels targeting SM100. A broader set of epilogue examples are available in `test_gemm_epilogue_fusion.py`.

### Auxiliary input and output tensors

In [None]:
from cutlass_api.fusion.activation import relu

def relu_aux_store(accum, alpha, C):
  F = (accum * alpha) + (C * 2.0) # Constant beta
  D = relu(F)
  return D, F

C = torch.randn((L, M, N), device="cuda", dtype=torch.float16)
alpha = 3.0
D = torch.empty((L, M, N), device="cuda", dtype=torch.float16)
F = torch.empty((L, M, N), device="cuda", dtype=torch.float16)

epi_args = EpilogueArguments(relu_aux_store, alpha=alpha, C=C, D=D, F=F)
args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)

D_ref, F_ref = relu_aux_store(A @ B, alpha, C)

torch.testing.assert_close(D, D_ref)
torch.testing.assert_close(F, F_ref)


### Keyword functions and returning accumulator

In [None]:
def relu_scale_return_acc(accum, alpha, beta, C, scale):
  F = relu((accum * alpha) + (C * beta))
  D = F * scale
  return D, F, accum

C = torch.randn((L, M, N), device="cuda", dtype=torch.float16)
alpha = 1.0
beta = 2.0
scale = 0.5
D = torch.empty((L, M, N), device="cuda", dtype=torch.float16)
F = torch.empty((L, M, N), device="cuda", dtype=torch.float16)
accum = torch.empty((L, M, N), device="cuda", dtype=torch.float32)

epi_args = EpilogueArguments(relu_scale_return_acc, alpha=alpha, beta=beta, C=C, scale=scale, D=D, F=F, accum=accum)
args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)

D_ref, F_ref, accum_ref = relu_scale_return_acc(A @ B, alpha, beta, C, scale)

torch.testing.assert_close(D, D_ref)
torch.testing.assert_close(F, F_ref)
torch.testing.assert_close(accum, accum_ref.to(accum.dtype))


### Passing a string representation of the function
`EpilogueArguments` can additionally be constructed using a string representation of the epilogue function:

In [None]:
epi_str = "def epi(accum, alpha, beta, C): F = (accum * alpha) + (C * beta); D = relu(F); return D, F"

C = torch.randn((L, M, N), device="cuda", dtype=torch.float16)
alpha = 1.0
beta = 0.5
D = torch.empty((L, M, N), device="cuda", dtype=torch.float16)
F = torch.empty((L, M, N), device="cuda", dtype=torch.float16)

epi_args = EpilogueArguments(epi_str, alpha=alpha, beta=beta, C=C, D=D, F=F)
args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)

F_ref = (A @ B) * alpha + (C * beta)
D_ref = torch.relu(F_ref)

torch.testing.assert_close(D, D_ref)
torch.testing.assert_close(F, F_ref)


### Failure examples
The following are examples of constructing `EpilogueArguments` that are expected to fail.

In [None]:
####################################################
# Epilogues must take in an accumulator
####################################################
def fail_missing_accum(alpha, beta, C):
  D = (C * beta)
  return D

try:
  epi_args = EpilogueArguments(fail_missing_accum, alpha=alpha, beta=beta, C=C, D=D)
  args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)
except Exception as e:
  # "accum must be an input to the epilogue function"
  print(e)


In [None]:
####################################################
# Epilogues must return an output named D
####################################################
def fail_missing_D(accum, alpha, beta, C):
  F = (accum * alpha) + (C * beta)
  return F

try:
  epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)
  args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)
except Exception as e:
  # "On SM90 or higher, D is expected to be a output node with 0 users to enable smem reuse between C and D, but got []"
  print(e)


In [None]:
####################################################
# Epilogues must use single-static assignment (SSA)
####################################################
def fail_ssa(accum):
    tmp = accum * 2.0
    # Redefine tmp, which violates SSA form.
    tmp = tmp - 1.0
    D = tmp / 4.0
    return D, tmp

try:
  epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)
  args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)
except Exception as e:
  # "Variable 'tmp' cannot be defined twice."
  print(e)


In [None]:
####################################################
# Must provide all operands and outputs to
# EpilogueArguments
####################################################
def my_epi(accum, alpha, beta, C):
  F = (accum * alpha) + (C * beta)
  D = relu(F)
  return D

try:
  # Missing D
  epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)
  args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)
except Exception as e:
  # "Argument D is not provided in the kwargs of the EpilogueArguments constructor"
  print(e)

try:
  # Missing alpha
  epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)
  args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)
except Exception as e:
  # "Argument alpha is not provided in the kwargs of the EpilogueArguments constructor"
  print(e)
