In [137]:
import triton
import triton.language as tl
import torch

@triton.jit
def _ctxt_fwd_inner(
    O_block,
    s_i,
    Q_block,
    K_block_ptr,
    V_block_ptr,
    P_block_ptr,
    index_block_q,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr,
    offs_q: tl.constexpr,
    offs_kv: tl.constexpr,
    SEQ_LEN: tl.constexpr,
):
    if STAGE == 1:
        lo, hi = 0, index_block_q * BLOCK_SIZE_Q
    elif STAGE == 2:
        lo, hi = index_block_q * BLOCK_SIZE_Q, (index_block_q + 1) * BLOCK_SIZE_Q
        lo = tl.multiple_of(lo, BLOCK_SIZE_Q)
    else:
        lo, hi = 0, SEQ_LEN

    K_block_ptr = tl.advance(K_block_ptr, (0, lo))
    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))

    for start_kv in tl.range(lo, hi, BLOCK_SIZE_KV):
        start_kv = tl.multiple_of(start_kv, BLOCK_SIZE_KV)

        K_block  = tl.load(K_block_ptr)
        QK_block = tl.dot(Q_block, K_block)
        
        if STAGE == 2:
            mask = offs_q[:, None] >= (start_kv + offs_kv[None, :])
            QK_block = QK_block * mask

        s_ij = tl.sum(QK_block, axis=1) + s_i

        V_block = tl.load(V_block_ptr)
        # QK_block = QK_block.to(tl.float16)
        # O_block = tl.dot(QK_block, V_block, O_block)
        O_block = tl.dot(QK_block.to(V_block.dtype), V_block, O_block)

        s_i = s_ij

        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_SIZE_KV, 0))
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_SIZE_KV))
        # P_block_ptr = tl.advance(P_block_ptr, (0, BLOCK_SIZE_KV))
    
    return O_block, s_i

@triton.autotune(
    [
        triton.Config(
            {"BLOCK_SIZE_Q": BLOCK_SIZE_Q, "BLOCK_SIZE_KV": BLOCK_SIZE_KV},
            num_stages=num_stages,
            num_warps=num_warps,
        )
        for BLOCK_SIZE_Q in [16]
        for BLOCK_SIZE_KV in [16]
        for num_stages in ([3])
        for num_warps in [2]
    ],
    key=["SEQ_LEN", "DIM"],
)
@triton.jit
def _ctxt_fwd(
    Q,  # BATCH_SIZE, SEQ_LEN, DIM
    K,  # BATCH_SIZE, SEQ_LEN, DIM
    V,  # BATCH_SIZE, SEQ_LEN, DIM
    O,  # BATCH_SIZE, SEQ_LEN, DIM
    P,  # SEQ_LEN, SEQ_LEN
    S,  # BATCH_SIZE, SEQ_LEN
    stride_Q_batch,
    stride_Q_seq,
    stride_Q_dim,
    stride_K_batch,
    stride_K_seq,
    stride_K_dim,
    stride_V_batch,
    stride_V_seq,
    stride_V_dim,
    stride_O_batch,
    stride_O_seq,
    stride_O_dim,
    stride_P_row,
    stride_P_col,
    BATCH_SIZE,
    SEQ_LEN: tl.constexpr,
    DIM: tl.constexpr,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr,
):
    tl.static_assert(BLOCK_SIZE_KV <= DIM)
    
    index_block_q = tl.program_id(0)
    index_batch   = tl.program_id(1)

    qkv_offset = index_batch.to(tl.int64) * stride_Q_batch

    Q_block_ptr = tl.make_block_ptr(
        base=Q + qkv_offset,
        shape=(SEQ_LEN, DIM),
        strides=(stride_Q_seq, stride_Q_dim),
        offsets=(index_block_q * BLOCK_SIZE_Q, 0),
        block_shape=(BLOCK_SIZE_Q, DIM),
        order=(1, 0), 
    )

    V_block_ptr = tl.make_block_ptr(
        base=V + qkv_offset,
        shape=(SEQ_LEN, DIM),
        strides=(stride_V_seq, stride_V_dim),
        offsets=(0, 0),
        block_shape=(BLOCK_SIZE_KV, DIM),
        order=(1, 0)
    )

    K_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(DIM, SEQ_LEN),
        # invert the strides w.r.t Q, so we transpose the matrix
        strides=(stride_K_dim, stride_K_seq),
        offsets=(0, 0),
        block_shape=(DIM, BLOCK_SIZE_KV),
        order=(0, 1),
    )

    P_block_ptr = tl.make_block_ptr(
        base=P,
        shape=(SEQ_LEN, SEQ_LEN),
        strides=(stride_P_row, stride_P_col),
        offsets=(index_block_q * BLOCK_SIZE_Q, 0),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KV),
        order=(1, 0)
    )

    O_block_ptr = tl.make_block_ptr(
        base=O + qkv_offset,
        shape=(SEQ_LEN, DIM),
        strides=(stride_O_seq, stride_O_dim),
        offsets=(index_block_q * BLOCK_SIZE_Q, 0),
        block_shape=(BLOCK_SIZE_Q, DIM),
        order=(1, 0),
    )

    offs_q  = index_block_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
    offs_kv = tl.arange(0, BLOCK_SIZE_KV)

    s_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32)
    O_block = tl.zeros([BLOCK_SIZE_Q, DIM], dtype=tl.float32)

    Q_block = tl.load(Q_block_ptr)

    if STAGE == 1 or STAGE == 3:
        # non-causal attention or blocks to the left of the diagonal in causal attention
        O_block, s_i = _ctxt_fwd_inner(
            O_block,
            s_i,
            Q_block,
            K_block_ptr,
            V_block_ptr,
            P_block_ptr,
            index_block_q,
            BLOCK_SIZE_Q,
            BLOCK_SIZE_KV,
            4 - STAGE,
            offs_q,
            offs_kv,
            SEQ_LEN,
        )

    if STAGE == 3:
        # blocks to the right of the diagonal in causal attention
        O_block, s_i = _ctxt_fwd_inner(
            O_block,
            s_i,
            Q_block,
            K_block_ptr,
            V_block_ptr,
            P_block_ptr,
            index_block_q,
            BLOCK_SIZE_Q,
            BLOCK_SIZE_KV,
            2,
            offs_q,
            offs_kv,
            SEQ_LEN,
        )
    
    O_block = O_block / (s_i[:, None] + 1e-8)
    tl.store(O_block_ptr, O_block.to(O.type.element_ty))

    s_ptrs = S + index_batch * SEQ_LEN + offs_q
    tl.store(s_ptrs, s_i)

class TritonCtxt(torch.autograd.Function):

    @staticmethod
    def forward(ctx, Q, K, V, P, causal):
        HEAD_DIM_Q, HEAD_DIM_K = Q.shape[-1], K.shape[-1]
        HEAD_DIM_V = V.shape[-1]

        BATCH_SIZE, SEQ_LEN, HEAD_DIM = Q.shape

        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V

        O = torch.empty_like(Q)
        stage = 3 if causal else 1

        grid = lambda args: (
            triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]),
            BATCH_SIZE,
            1,
        )

        # S is the logsumexp for the backward pass, one for each query
        S = torch.empty(
            (BATCH_SIZE, SEQ_LEN), device=Q.device, dtype=torch.float32
        )

        _ctxt_fwd[grid](
            Q=Q,
            K=K,
            V=V,
            O=O,
            P=P,
            S=S,
            stride_Q_batch=Q.stride(0),
            stride_Q_seq=Q.stride(1),
            stride_Q_dim=Q.stride(2),
            stride_K_batch=K.stride(0),
            stride_K_seq=K.stride(1),
            stride_K_dim=K.stride(2),
            stride_V_batch=V.stride(0),
            stride_V_seq=V.stride(1),
            stride_V_dim=V.stride(2),
            stride_O_batch=O.stride(0),
            stride_O_seq=O.stride(1),
            stride_O_dim=O.stride(2),
            stride_P_row=P.stride(0),
            stride_P_col=P.stride(1),
            BATCH_SIZE=Q.shape[0],
            SEQ_LEN=Q.shape[1],
            DIM=HEAD_DIM_K,
            STAGE=stage,
        )

        ctx.save_for_backward(Q, K, V, P, O, S)
        ctx.grid = grid
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.causal = causal
        return O, S

    @staticmethod
    def backward(ctx, dO):
        Q, K, V, P, O, S = ctx.saved_tensors
        return None, None, None, None, None

def test_op(BATCH_SIZE, SEQ_LEN, HEAD_DIM, causal, dtype=torch.float16):
    Q = (
        torch.empty(
            (BATCH_SIZE, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_(False)
    )

    K = (
        torch.empty(
            (BATCH_SIZE, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_(False)
    )
    
    V = (
        torch.empty(
            (BATCH_SIZE, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_(False)
    )

    P = (
        torch.empty(
            (SEQ_LEN, SEQ_LEN), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_(False)
    )
    P = torch.nn.init.xavier_normal_(P)

    norm = torch.sqrt(torch.sum(Q ** 2, dim=-1, keepdim=True) + 1e-8)
    Q = Q / norm
    # Q = torch.nn.functional.normalize(Q, dim=-1)
    norm = torch.sqrt(torch.sum(K ** 2, dim=-1, keepdim=True) + 1e-8)
    K = K / norm
    # K = torch.nn.functional.normalize(K, dim=-1)

    # reference implementation
    cosim = torch.matmul(Q, K.transpose(-1, -2))
    # row_sums = torch.tril(cosim).sum(dim=-1, keepdim=True)
    row_sums = cosim.sum(dim=-1, keepdim=True)
    cosim_scores = cosim / (row_sums + 1e-8)
    ref_O = (cosim_scores) @ V
    ref_no_norm = cosim @ V

    print("refernce is complete")

    # triton implementation
    tri_out, s = TritonCtxt.apply(Q, K, V, P, causal)
    tri_out = tri_out.half()

    print("triton is complete")
    # compare
    rtol = 0.0
    atol = 1e-2
    torch.set_printoptions(threshold=float('inf'))
    try:
        assert torch.allclose(ref_O, tri_out, atol=atol, rtol=rtol)
        print("Passed ...")
    except:
        print("Failed ...")
        diff = (ref_O - tri_out).abs()
        mask = diff > atol
        print("Differences (abs > atol):")
        print("Indices:", mask.nonzero(as_tuple=True))
        print("Reference values:\n", ref_O[mask][:10])
        print("Triton values:\n", tri_out[mask][:10])
        print("Diffs:\n", diff[mask][:10])

    # print("My sums")
    # print(s[:, :, None])
    # print("Torch sums")
    # print(row_sums)
    print("Sums diff")
    print((s[:, :, None] - row_sums).abs())
    
    # print("My cosim")
    # print(tri_out[:, :1, :])
    # print("Torch cosim")
    # print(ref_no_norm[:, :1, :])
    print("Cosim diff")
    print((ref_O[:, :, :] - tri_out[:, :, :]).abs())

test_op(BATCH_SIZE=1, SEQ_LEN=16, HEAD_DIM=16, causal=False)

refernce is complete
triton is complete
Failed ...
Differences (abs > atol):
Indices: (tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0'), tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0'), tensor([ 0,  1,  2,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
       device='cuda:0'))
Reference values:
 tensor([  3.7129,  45.9375, -20.0000, -43.8125, -25.3438, -29.8594,   7.2305,
         46.4688, -35.1562, -48.3125], device='cuda:0', dtype=torch.float16)
Triton values:
 tensor([  3.6875,  45.5938, -19.8594, -43.5000, -25.1719, -29.6406,   7.1797,
         46.1250, -34.9375, -47.9688], device='cuda:0', dtype=torch.float16)
Diffs:
 tensor([0.0254, 0.3438, 0.1406, 0.3125, 0.1719, 0.2188, 0.0508, 0.3438, 0.2188,
        0.3438], device='cuda:0', dtype=torch.float16)
Sums diff
tensor([[[2.0871e-04],
         [1.8287e-04],
         [1.1675e-03],
         [3.0988e-04],
         [4.8286e-04],
         [3.2783e-06],
         [1.6522e-04],
         [6.3121

In [3]:
from tabulate import tabulate

def naive_contextualizer(Q, K, V, P, causal):
    """
    The memory-inefficient naive implementation.
    """
    cosim = torch.matmul(Q, K.transpose(-1, -2))
    
    # The user's formula uses torch.tril(P). If P is already causal,
    # this is redundant. We assume P needs the mask applied.
    if causal:
        # We apply the causal mask to the P matrix before multiplication
        masked_P = torch.tril(P)
        scores = masked_P * cosim
    else:
        scores = P * cosim
        
    output = torch.matmul(scores, V)
    return output

# ==============================================================================
# Benchmarking Code
# ==============================================================================

def benchmark_forward_pass(func, func_name, *args):
    """
    Benchmarks the forward pass of a function for latency and memory.
    """
    # --- Measure Latency using Triton's utility ---
    # `do_bench` returns the median time in ms.
    latency_ms = triton.testing.do_bench(lambda: func(*args))

    # --- Measure Peak Memory ---
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats("cuda")
    func(*args) # Run one forward pass to measure memory
    torch.cuda.synchronize()
    peak_memory_mib = torch.cuda.max_memory_allocated("cuda") / (1024 * 1024)

    print(
        f"Finished benchmarking {func_name:<20} | "
        f"Latency: {latency_ms:6.3f} ms | "
        f"Peak Memory: {peak_memory_mib:8.2f} MiB"
    )
    return latency_ms, peak_memory_mib


In [4]:


# --- Configuration ---
BATCH_SIZE = 4
SEQ_LEN = 8192
HEAD_DIM = 128
DTYPE = torch.float16
DEVICE = "cuda"

# --- Create Tensors ---
# Note: Triton kernel expects contiguous tensors
Q = torch.randn(BATCH_SIZE, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE).contiguous()
V = torch.randn(BATCH_SIZE, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE).contiguous()
P = torch.randn(SEQ_LEN, SEQ_LEN, dtype=DTYPE, device=DEVICE).contiguous()

# --- Run Benchmarks ---
results = []
for causal_flag in [True]:
    causal_str = "Causal" if causal_flag else "Non-Causal"
    print("-" * 80)
    print(f"Benchmarking {causal_str} Forward Pass")
    print("-" * 80)

    # Benchmark Naive Implementation
    naive_latency, naive_memory = benchmark_forward_pass(
        naive_contextualizer, "Naive PyTorch", Q, Q, V, P, causal_flag
    )
    results.append(["Naive PyTorch", causal_str, naive_latency, naive_memory])

    # Benchmark Triton Implementation
    triton_latency, triton_memory = benchmark_forward_pass(
        TritonCtxt.apply, "Triton Custom", Q, Q, V, P, causal_flag
    )
    results.append(["Triton Custom", causal_str, triton_latency, triton_memory])

# --- Print Final Table ---
print("\n" * 2)
print("=" * 80)
print(" " * 20 + "Forward Pass Benchmark Results")
print("=" * 80)
print(
    tabulate(
        results,
        headers=["Implementation", "Type", "Latency (ms)", "Peak Memory (MiB)"],
        tablefmt="pretty",
        floatfmt=".3f",
    )
)

--------------------------------------------------------------------------------
Benchmarking Causal Forward Pass
--------------------------------------------------------------------------------
Finished benchmarking Naive PyTorch        | Latency: 12.476 ms | Peak Memory:  1312.12 MiB
Finished benchmarking Triton Custom        | Latency:  2.687 ms | Peak Memory:   160.25 MiB



                    Forward Pass Benchmark Results
+----------------+--------+--------------------+-------------------+
| Implementation |  Type  |    Latency (ms)    | Peak Memory (MiB) |
+----------------+--------+--------------------+-------------------+
| Naive PyTorch  | Causal | 12.475830895560128 |     1312.125      |
| Triton Custom  | Causal | 2.6867278055711226 |      160.25       |
+----------------+--------+--------------------+-------------------+


In [5]:
# --- Configuration ---
BATCH_SIZE = 4
SEQ_LEN = 16384
HEAD_DIM = 128
DTYPE = torch.float16
DEVICE = "cuda"

# --- Create Tensors ---
# Note: Triton kernel expects contiguous tensors
Q = torch.randn(BATCH_SIZE, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE).contiguous()
V = torch.randn(BATCH_SIZE, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE).contiguous()
P = torch.randn(SEQ_LEN, SEQ_LEN, dtype=DTYPE, device=DEVICE).contiguous()

# --- Run Benchmarks ---
results = []
for causal_flag in [True]:
    causal_str = "Causal" if causal_flag else "Non-Causal"
    print("-" * 80)
    print(f"Benchmarking {causal_str} Forward Pass")
    print("-" * 80)

    # Benchmark Naive Implementation
    naive_latency, naive_memory = benchmark_forward_pass(
        naive_contextualizer, "Naive PyTorch", Q, Q, V, P, causal_flag
    )
    results.append(["Naive PyTorch", causal_str, naive_latency, naive_memory])

    # Benchmark Triton Implementation
    triton_latency, triton_memory = benchmark_forward_pass(
        TritonCtxt.apply, "Triton Custom", Q, Q, V, P, causal_flag
    )
    results.append(["Triton Custom", causal_str, triton_latency, triton_memory])

# --- Print Final Table ---
print("\n" * 2)
print("=" * 80)
print(" " * 20 + "Forward Pass Benchmark Results")
print("=" * 80)
print(
    tabulate(
        results,
        headers=["Implementation", "Type", "Latency (ms)", "Peak Memory (MiB)"],
        tablefmt="pretty",
        floatfmt=".3f",
    )
)

--------------------------------------------------------------------------------
Benchmarking Causal Forward Pass
--------------------------------------------------------------------------------


Finished benchmarking Naive PyTorch        | Latency: 48.971 ms | Peak Memory:  5176.12 MiB
Finished benchmarking Triton Custom        | Latency: 10.375 ms | Peak Memory:   568.38 MiB



                    Forward Pass Benchmark Results
+----------------+--------+--------------------+-------------------+
| Implementation |  Type  |    Latency (ms)    | Peak Memory (MiB) |
+----------------+--------+--------------------+-------------------+
| Naive PyTorch  | Causal | 48.97075271606445  |     5176.125      |
| Triton Custom  | Causal | 10.375074820085006 |      568.375      |
+----------------+--------+--------------------+-------------------+


In [None]:
# --- Configuration ---
BATCH_SIZE = 1
SEQ_LEN = 32768
HEAD_DIM = 128
DTYPE = torch.float16
DEVICE = "cuda"

# --- Create Tensors ---
# Note: Triton kernel expects contiguous tensors
Q = torch.randn(BATCH_SIZE, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE).contiguous()
V = torch.randn(BATCH_SIZE, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE).contiguous()
P = torch.randn(SEQ_LEN, SEQ_LEN, dtype=DTYPE, device=DEVICE).contiguous()

# --- Run Benchmarks ---
results = []
for causal_flag in [True]:
    causal_str = "Causal" if causal_flag else "Non-Causal"
    print("-" * 80)
    print(f"Benchmarking {causal_str} Forward Pass")
    print("-" * 80)

    # Benchmark Naive Implementation
    naive_latency, naive_memory = benchmark_forward_pass(
        naive_contextualizer, "Naive PyTorch", Q, Q, V, P, causal_flag
    )
    results.append(["Naive PyTorch", causal_str, naive_latency, naive_memory])

    # Benchmark Triton Implementation
    triton_latency, triton_memory = benchmark_forward_pass(
        TritonCtxt.apply, "Triton Custom", Q, Q, V, P, causal_flag
    )
    results.append(["Triton Custom", causal_str, triton_latency, triton_memory])

# --- Print Final Table ---
print("\n" * 2)
print("=" * 80)
print(" " * 20 + "Forward Pass Benchmark Results")
print("=" * 80)
print(
    tabulate(
        results,
        headers=["Implementation", "Type", "Latency (ms)", "Peak Memory (MiB)"],
        tablefmt="pretty",
        floatfmt=".3f",
    )
)