In [1]:
import torch
import triton
import triton.language as tl

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
def is_cuda():
    return DEVICE.backend == "cuda"

In [5]:
@triton.jit
def matmul_kernel(a_ptr,b_ptr,c_ptr,
                  M,N,K,
                  stride_am,stride_ak,
                  stride_bk,stride_bn,
                  stride_cm,stride_cn,
                  BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,BLOCK_SIZE_K:tl.constexpr,
                  GROUP_SIZE_M:tl.constexpr,
                  ACTIVATION:tl.constexpr):
    pid = tl.program_id(0)
    grid_n = tl.cdiv(N,BLOCK_SIZE_N)
    pid_m = pid // grid_n 
    pid_n = pid % grid_n

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0,BLOCK_SIZE_M)) 
    mask_am = offs_am < M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0,BLOCK_SIZE_N))
    mask_bn = offs_bn < N
    offs_k = tl.arange(0,BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:,None] * stride_am + offs_k[None,:] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:,None] * stride_bk + offs_bn[None,:] * stride_bn)
    acc = tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32)

    for k in range(0,tl.cdiv(K,BLOCK_SIZE_K)):
        a = tl.load(a_ptrs,mask=mask_am[None,:] < K - k * BLOCK_SIZE_K,other=0.0)
        b = tl.load(b_ptrs,mask=mask_bn[:,None] < K - k * BLOCK_SIZE_K,other=0.0)
        acc = tl.dot(a,b,acc)
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    if ACTIVATION == 'leaky_relu':
        acc = leaky_relu(acc)
    c = acc.to(tl.float16)

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0,BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0,BLOCK_SIZE_N)
    c_ptrs = c_ptr + (offs_cm[:,None] * stride_cm + offs_cn[None,:] * stride_cn)
    mask_c = (offs_cm[:,None] < M) & (offs_cn[None,:] < N)
    tl.store(c_ptrs,c,mask=mask_c)

In [6]:
@triton.jit
def leaky_relu(x):
    return tl.where(x>=0,x,0.01 * x)

In [13]:
def matmul(a,b,activation=""):
    assert a.shape[1] == b.shape[0], "Incompatible dim"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    BLOCK_SIZE_M = 32
    BLOCK_SIZE_N = 32
    BLOCK_SIZE_K = 32
    GROUP_SIZE_M = 8
    c = torch.empty((M,N),device=a.device,dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M,META['BLOCK_SIZE_M']) * triton.cdiv(N,META['BLOCK_SIZE_N']),)
    matmul_kernel[grid](a,b,c,
                        M,N,K,
                        a.stride(0),a.stride(1),
                        b.stride(0),b.stride(1),
                        c.stride(0),c.stride(1),
                        BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,
                        GROUP_SIZE_M,
                        ACTIVATION=activation)
    return c

In [15]:
torch.manual_seed(0)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")

if torch.allclose(triton_output, torch_output, atol=1e-2):
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")

triton_output_with_fp16_inputs=tensor([[ 48.4688,  28.1406, -28.7656,  ..., -23.0000,  15.6641,  13.9219],
        [ 33.5625,  -2.6055,  14.5000,  ...,  -0.0503,  18.5625,  -9.4453],
        [ -4.6367,  -7.6758,  30.5625,  ...,  20.5469,  35.0312,  -5.9375],
        ...,
        [-29.6562,  -0.5352,  29.2344,  ...,  45.6875, -20.1719, -15.7109],
        [ 23.1094,  -6.1484, -17.9062,  ...,  14.5547,  21.3125, -19.8750],
        [  5.2539,  31.2344, -14.6641,  ..., -16.0938,  24.2812,   7.0430]],
       device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[ 48.4688,  28.1406, -28.7656,  ..., -23.0000,  15.6641,  13.9219],
        [ 33.5625,  -2.6055,  14.5000,  ...,  -0.0503,  18.5625,  -9.4453],
        [ -4.6367,  -7.6758,  30.5625,  ...,  20.5469,  35.0312,  -5.9375],
        ...,
        [-29.6562,  -0.5352,  29.2344,  ...,  45.6875, -20.1719, -15.7109],
        [ 23.1094,  -6.1484, -17.9062,  ...,  14.5547,  21.3125, -19.8750],
        [  5.2539,  31.2344, -1