# Basic GEMM using CUTLASS Python API

The CUTLASS API provides a consistent, uniform interface for discovering, compiling, and running GPU kernels from various DSL sources.

This notebook walks through a minimal GEMM (Generalized Matrix-Matrix Multiplication) example, and introduces the core concepts of the API.

In [1]:
import torch

import cutlass

import cutlass_api

if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):
    print(
        f"This notebook requires a GPU with compute capability 100 or 103.\n{status.error}"
    )
    import sys

    sys.exit(0)

## Running your first kernel

### Setting up arguments

CUTLASS API has first-class support for PyTorch tensors. We start by creating torch tensors that will be operands to a matrix multiplication.

In [2]:
M, N, K, L = 128, 256, 64, 2
ab_type = torch.float16
out_type = torch.float32
acc_type = torch.float32

A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_type)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_type)
out = torch.empty((L, M, N), device="cuda", dtype=out_type)

reference = (A @ B).to(out.dtype)

We then create a `GemmArguments` object. This object specifies:
1. what logical operation do we want to perform (a GEMM)
2. on which operands we want to perform that operation (`A`, `B`, `out` as declared above)

In [3]:
args = cutlass_api.arguments.GemmArguments(A=A, B=B, out=out, accumulator_type=acc_type)

### Kernel discovery

We now need to find kernels that can perform the operation we expressed in `args`.

The simplest way to do so is to use `get_kernels(args)`. It searches a set of kernels pre-registered in the library, and returns the subset of those kernels which can successfully run our given `args`.

Any of these kernels will be functionally equivalent -- they may have different design or performance characteristics. We arbitrarily pick the first of the returned kernels to execute here

In [4]:
kernels = cutlass_api.get_kernels(args)
assert kernels, "No kernels found for the given arguments!"

kernel = kernels[0]

#### Run the kernel

Running the kernel is as simple as `kernel.run(args)`.

This implicitly JIT-compiles the kernel, and launches it on the GPU device using our given arguments.

In [5]:
kernel.run(args)

torch.testing.assert_close(out, reference)

One can also explicitly compile the kernel and pass this in to `kernel.run` to avoid
JIT compilation on future invocations. Additional details related to this will be
described below.

In [None]:
artifact = kernel.compile(args)
kernel.run(args, compiled_artifact=artifact)
torch.testing.assert_close(out, reference)

---

---

### Understanding the core interfaces

#### 1. `RuntimeArguments` / `GemmArguments`

`RuntimeArguments` describe the operation a user wants to perform, and all the runtime operands or other runtime parameters needed for it. 
This includes primary runtime operands to the operation, as well as any custom epilogue fusions and runtime performance knobs.

We provide builtin subtypes of `RuntimeArguments` for common operations (e.g. GEMM, Elementwise ops; more later).

For instance, `GemmArguments` is a type of `RuntimeArguments`:

```python
@dataclass
class GemmArguments(RuntimeArguments):
    A: TensorLike
    B: TensorLike
    out: TensorLike
    accumulator_type: NumericLike
```

`GemmArguments` conveys:
* We want to perform a dense GEMM operation (`out = A @ B`)
* We want to perform it for operands in `A, B, out`, with intermediate results stored as `accumulator_type`
* We can optionally set a custom epilogue that is fused on top of the base GEMM. Some kernels also support some runtime performance controls which can be specified here. These will be discussed in detail in other tutorials.

It is a kernel-agnostic way to specify the desired functionality.

`RuntimeArguments` can be constructed from any `TensorLike` object. This includes `torch.Tensor`, `cute.Tensor`, or any other DLPack-compatible tensors.

#### 2. Kernel Discovery

There are several kernels available in CUTLASS DSLs that are registered with, and discoverable via, the CUTLASS API.

This includes kernels for various operations (GEMM, Elementwise operations, ...), which implement various algorithms & architecture features. Within the same implementation, there are several instances or  configurations of it with different combinations of operand types, layouts, tile sizes, etc.

In the previous step, we used `GemmArguments` to specify our desired GEMM in a kernel-agnostic way. Now we find kernels that can fulfill that functionality. A subset of the available kernels will perform GEMM, and a subset of _those_ will support the properties of specific operands we are currently using.

In [6]:
# get_kernels() fetches all kernels when called without args
all_kernels = cutlass_api.get_kernels()
print(f"A total of {len(all_kernels)} kernel instances are available.")

# we can limit the search to kernels supporting given args
kernels = cutlass_api.get_kernels(args)
print(f"Of these, {len(kernels)} support the given arguments.")

kernel = kernels[0]
print(f"Picked kernel with name: {kernel.metadata.kernel_name}")

A total of 107616 kernel instances are available.
Of these, 350 support the given arguments.
Picked kernel with name: cutedsl.PersistentDenseGemmKernel_sm100_ttt_AFloat16_BFloat16_outFloat32_accFloat32_2cta_cluster2x1x1_tile128x32x256_tma_store


#### 3. `Kernel` execution

Once we have selected a kernel, we are now ready to execute it. We previously showed the simplest way to do this is `kernel.run(args)`.

This method does the following:
* verify that the kernel supports the given `args`
* JIT-compile the kernel
* launch the compiled kernel function

Users can do these steps individually for more control:

* `kernel.supports(args)` checks if the kernel supports the given `args`
    * this is relevant if the kernel was not picked just for these `args`

In [7]:
supported = kernel.supports(args)
assert supported

If the arguments are not supported by this kernel, `supports` returns a `Status` object explaining the error.

In [8]:
unsupported_args = cutlass_api.arguments.GemmArguments(
    A=A.to(torch.bfloat16), B=B, out=out, accumulator_type=acc_type
)
if not (status := kernel.supports(unsupported_args)):
    print(status.error)

assert not status

Operand `A` is unsupported: Expected element type Float16, got BFloat16


* `kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`

This compiled artifact is a lightweight wrapper over the result of compiling a kernel (e.g., via `cute.compile()`).

For just-in-time compilation, we can use the compiled artifact straightaway.
In the future, we will support optionally serializing it for ahead-of-time compilation and deserialized in a different context.


In [9]:
compiled_artifact = kernel.compile(args)

* `kernel.run(args)` launches the compiled kernel function. This example uses:
    * the precompiled artifact
    * a custom stream to launch to
    * bypasses the supports check already performed above (`assume_supported_args=True`).

In [10]:
# zero the output to avoid testing stale output
out.zero_()

kernel.run(
    args,
    compiled_artifact,
    stream=torch.cuda.Stream(),
    assume_supported_args=True,
)
torch.testing.assert_close(out, reference)

Some kernels may also require a device "workspace". This is an additional buffer needed by some kernels for book-keeping, temporary results, etc.
Its size can be queried using `kernel.get_workspace_size(args)`. Most kernels will have a workspace size of 0.
If a kernel does have a non-zero workspace size, an additional buffer of at least that size must be provided. Without it, the kernel behavior is undefined.

In [11]:
workspace_size = kernel.get_workspace_size(args)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.int8)

out.zero_()
kernel.run(args, compiled_artifact, stream=torch.cuda.Stream(), workspace=workspace)
torch.testing.assert_close(out, reference)

### Advanced: Filtering on Metadata

Using `RuntimeArguments` to search for supporting kernels is a convenient way to discover kernels: users directly specify their desired functionality, and `get_kernels()` finds the supporting kernels.
It covers all logical operands of any operation, as well as (in later examples) epilogue fusions, and performance controls.

However, there may be cases where users want more advanced ways to query kernels. These could be:
* when the desired properties may not be expressed in runtime controls
   * the simplest scenario may be if you're searching searching for a kernel with a specific name, a specific class, etc.
   * searching for kernel's static properties such as tile size, cluster size, etc.
* when the `RuntimeArguments` are not available or you want to generate & pre-compile a broader set of kernels

For such cases, we provide a more advanced filtering based on `KernelMetadata`

`KernelMetadata` captures a wide variety of properties of a `Kernel`.

These are properties of a kernel's functional support (like operand types, layouts, alignments), as well as architectural/design choices & performance characteristics (like tilze size, scheduling characteristics).

Different kernels may use different sub-classes of `metadata.operands`, `metadata.design`, `metadata.epilogue` for flexibility, which can also identify their characteristics.

```python
@dataclass
class KernelMetadata:
    kernel_name: str
    kernel_class: type["Kernel"]
    min_cc: int
    operands: OperandsMetadata
    design: DesignMetadata | None = None
    epilogue: EpilogueMetadata | None = None
```

Every unique kernel instance can be distinguished by its metadata.
It can be used in filtering for kernels in addition to the `RuntimeArguments`, by providing a custom `metadata_filter`.

Here, we get all kernels that support `args`, and have `metadata.design` of type `Sm100DesignMetadata`.



In [12]:
kernels = cutlass_api.get_kernels(
    args,
    metadata_filter=lambda metadata: isinstance(
        metadata.design, cutlass_api.metadata.Sm100DesignMetadata
    ),
)
print(f"Found {len(kernels)} kernels which support args & have Sm100DesignMetadata")

Found 350 kernels which support args & have Sm100DesignMetadata


We can construct more advanced filters by leveraging duck-typing.
Additionally, we can get all the kernels that match our filter, rather than supporting a fully-defined set of arguments.
This could be useful to pre-generate large set of kernels not targeted to any one problem.

In [13]:
def a_more_complex_filter(metadata: cutlass_api.metadata.KernelMetadata) -> bool:
    """
    Find all GEMM kernels that support Float16 A and 2-CTA MMA
    """
    # Only look at GEMM kernels
    if not isinstance(metadata.operands, cutlass_api.metadata.GemmOperandsMetadata):
        return False
    # Only look at kernels with A-type F16
    if metadata.operands.A.dtype != cutlass.Float16:
        return False
    # Only look at kernels with tile_shape[0] == 128
    if getattr(metadata.design, "tile_shape", [None])[0] != 128:
        return False
    return True


# Look ma, no args! Fetch all kernels that match the filter,
# instead of supporting a complete set of args
kernels = cutlass_api.get_kernels(
    args=None,
    metadata_filter=a_more_complex_filter,
)
print(f"Found {len(kernels)} matching kernels")

Found 9400 matching kernels
