# Grouped GEMM with contiguous tensors via the CUTLASS API

Note: this notebook requires a GPU with compute capability 100:

In [1]:
import cutlass_api

if not (status := cutlass_api.utils.is_device_cc_supported({100})):
    print(
        f"This notebook requires a GPU with compute capability 100.\n{status.error}"
    )
    import sys
    sys.exit(0)

This notebook shows how to use the CUTLASS API to discover, compile, and execute
kernels supporting contiguous offset grouped GEMMs.

In a "contiguous offset" grouped GEMM, `G` different problems are executed
in which problems differ only in the `M` mode. Their problem sizes are thus
represented as:

```text
M0 x N x K
M1 x N x K
M2 x N x K
...
M(G-1) x N x K
```

The grouped GEMM is referred to as "contiguous" because operands for different
problems in the group are contained within contiguous tensors.

Rather than having `G` different tensors for each of operands `A` and `B`, tensors
for different problems in the group are packed together:
* `A` is of shape `(TotalM, K)`, where `TotalM` is the sum of all `M` modes for problems in the group.
The `A` operands for each problem in the group are stacked along the `M` mode to form this input. More on this below.
* `B` is of shape `(G, K, N)`, where `B[i, :, :]` represents the GEMM `B` operand for the `i`th problem in the group.

For example, with `G=3` (three problems in the group), with `M` modes of M0, M1, and M2,
respectively, the tensor `A` would be laid out as follows:

```text

    +----------------------------------+         ^  
    |                                  |  |      |  
    |               A0                 |  M0     |  
    |                                  |  |      |  
    |-  -  -  -  -  -  -  -  -  -  -  -|         |  
    |                                  |  |      |
    |                                  |  |    TotalM  
    |               A1                 |  M1     |
    |                                  |  |      |
    |                                  |  |      |  
    |-  -  -  -  -  -  -  -  -  -  -  -|         |  
    |               A2                 |  M2     |  
    +----------------------------------+         v   
```

The extents of individual `A` operands packed within the overall contiguous offset `A` tensor
are provided by an auxiliary `offsets` vector of shape `(G,)`. `offsets[i]` indicates the ending
M coordinate (exclusive) for the `i`th `A` operand.

Thus, for the example above, `offsets = [M0, M0 + M1, M0 + M1 + M2]`.

The output of the operation is of shape `(TotalM, N)`. The `i`th output occupies `out[start:end, :]`,
where `start` and `end` are `offsets[i-1]` and `offsets[i]`, respectively (unless `i=0`, in which case
`start` is 0).

The reference code below shows the computation of this kernel.

In [2]:
import torch

def reference_contiguous_offset_grouped_gemm(A, B, offsets, out_dtype):
    G, K, N = B.shape
    TotalM = A.shape[0]

    out = torch.empty((TotalM, N), dtype=out_dtype, device=A.device)

    start = 0
    for i in range(G):
        end = offsets[i]
        out[start:end, :] = A[start:end, :] @ B[i, :, :]
        start = end

    return out

## Contiguous offset grouped GEMM in PyTorch

The same operation is performed by `torch`'s `torch._grouped_mm` (torch < 2.10)
and `torch.nn.functional.grouped_mm` (torch >= 2.10).

In [3]:
TotalM = 8192
G = 12
K = 1024
N = 2048

offsets = torch.arange(TotalM // G, TotalM, TotalM // G, device="cuda", dtype=torch.int32)
offsets[-1] = TotalM

A = torch.randn(TotalM, K, device="cuda", dtype=torch.bfloat16)
B = torch.randn(G, N, K, device="cuda", dtype=torch.bfloat16).permute(0, 2, 1)

out_torch = torch._grouped_mm(A, B, offsets, out_dtype=torch.bfloat16)
reference = reference_contiguous_offset_grouped_gemm(A, B, offsets, out_dtype=torch.bfloat16)

torch.testing.assert_close(out_torch, reference)

## Contiguous offset grouped GEMM in CUTLASS API

CUTLASS API exposes this contiguous offset grouped GEMM via `GroupedGemmArguments`,
which are constructed similarly to `GemmArguments`, but take in an `offsets`
tensor as well:

In [4]:
out = torch.empty((TotalM, N), device="cuda", dtype=torch.bfloat16)

args = cutlass_api.arguments.GroupedGemmArguments(
    A,
    B,
    out,
    accumulator_type=torch.float32,
    offsets=offsets,
)

One can then use the same APIs for finding, compiling, and executing a
kernel supporting this operation

In [5]:
kernels = cutlass_api.get_kernels(args, cc=100)

assert kernels, "No kernels found"

# Select the first kernel found for simplicity
kernel = kernels[0]

compiled_kernel = kernel.compile(args)

# Execute the kernel
kernel.run(args, compiled_artifact=compiled_kernel)

torch.testing.assert_close(out, reference)