In [1]:
import time
import torch
import triton
import triton.language as tl
import math

In [2]:
def attention_naive(X: torch.Tensor, W_q, W_k, W_v,
                    b_q, b_k, b_v, device) -> torch.Tensor:
    """
        X :     NxD (N: number of tokens, D: Dimension of latent space tokenizer)
        W_*:    D_headxD (D_head: Model space dimension / num heads)
    """

    # check if X is NxD
    assert(X.dim()==2)

    Q = torch.matmul(X, W_q.transpose(0,1)) + b_q[None, :]
    K = torch.matmul(X, W_k.transpose(0,1)) + b_k[None, :]
    V = torch.matmul(X, W_v.transpose(0,1)) + b_v[None, :]
    D_V = torch.tensor(V.shape[1], device=device)

    KQ_normalised = torch.matmul(Q, K.transpose(0,1)) / torch.sqrt(D_V)
    KQ_softmax = torch.softmax(KQ_normalised, dim=1)

    attention = torch.matmul(KQ_softmax, V)

    return attention

def multiheaded_attention_naive(X: torch.Tensor, W_qkv, W_out,
                    b_qkv, b_out, num_heads=1, device="cuda") -> torch.Tensor:
    """
    W_qkv: 3DxD
    W_out: DxD
    b_qkv: 3D
    b_out: D
    """
    # check if X is NxD
    assert(X.dim()==2)

    N, D = X.shape
    D_head = math.ceil(D / num_heads)
    attention = torch.empty((N, D), device=device, dtype=torch.float16)

    for head in range(num_heads):
        head_start = head*D_head
        head_end = min(D, (head+1)*D_head)
        attention[:,head_start:head_end] = attention_naive(
            X,
            W_qkv[0:D, :][head_start:head_end, :],
            W_qkv[D:2*D, :][head_start:head_end, :],
            W_qkv[2*D:3*D, :][head_start:head_end, :],
            b_qkv[0+head_start:0+head_end],
            b_qkv[D+head_start:D+head_end],
            b_qkv[2*D+head_start:2*D+head_end],
            device
        )

    attention = torch.matmul(attention, W_out.transpose(0,1)) + b_out[None, :]

    return attention

In [5]:
# -----------------------------
# Timing helpers (your style)
# -----------------------------
@torch.no_grad()
def time_ms0(fn, iters=100, warmup=25):
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end   = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(iters):
        fn()
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / iters

def mha_naive_wrapper(X, mha: torch.nn.MultiheadAttention):

    num_heads = mha.num_heads
    in_proj_w = mha.in_proj_weight
    in_proj_b = mha.in_proj_bias
    out_proj = mha.out_proj
    device = in_proj_w.device

    attention_naive_out = multiheaded_attention_naive(
        X,
        in_proj_w,
        out_proj.weight,
        in_proj_b,
        out_proj.bias,
        num_heads,
        device
    )

    return attention_naive_out

def mha_torch_wrapper(X, mha: torch.nn.MultiheadAttention):
    # For (N, D) input, PyTorch interprets as (L, E) when unbatched.
    # This returns (L, E). Use need_weights=False to time only output.
    out, _ = mha(X, X, X, need_weights=False)
    return out

def report(name, ms, N, D):
    toks_per_s = N / (ms / 1e3)
    print(f"{name:>14}: {ms:8.3f} ms | {toks_per_s:10.1f} tokens/s")


In [9]:
# -----------------------------
# Run benchmark
# -----------------------------
torch.manual_seed(42)
device="cuda"
N, D, H = 8192, 128, 2

X = torch.randn((N, D), device=device, dtype=torch.float16)

mha = torch.nn.MultiheadAttention(embed_dim=D, num_heads=H, device=device, dtype=torch.float16)
mha.eval()

# correctness check first
with torch.no_grad():
    ref = mha_torch_wrapper(X, mha)
    out = mha_naive_wrapper(X, mha)
    print("max abs err:", (out - ref).abs().max().item())
    print("mean abs err:", (out - ref).abs().mean().item())

# timing
torch_ms = time_ms0(lambda: mha_torch_wrapper(X, mha), iters=100, warmup=25)
naive_ms = time_ms0(lambda: mha_naive_wrapper(X, mha), iters=20, warmup=5)  # naive is O(N^2); use fewer iters

report("torch_mha", torch_ms, N, D)
report("naive_mha", naive_ms, N, D)

max abs err: 6.103515625e-05
mean abs err: 4.649162292480469e-06
     torch_mha:    2.959 ms |  2768683.4 tokens/s
     naive_mha:    9.990 ms |   820018.2 tokens/s
