In [6]:
import torch
import tilelang
import tilelang.language as T
from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout

In [7]:
T.Buffer

<tilelang.language.proxy.BufferProxy at 0x7f16813866c0>

In [8]:
def matmul(M, N, K, block_M, block_N, block_K, num_stages=3, dtype="float16", accum_dtype="float"):
    """
    TileLang GEMM kernel factory function.
    Based on official tilelang documentation examples.
    """
    @T.prim_func
    def main(
        A: T.Tensor((M, K), dtype),  # FIX: T.Buffer -> T.Tensor
        B: T.Tensor((K, N), dtype),  # FIX: T.Buffer -> T.Tensor
        C: T.Tensor((M, N), dtype),  # FIX: T.Buffer -> T.Tensor
    ):
        # FIX: T.kernel -> T.Kernel, thread -> threads
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):

            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local  = T.alloc_fragment((block_M, block_N), accum_dtype)

            # Clear local accumulation
            T.clear(C_local)

            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                # Copy A tile to shared memory
                T.copy(A[by * block_M, ko * block_K], A_shared)
                # Copy B tile to shared memory
                T.copy(B[ko * block_K, bx * block_N], B_shared)
                # Perform matrix multiplication (FIX: removed k_pack parameter)
                T.gemm(A_shared, B_shared, C_local)
            
            # Copy result from local fragment to global memory
            T.copy(C_local, C[by * block_M, bx * block_N])
    return main

In [9]:
import itertools
import pandas as pd

def build_gemm_config():
    """Generate configurations for autotuning GEMM kernel."""
    block_M = [64, 128]      # Reduced for faster testing
    block_N = [64, 128]
    block_K = [32, 64]
    num_stages = [2, 3]

    all_combinations = itertools.product(
        block_M, block_N, block_K, num_stages
    )

    return [
        {
            "block_M": m, "block_N": n, "block_K": k,
            "num_stages": s
        }
        for m, n, k, s in all_combinations
    ]

configs = build_gemm_config()
df = pd.DataFrame(configs)
print(f"Total configurations: {len(configs)}")
df

Total configurations: 16


Unnamed: 0,block_M,block_N,block_K,num_stages
0,64,64,32,2
1,64,64,32,3
2,64,64,64,2
3,64,64,64,3
4,64,128,32,2
5,64,128,32,3
6,64,128,64,2
7,64,128,64,3
8,128,64,32,2
9,128,64,32,3


In [10]:
import torch as th
from tilelang.autotuner import AutoTuner
from tilelang.autotuner.capture import set_autotune_inputs

device = th.device("cuda")
M, N, K = 8192, 4096, 4096
A = th.randn((M, K), dtype=th.float16, device=device)
B = th.randn((K, N), dtype=th.float16, device=device)
C = th.empty((M, N), dtype=th.float16, device=device)

# Kernel entrypoint - parameter names must match config keys exactly
def kernel_entrypoint(block_M=None, block_N=None, block_K=None, num_stages=None):
    return matmul(M, N, K, block_M, block_N, block_K, num_stages)

with set_autotune_inputs(A, B, C):
    autotuner = AutoTuner.from_kernel(
        kernel=kernel_entrypoint,
        configs=build_gemm_config()
    ).set_compile_args(
        out_idx=[-1],
        target="auto"
    )

    result = autotuner.run(warmup=10, rep=50)

print(f"Best config: {result.config}")
print(f"Best latency: {result.latency:.4f} ms")

best_kernel = result.kernel
C_output = best_kernel(A, B)

# Verify correctness
C_ref = A @ B
print(f"Max error: {(C_output - C_ref).abs().max().item():.6f}")

2026-01-14 05:45:53,444 INFO:Auto-tuning with 0.9 CPU utilizations, 180 CPUs available, 162 CPUs will be used


Compiling configurations:   0%|          | 0/16 [00:00<?, ?it/s]

2026-01-14 05:45:55  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[-1]`
2026-01-14 05:45:55  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[-1]`
2026-01-14 05:45:55  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[-1]`
2026-01-14 05:45:55  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[-1]`
2026-01-14 05:45:55  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[-1]`
2026-01-14 05:45:55  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[-1]`
2026-01-14 05:45:55  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[-1]`
2026-01-14 05:45:55  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `main` with `out_idx=[-1]`
2026-01-14 05:45:55  [TileLang:tilelang.jit.kern

Bench configurations:   0%|          | 0/16 [00:00<?, ?it/s]

Tuned Latency 2.7071359157562256 with config {'block_M': 64, 'block_N': 64, 'block_K': 32, 'num_stages': 2} at index 0
Incompatible input tensor properties detected between cached tensors and tensors regenerated for the current configuration trial. This can happen if different tuning configurations require different input shapes/dtypes and input tensor caching is enabled.
To ensure fresh, compatible inputs are generated for every trial you can disable caching by setting:
  `cache_input_tensors=False`
within your `.set_compile_args(...)` call.

Tuned Latency 1.7991039752960205 with config {'block_M': 64, 'block_N': 128, 'block_K': 32, 'num_stages': 2} at index 1
Incompatible input tensor properties detected between cached tensors and tensors regenerated for the current configuration trial. This can happen if different tuning configurations require different input shapes/dtypes and input tensor caching is enabled.
To ensure fresh, compatible inputs are generated for every trial you can d