# Using fake tensors with the CUTLASS API
Fake tensors (e.g., [torch's FakeTensor](https://docs.pytorch.org/docs/2.8/torch.compiler_fake_tensor.html))
are useful for describing the properties of a tensor without actually allocating backing data.

This example shows how fake tensors can be used within the CUTLASS API
for discovering and compiling a GEMM kernel.

In [1]:
import torch

import cutlass_api

torch.manual_seed(2025)

if not (status := cutlass_api.utils.is_device_cc_supported({80, 89, 90, 100, 103})):
    print(f"This notebook requires a GPU with compute capability >= 80.\n{status.error}")
    import sys
    sys.exit(0)

We first set up operands `A`, `B`, and `out` in torch's `FakeTensorMode`.
These will have all the properties needed for CUTLASS API to construct
the internal representations of tensors used for discovering and compiling
kernels.

In [2]:
M, N, K = 128, 256, 512

with torch._subclasses.fake_tensor.FakeTensorMode():
    A_fake = torch.randn(M, K, device="cuda", dtype=torch.float16)
    B_fake = torch.randn(K, N, device="cuda", dtype=torch.float16)
    out_fake = torch.empty(M, N, device="cuda", dtype=torch.float16)

print(A_fake)
print(B_fake)
print(out_fake)

FakeTensor(..., device='cuda:0', size=(128, 512), dtype=torch.float16)
FakeTensor(..., device='cuda:0', size=(512, 256), dtype=torch.float16)
FakeTensor(..., device='cuda:0', size=(128, 256), dtype=torch.float16)


We can now use these fake tensors to create `GemmArguments`, and use
these to discover and compile a compatible kernel. Note that the same APIs are
used in creating `GemmArguments` as would be used if using
"real" tensors.

In [3]:
args_fake = cutlass_api.arguments.GemmArguments(
    A_fake, B_fake, out_fake, accumulator_type=torch.float32)

cc = cutlass_api.utils.device_cc()
kernels = cutlass_api.get_kernels(args_fake, cc=cc)
assert len(kernels) > 0

kernel = kernels[0]
compiled_artifact = kernel.compile(args_fake)

The `kernel` and `compiled_artifact` discovered using fake tensors
above can now used for running the kernel using real tensors.

In [4]:
# Create real tensors
A_real = torch.randn(M, K, device="cuda", dtype=torch.float16)
B_real = torch.randn(K, N, device="cuda", dtype=torch.float16)
out_real = torch.empty(M, N, device="cuda", dtype=torch.float16)

args_real = cutlass_api.arguments.GemmArguments(
    A_real, B_real, out_real, accumulator_type=torch.float32)

# Run the kernel using the compiled_artifact from resulting
# from compiling with fake tensors.
kernel.run(args_real, compiled_artifact)

In [5]:
ref = A_real @ B_real
torch.testing.assert_close(out_real, ref)