# Adding a kernel to the CUTLASS API
The CUTLASS API is designed to make it easy for users to add their own kernel
so that it can be discovered and run under the uniform API. We welcome contributions
toward the API by "bringing your own kernel."

This example shows how to add a CuTe DSL kernel to the CUTLASS API.

## Bring your own implementation
Individuals wishing to add a CuTe DSL kernel to the CUTLASS API likely already
have the kernel written in CuTe DSL, but have not yet implemented the API's needed
interface. Within the API, we separate these components into the "implementation" --
the kernel written in CuTe DSL -- and the "interface" -- the definition of methods
a kernel needs to be used within the CUTLASS API.

For example, consider the following implementation of a simple FP64 GEMM kernel implementation:

In [1]:
from collections.abc import Callable

import cuda.bindings.driver as cuda

import cutlass
import cutlass.cute as cute


class F64GemmKernelImplementation:
    def __init__(self, cta_tile_shape_mn: tuple[int, int]):
        self.cta_tile_shape_mn = cta_tile_shape_mn

    @cute.jit
    def __call__(
        self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor, stream: cuda.CUstream
    ):
        l, m, n = out.shape
        m_tiles = (m + self.cta_tile_shape_mn[0] - 1) // self.cta_tile_shape_mn[0]
        n_tiles = (n + self.cta_tile_shape_mn[1] - 1) // self.cta_tile_shape_mn[1]

        grid = (m_tiles, n_tiles, l)
        block = [self.cta_tile_shape_mn[0], self.cta_tile_shape_mn[1], 1]
        self.kernel(a, b, out).launch(grid=grid, block=block, stream=stream)

    @cute.kernel
    def kernel(self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor):
        l, m, n = out.shape
        k = a.shape[-1]
        m_tile, n_tile, l_idx = cute.arch.block_idx()
        tidx, tidy, _ = cute.arch.thread_idx()

        m_idx = m_tile * self.cta_tile_shape_mn[0] + tidx
        n_idx = n_tile * self.cta_tile_shape_mn[1] + tidy

        if m_idx < m and n_idx < n:
            out[l_idx, m_idx, n_idx] = cutlass.Float64(0)
            for k_idx in range(k):
                out[l_idx, m_idx, n_idx] += (
                    a[l_idx, m_idx, k_idx] * b[l_idx, k_idx, n_idx]
                )

The implementation is configurable via a `cta_tile_shape_mn` argument, which
controls the size of blocks and tiles in the M and N modes. A simple `cute.jit` function
computes the grid and block size for the input problem based on `cta_tile_shape_mn`,
and launches the kernel. The `cute.kernel` itself simply has each thread compute a single
output element of the matrix by taking a dot product.

This implementation is not performant, but is kept simple for illustrative purposes.

## Defining interface methods
As it currently stands, this GEMM kernel implementation cannot be used via the
CUTLASS API because it does not implement interface methods. Specifically, kernels
within the CUTLASS API must inherit from and implement the `cutlass_api.Kernel`
abstract class. This class has methods needed for many common operations
performed when compiling and executing DSL kernels.

Certain providers (i.e., DSLs), such as CuTe DSL, provide an additional layer atop the
`cutlass_api.Kernel` class to add utilities for kernels being written
via that provider. For example, the CuTe DSL provider in the CUTLASS API
defines `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`, which adds utilities surrounding
`cute.compile()` to add compile-time arguments needed for using TVM-FFI when
it is enabled.

We will next walk through the steps in defining interface methods for this
implementation

In [2]:
import itertools

import cutlass_api
from cutlass_api.arguments import GemmArguments
from cutlass_api.metadata import KernelMetadata
from cutlass_api.status import Status

We begin by defining a class to represent the kernel's interface.
As mentioned above, since this is a CuTe DSL kernel, our interface must
inherit from and implement `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`.

The class must additionally be registered with the CuTe DSL provider
via the `@CuTeDSLProvider.register` decorator so that the class
can be considered when discovering kernels.

In [None]:
@cutlass_api.providers.cutedsl.CuTeDSLProvider.register
class F64GemmKernel(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):
    # Empty versions of interface methods. These will be implemented later, interspersed
    # with notebook markdown. Normally, one would define them inline with the class definition.
    def __init__(self, metadata: KernelMetadata):
        pass

    def _run(
        self,
        args: GemmArguments,
        artifact: cutlass_api.artifact.CompiledArtifact,
        stream,
        workspace=None,
    ):
        pass

    def compile(
        self, args: GemmArguments, cc: int = None
    ) -> cutlass_api.artifact.CompiledArtifact:
        pass

    @staticmethod
    def generate_kernels(
        metadata_filter, epilogue_args=None, cc=None
    ) -> list["F64GemmKernel"]:
        pass

    def _supports(self, args: GemmArguments) -> Status:
        pass

    def get_workspace_size(self, args: GemmArguments) -> int:
        pass

The `__init__` method of the class takes in a `KernelMetadata` object
from which it extracts the `cta_tile_shape_mn`. This is used to construct
the kernel implementation object. We will discuss later how the `KernelMetadata`
object passed in here is constructed:

In [4]:
def __init__(self, metadata: KernelMetadata):
    # Using Python-2-style super() because we're defining this method outside of the class definition.
    super(F64GemmKernel, self).__init__(metadata)
    cta_tile_shape_mn = metadata.design.tile_shape[:2]
    self.impl = F64GemmKernelImplementation(cta_tile_shape_mn)

### Defining interfaces for compilation and execution
The interfaces needed for compilation and execution are simple.

The `compile` method simply constructs a placeholder stream object
and passes that and relevant arguments to `self.cute_compile`. This
is a utility defined in the `CuteDSLKernel` abstract class that
passes in compilation flags needed for certain options to `cute.compile`
(e.g., TVM-FFI). The result is wrapped as a `CompiledArtifact`.

In [5]:
def compile(
    self, args: GemmArguments, cc: int = None
) -> cutlass_api.artifact.CompiledArtifact:
    stream = cutlass.cute.runtime.make_fake_stream()
    compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)
    return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)

Users define the `_run` method rather than the top-level `run` method
(no leading underscore) that is used in interacting with kernels. `_run` (1) extracts from `args`
the arguments needed to run the JIT function, and (2) calls the JIT function
passed in via `artifact` with these arguments.

In [6]:
def _run(
    self,
    args: GemmArguments,
    artifact: cutlass_api.artifact.CompiledArtifact,
    stream,
    workspace=None,
):
    stream = cutlass_api.utils.to_cuda_stream(stream)
    compiled_gemm = artifact.compiled_obj
    self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)

Finally, since this kernel does not require any device workspace,
we give it a simple `get_workspace_size` method that always returns 0.

In [7]:
def get_workspace_size(self, args: GemmArguments) -> int:
    return 0

### Defining interfaces for kernel generation
We have implemented the interfaces needed for constructing the kernel
interface, compiling it, and running it. We now must implement methods for
generating the possible configurations of this kernel that the kernel
class itself supports. This will be used in kernel discovery (e.g., via
`cutlass_api.get_kernels()`).

To do so, we write the `generate_kernels` method. This takes in a
binary function `metadata_filter`, epilogue arguments `epilogue_args`,
and a compute capability `cc`. It returns a list of all instances
of the kernel interface that support the `epilogue_args`, are compatible
with the given `cc`, and which pass the `metadata_filter`.

The `Kernel` class is responsible for defining what valid possible configurations (instances) of it can exist.
In this example, the valid configurations involve a cross-product of row/column-major strides and two preset tile shapes.
We create a nested loop over these knobs and create a `KernelMetadata` corresponding to each unique configuration.

The `generate_kernels` method must additionally filter the generated kernels by passing it through a `metadata_filter`.
This is a user-provided custom filter to filter generated metadata combinations. More information on `metadata_filter` is provided in other examples.

In [8]:
@staticmethod
def generate_kernels(
    metadata_filter: Callable[[KernelMetadata], bool],
    epilogue_args: cutlass_api.arguments.EpilogueArguments = None,
    cc: int = None,
) -> list["F64GemmKernel"]:
    # The tile shapes this kernel supports/exposes
    supported_tile_shapes = [(32, 32, 1), (16, 16, 1)]

    if epilogue_args is not None:
        return []

    row_major_stride = (0, 0, 1)
    col_major_stride = (0, 1, 0)
    stride_combos = list(
        itertools.product([row_major_stride, col_major_stride], repeat=3)
    )
    divisibility = 1

    def stride_name(stride):
        return "T" if stride == row_major_stride else "N"

    kernels = []
    for tile_shape in supported_tile_shapes:
        design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))
        for stride_A, stride_B, stride_out in stride_combos:
            # Create TensorAttributes for A, B, and out tensors
            a_attrs = cutlass_api.metadata.TensorAttributes(
                cutlass.Float64, stride_A, divisibility
            )
            b_attrs = cutlass_api.metadata.TensorAttributes(
                cutlass.Float64, stride_B, divisibility
            )
            out_attrs = cutlass_api.metadata.TensorAttributes(
                cutlass.Float64, stride_out, divisibility
            )
            layout_str = cutlass_api.utils.strides_to_layout_string(
                stride_A, stride_B, stride_out
            )

            name = f"F64GemmKernel_tile{tile_shape[0]}x{tile_shape[1]}_{layout_str}"

            metadata = KernelMetadata(
                kernel_name=name,
                kernel_class=F64GemmKernel,
                operands=cutlass_api.metadata.GemmOperandsMetadata(
                    a_attrs, b_attrs, out_attrs, accumulator_type=cutlass.Float64
                ),
                design=design_metadata,
                min_cc=0,
            )

            if metadata_filter(metadata):
                kernels.append(F64GemmKernel(metadata))

    return kernels

We also add a method for indicating whether a kernel instance in question
supports a set of arguments. The top-level `Kernel.supports` method will
already verify that the `args` passed in match the metadata with which
this `Kernel` instance was constructed. Here, we define additional
checks specific to this kernel, such as that the kernel expects
all operands to be of rank 3:

In [9]:
def _supports(self, args: GemmArguments) -> Status:
    if not (
        len(args.A.shape) == 3  # A should be (L, M, K)
        and len(args.B.shape) == 3  # B should be (L, K, N)
        and len(args.out.shape) == 3  # out should be (L, M, N)
    ):
        return Status.fail("All operands must be rank 3.")
    return Status.success()

In [10]:
# Assign methods to the class because we interspersed notebook markdown
# with the class definition. This is not needed in a real implementation.
F64GemmKernel.__init__ = __init__
F64GemmKernel.compile = compile
F64GemmKernel._run = _run
F64GemmKernel._supports = _supports
F64GemmKernel.generate_kernels = generate_kernels
F64GemmKernel.get_workspace_size = get_workspace_size

## Discovering instances of the kernel and using them
The CUTLASS API is now prepared to discover instances of this
kernel interface just as was done in previous examples.

We add a small modification of using a `metadata_filter`
to ensure that all returned kernels are instances of the
`F64GemmKernel` class we just implemented. This is needed
only for example/testing purposes.

In [11]:
import torch

torch.manual_seed(2025)

L, M, N, K = 1, 256, 1024, 128
A = torch.randn(L, M, K, device="cuda", dtype=torch.float64)
B = torch.randn(L, K, N, device="cuda", dtype=torch.float64)
out = torch.empty(L, M, N, device="cuda", dtype=torch.float64)

args = GemmArguments(A, B, out, accumulator_type=torch.float64)


def is_f64gemm_kernel(metadata):
    return metadata.kernel_class == F64GemmKernel


kernels = cutlass_api.get_kernels(args, metadata_filter=is_f64gemm_kernel)

We can print off the names of the first few kernels to see that
they come from our recently-added kernel.

In [12]:
print(kernels[0].metadata.kernel_name)
print(kernels[1].metadata.kernel_name)

F64GemmKernel_tile32x32_ttt
F64GemmKernel_tile16x16_ttt


We can evaluate and test the correctness of an instance of our kernel:

In [13]:
kernels[0].run(args)
torch.testing.assert_close(out, A @ B)

We can also test the limits of our kernel's design space by providing a
metadata filter that expects a CTA tile size M of 256, which is not exposed
in the `generate_kernels` method of our recently-added kernel. We expect
no kernels of type `F64GemmKernel` to be returned.

In [14]:
def my_filter(metadata):
    return (
        is_f64gemm_kernel(metadata)
        and isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata)
        and metadata.design.tile_shape[0] == 256
    )


kernels_ctam256 = cutlass_api.get_kernels(args, metadata_filter=my_filter)

# No kernels should be found
assert len(kernels_ctam256) == 0

## A note on contributing kernels to directory structure
This example showed how to define a kernel inline and add it to the
API for example purposes. This kernel doesn't necessarily need to live
within the API's source code.

We welcome contributions of kernels that do live within the CUTLASS
API's repository as well.

Kernels in the repository are organized based on the "provider" in which they are
authored (i.e., the DSL). All kernels corresponding to a given
provider live a directory corresponding to that provider under
`cutlass_api/providers`. For example, CuTe DSL kernels live
under `cutlass_api/providers/cutedsl`.

Each provider can organize kernels differently. For CuTe DSL,
kernels are further split based on their logical operation,
with GEMM kernels under the `cutlass_api/providers/cutedsl/gemm`
directory.

We recommend separating the implementation of the kernel from
its interface not just by using separate classes, as done in
this example, but also by separating the implementation and
interface into separate files. This makes it easier to update
each without affecting the other.

For example, CuTe DSL GEMM kernels have the following organization:
```text
cutlass_api/
  providers/
    cutedsl/
      gemm/
        sm100_static_persistent.py
        implementations/
          sm100_static_persistent_impl.py
```