In [1]:
import triton
import triton.language as tl
import torch

In [2]:
def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"

def get_autotune_configs():
    return [
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        # Good config for fp8 inputs.
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4)
    ]

@triton.autotune(
    configs=get_autotune_configs(),
    key= ['M', 'N', 'K']
)
@triton.jit
def mm_kernel(a_ptr, b_ptr, c_ptr,
              M, N, K,
              am_stride, ak_stride,
              bk_stride, bn_stride,
              cm_stride, cn_stride,
              BLOCK_SIZE_M: tl.constexpr,
              BLOCK_SIZE_N: tl.constexpr,
              BLOCK_SIZE_K: tl.constexpr,
              GROUP_SIZE_M: tl.constexpr,
              ):
    pid          = tl.program_id(0)
    num_programs = tl.num_programs(0)

    n_blocks          = tl.cdiv(N, BLOCK_SIZE_N)
    m_blocks          = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_per_group = GROUP_SIZE_M * n_blocks
    grp_idx           = pid // num_pid_per_group
    m_group_size      = min(m_blocks - grp_idx * GROUP_SIZE_M, GROUP_SIZE_M)

    group_m = grp_idx * GROUP_SIZE_M * BLOCK_SIZE_M
    pid_m   = group_m + BLOCK_SIZE_M * (pid % m_group_size)
    pid_n   = BLOCK_SIZE_N * ((pid % num_pid_per_group) // m_group_size)

    am_offsets = (pid_m + tl.arange(0, BLOCK_SIZE_M)) % M
    bn_offsets = (pid_n + tl.arange(0, BLOCK_SIZE_N)) % N
    k_offsets  = tl.arange(0, BLOCK_SIZE_K)
    
    a_ptrs = a_ptr + (am_offsets[:, None] * am_stride + k_offsets[None, :] * ak_stride)
    b_ptrs = b_ptr + (k_offsets[:, None] * bk_stride + bn_offsets[None, :] * bn_stride)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in tl.range(0, K, BLOCK_SIZE_K):
        a_block = tl.load(a_ptrs, mask=k_offsets[None, :] < (K - k), other=0.0)
        b_block = tl.load(b_ptrs, mask=k_offsets[:, None] < (K - k), other=0.0)

        accumulator = tl.dot(a_block, b_block, accumulator)

        a_ptrs += BLOCK_SIZE_K * ak_stride
        b_ptrs += BLOCK_SIZE_K * bk_stride

    c = accumulator.to(tl.float16)

    # write back the block of the output matrix C with masks
    cm_offsets = pid_m + tl.arange(0, BLOCK_SIZE_M)
    cn_offsets = pid_n + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + (cm_offsets[:, None] * cm_stride + cn_offsets[None, :] * cn_stride)
    c_mask = (cm_offsets[:, None] < M) & (cn_offsets[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


In [3]:
def matmul(a, b, activation=""):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    mm_kernel[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
    )
    return c

In [5]:
DEVICE = triton.runtime.driver.active.get_active_torch_device()

torch.manual_seed(0)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
# torch.set_printoptions(profile="full")
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
# torch.set_printoptions(profile="default")
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0.0):
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")

TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
if TORCH_HAS_FP8 and is_cuda():
    torch.manual_seed(0)
    a = torch.randn((16, 16), device=DEVICE, dtype=torch.float16)
    b = torch.randn((16, 16), device=DEVICE, dtype=torch.float16)
    a = a.to(torch.float8_e5m2)
    # pre-transpose b for efficiency.
    b = b.T
    b = b.to(torch.float8_e5m2)
    triton_output = matmul(a, b)
    torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16))
    print(f"triton_output_with_fp8_inputs={triton_output}")
    print(f"torch_output_with_fp8_inputs={torch_output}")
    if torch.allclose(triton_output, torch_output, atol=0.125, rtol=0):
        print("✅ Triton and Torch match")
    else:
        print("❌ Triton and Torch differ")


triton_output_with_fp16_inputs=tensor([[  0.8882, -25.5312,  12.9375,  ...,  -0.1477,  -8.3750,  -5.4609],
        [ 22.3438, -14.1719, -17.5312,  ..., -28.3906, -22.6406,  28.9219],
        [-33.6875,   1.8291, -22.8438,  ...,   9.2578, -52.2812,  12.8750],
        ...,
        [ 14.2031, -54.0312,  10.4609,  ..., -13.9531, -14.6953,  -2.7188],
        [-33.2500,  29.1094, -20.9375,  ...,   4.3203, -14.3906, -10.8672],
        [-18.5938,  22.7188, -30.7500,  ...,  -9.0312,  -6.8281,  -1.4648]],
       device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[  0.8882, -25.5312,  12.9375,  ...,  -0.1477,  -8.3750,  -5.4609],
        [ 22.3438, -14.1719, -17.5312,  ..., -28.3906, -22.6406,  28.9219],
        [-33.6875,   1.8291, -22.8438,  ...,   9.2578, -52.2812,  12.8750],
        ...,
        [ 14.2031, -54.0312,  10.4609,  ..., -13.9531, -14.6953,  -2.7188],
        [-33.2500,  29.1094, -20.9375,  ...,   4.3203, -14.3906, -10.8672],
        [-18.5938,  22.7188, -3

In [None]:
ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS'

configs = []
for fp8_inputs in [False, True]:
    if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()):
        continue
    configs.append(
        triton.testing.Benchmark(
            x_names=["M", "N", "K"],  # Argument names to use as an x-axis for the plot
            x_vals=[128 * i for i in range(2, 33)],  # Different possible values for `x_name`
            line_arg="provider",  # Argument name whose value corresponds to a different line in the plot
            # Possible values for `line_arg`
            # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
            line_vals=["triton"] if fp8_inputs else [ref_lib.lower(), "triton"],  # Label name for the lines
            line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"],  # Line styles
            styles=[("green", "-"), ("blue", "-")],
            ylabel="TFLOPS",  # Label name for the y-axis
            plot_name="matmul-performance-" +
            ("fp16" if not fp8_inputs else "fp8"),  # Name for the plot, used also as a file name for saving the plot.
            args={"fp8_inputs": fp8_inputs},
        ))


@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider, fp8_inputs):
    a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
    b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
    if TORCH_HAS_FP8 and fp8_inputs:
        a = a.to(torch.float8_e5m2)
        b = b.T
        b = b.to(torch.float8_e5m2)
    quantiles = [0.5, 0.2, 0.8]
    if provider == ref_lib.lower():
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
    perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)


benchmark.run(show_plots=True, print_data=True)