# Triton Puzzles



In [6]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import triton_viz
from triton_viz.interpreter import record_builder
import jaxtyping 
import inspect
from jaxtyping import Float

def test(puzzle, puzzle_spec, nelem={}, B={"B0": 32}):
    B = dict(B)
    if "N1" in nelem:
        B["B1"] = 32
    if "N2" in nelem:
        B["B2"] = 32
        
    triton_viz.interpreter.record_builder.reset()
    torch.manual_seed(0)
    signature = inspect.signature(puzzle_spec)
    args = {}
    for n, p in signature.parameters.items():
        args[n + "_ptr"] = [d.size for d in p.annotation.dims]
    args["z_ptr"] = [d.size for d in signature.return_annotation.dims]
    
    tt_args = []
    for k, v in args.items():
        tt_args.append(torch.rand(*v))
    grid = lambda meta: (triton.cdiv(nelem["N0"], meta["B0"]), 
                         triton.cdiv(nelem.get("N1", 1), meta.get("B1", 1)), 
                         triton.cdiv(nelem.get("N2", 1), meta.get("B2", 1)))   

    for k, v in args.items():
        print(k, v)
    triton_viz.trace(puzzle)[grid](*tt_args, **B, **nelem)
    z = tt_args[-1]
    tt_args = tt_args[:-1]
    z_ = puzzle_spec(*tt_args)
    match = torch.allclose(z, z_)
    print("Results match:",  match)
    if not match:
        print(z)
        print(z_)
    triton_viz.launch()

## Puzzle 1: Constant Add

Add a constant to a vector. Uses one program block. Block size `B0` is always the same as vector length `N0`.


In [3]:
def add_spec(x: Float[Tensor, "32"]) -> Float[Tensor, "32"]:
    return x + 10.

@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    range = tl.arange(0, B0)
    x = tl.load(x_ptr + range)
    z = x + 10
    z = tl.store(z_ptr + range, z)

test(add_kernel, add_spec, nelem={"N0": 32})

Results match: True
Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://3d9851881489f442ca.gradio.live


## Puzzle 2: Constant Add Block

Add a constant to a vector. Uses one program block. Block size `B0` is always the same as vector length `N0`.



In [4]:
def add2_spec(x: Float[Tensor, "200"]) -> Float[Tensor, "200"]:
    return x + 10.

@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    pid = tl.program_id(0)
    range = pid * B0 + tl.arange(0, B0)
    x = tl.load(x_ptr + range, range < N0, 0
               )
    z = x + 10
    z = tl.store(z_ptr + range, 
                 z, range < N0)
    
test(add_mask2_kernel, add2_spec, nelem={"N0": 200})

Results match: True
Running on local URL:  http://127.0.0.1:7861
Running on public URL: https://c6ee2b6ec3dcfcb629.gradio.live


## Puzzle 3: Outer Vector Add

Add two vectors. Uses one program block. Block size `B0` is always the same as vector `x` length `N0`.
Block size `B1` is always the same as vector `y` length `N1`.


In [7]:
def add_vec_spec(x: Float[Tensor, "32"], y: Float[Tensor, "32"]) -> Float[Tensor, "32 32"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    i_range = tl.arange(0, B0)[None, :] 
    j_range = tl.arange(0, B1)[:, None]
    
    x = tl.load(x_ptr + i_range)
    y = tl.load(y_ptr + j_range)
    
    z = x + y
    z = tl.store(z_ptr + i_range + B0 * j_range, z)
    
test(add_vec_kernel, add_vec_spec, nelem={"N0": 32, "N1": 32})

Results match: True
Running on local URL:  http://127.0.0.1:7863
Running on public URL: https://674c0a088a506af40a.gradio.live


## Puzzle 4: Outer Vector Add Block

Add two vectors. Uses one program block. Block size `B0` is always greater than the vector `x` length `N0`.
Block size `B1` is always greater than vector `y` length `N1`.


In [4]:
def add_vec_block_spec(x: Float[Tensor, "100"], y: Float[Tensor, "90"]) -> Float[Tensor, "90 100"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    i_range = tl.arange(0, B0)[:, None] + pid_0 * B0
    j_range = tl.arange(0, B1)[None, :] + pid_1 * B1
    
    x = tl.load(x_ptr + i_range, i_range < N0, 0)
    y = tl.load(y_ptr + j_range, j_range < N1, 0)
    
    z = x + y
    z = tl.store(z_ptr + i_range + N0 * j_range, z, (i_range < N0) & (j_range < N1))
    
test(add_vec_block_kernel, add_vec_block_spec, nelem={"N0": 100, "N1": 90})

Results match: True
Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://9706a1f20c839923f8.gradio.live


## Puzzle 5: Fused Op

In [5]:
def mul_relu_block_spec(x: Float[Tensor, "100"], y: Float[Tensor, "90"]) -> Float[Tensor, "90 100"]:
    return torch.relu(x[None, :] * y[:, None])

@triton.jit
def mul_relu_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    i_range = tl.arange(0, B0)[:, None] + pid_0 * B0
    j_range = tl.arange(0, B1)[None, :] + pid_1 * B1
    
    x = tl.load(x_ptr + i_range, i_range < N0, 0)
    y = tl.load(y_ptr + j_range, j_range < N1, 0)
    
    z = x * y
    z = tl.where(z > 0, z, 0)
    
    z = tl.store(z_ptr + i_range + N0 * j_range, z, (i_range < N0) & (j_range < N1))
    
test(mul_relu_block_kernel, mul_relu_block_spec, nelem={"N0": 100, "N1": 90})

x_ptr [100]
y_ptr [90]
z_ptr [90, 100]
Results match: True
Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://4dbb2aeb4405a9b24f.gradio.live


## Puzzle 6: Fused Op Backwards

In [5]:
def mul_relu_block_back_spec(x: Float[Tensor, "90 100"], y: Float[Tensor, "90"], dz: Float[Tensor, "90 100"]) -> Float[Tensor, "90 100"]:
    x = x.clone()
    y = y.clone()
    x = x.requires_grad_(True)
    y = y.requires_grad_(True)
    z = torch.relu(x * y[:, None])
    z.backward(dz)
    dx = x.grad
    return dx

@triton.jit
def mul_relu_block_back_kernel(x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    i_range = tl.arange(0, B0)[None, :] + pid_0 * B0
    j_range = tl.arange(0, B1)[:, None] + pid_1 * B1
    
    x = tl.load(x_ptr + i_range + N0 * j_range, (i_range < N0) & (j_range < N1), 0)
    y = tl.load(y_ptr + j_range, j_range < N1, 0)

    # Forward
    z = x * y
    dz = tl.load(dz_ptr + i_range + N0 * j_range, (i_range < N0) & (j_range < N1), 0)
    dr = tl.where(z > 0, dz, 0)
    dx = dr * y
    tl.store(dx_ptr + i_range + N0 * j_range, dx, (i_range < N0) & (j_range < N1))
    
test(mul_relu_block_back_kernel, mul_relu_block_back_spec, nelem={"N0": 100, "N1": 90})


x_ptr [90, 100]
y_ptr [90]
dz_ptr [90, 100]
z_ptr [90, 100]
Results match: True
Running on local URL:  http://127.0.0.1:7862
Running on public URL: https://fd12f8adccb27f9f24.gradio.live


## Puzzle 7: Fused Softmax

In [8]:
def softmax_spec(x: Float[Tensor, "4 200"]) -> Float[Tensor, "4 200"]:
    x_max = x.max(1, keepdim=True)[0]
    x = x - x_max
    x_exp = x.exp() 
    return x_exp / x_exp.sum(1, keepdim=True)

@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, TN1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    x_max = -1e9
    for i in range(0, TN1, B1):
        i_range = tl.arange(0, B1)[None, :] + i
        offset = TN1 * pid_0 + i_range
        x = tl.load(x_ptr + offset, i_range < TN1, -1e9)
        chunk_max = tl.max(x, 1)[:, None]
        x_max = tl.where(chunk_max > x_max, chunk_max, x_max)

    partition = 0
    for i in range(0, TN1, B1):
        i_range = tl.arange(0, B1)[None, :] + i
        offset = TN1 * pid_0 + i_range
        x = tl.load(x_ptr + offset, i_range < TN1, -1e9) - x_max
        partition = partition + tl.sum(tl.exp(x), 1)[:, None]
    
    for i in range(0, TN1, B1):
        i_range = tl.arange(0, B1)[None, :] + i
        offset = TN1 * pid_0 + i_range
        x = tl.load(x_ptr + offset, i_range < TN1, -1e9) - x_max
        x_exp = tl.exp(x)
        z = x_exp / partition
        tl.store(z_ptr + offset, z, i_range < TN1)
    
test(softmax_kernel, softmax_spec, B={"B0": 1, "B1": 32}, 
     nelem={"N0": 4, "N1": 32, "TN1": 200})


x_ptr [4, 200]
z_ptr [4, 200]
Results match: True
Running on local URL:  http://127.0.0.1:7864
Running on public URL: https://c088e7020ccd7d9071.gradio.live


## Puzzle 8: Manual Conv.

In [3]:
def conv2d_spec(x: Float[Tensor, "4 8 8"], k: Float[Tensor, "4 4"]) -> Float[Tensor, "4 8 8"]:
    return x

@triton.jit
def conv2d_kernel(x_ptr, k_ptr, z_ptr, N0, N1, H, W, KH: tl.constexpr, KW: tl.constexpr, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    kh_range = tl.arange(0, KH)[None, :, None]
    kw_range = tl.arange(0, KW)[None, None, :]
    k = tl.load(k_ptr + kh_range * KW + kw_range)
    for i in range(0, H):
        for j in range(0, W):
            x = tl.load(x_ptr + pid_0 * H * W + (kh_range + i) * W + (kw_range + j), 
                        ((kh_range + i) < H) & ((kw_range + j) < W), 0)
            out = tl.sum(tl.sum(x * k, 2), 1)
            tl.store(z_ptr + pid_0 * H * W + i*W + j + tl.arange(0,1), out)
    
test(conv2d_kernel, conv2d_spec, B={"B0": 1, "B1": 32}, nelem={"N0": 4, "N1": 32, "H": 8, "W": 8, "KH": 4, "KW": 4})

x_ptr [4, 8, 8]
k_ptr [4, 4]
z_ptr [4, 8, 8]
Results match: False
tensor([[[3.7979e+00, 3.1989e+00, 2.9189e+00, 2.2978e+00, 3.7707e+00,
          3.1057e+00, 2.5557e+00, 1.5586e+00],
         [2.9854e+00, 2.7109e+00, 3.3419e+00, 2.7745e+00, 4.3589e+00,
          3.7183e+00, 2.7992e+00, 1.4989e+00],
         [3.0391e+00, 2.3854e+00, 2.6818e+00, 2.8011e+00, 4.2043e+00,
          3.6416e+00, 3.4220e+00, 1.4298e+00],
         [2.7617e+00, 2.9887e+00, 3.5100e+00, 4.2107e+00, 4.7459e+00,
          3.6242e+00, 3.1270e+00, 8.8920e-01],
         [3.0284e+00, 2.6824e+00, 2.8830e+00, 4.1134e+00, 4.9645e+00,
          3.7006e+00, 2.5766e+00, 9.0460e-01],
         [2.9366e+00, 3.2696e+00, 2.9984e+00, 3.0241e+00, 3.1591e+00,
          2.5144e+00, 1.2738e+00, 4.9069e-01],
         [1.2750e+00, 1.9437e+00, 2.4010e+00, 2.1371e+00, 2.1035e+00,
          1.4691e+00, 7.0879e-01, 1.3890e-01],
         [1.6424e-01, 2.9547e-01, 4.8142e-01, 7.4419e-01, 6.7324e-01,
          3.2021e-01, 4.7337e-01, 2.8814e-03]

## Puzzle 9: Matrix Mult

In [12]:
def dot_spec(x: Float[Tensor, "4 32 32"], y: Float[Tensor, "4 32 32"]) -> Float[Tensor, "4 32 32"]:
    return x @ y
    
@triton.jit
def dot_kernel(x_ptr, y_ptr, z_ptr, N0, N1, N2, MID, B0: tl.constexpr, B1: tl.constexpr, B2: tl.constexpr):
    r = tl.program_id(0) * B0
    rows = tl.arange(0, B0) [:, None]
    c = tl.program_id(1) * B1
    cols = tl.arange(0, B1)[None, :]
    mids = tl.arange(0, B0)
    b = tl.program_id(2)
    bid = b * N0 * N1
    z = 0
    for i in range(0, MID, B0): 
        x_val = tl.load(
            x_ptr
            + bid
            + (r + rows) * MID
            + (mids + i) 
            (i + mids < MID) & (rows + r < N0), 0
        )
        y_val = tl.load(
            y_ptr
            + bid
            + (i + mids[:, None]) * N1
            + (cols + c),
            (i + mids[:, None] < MID) & (cols + c < N1)
            , 0
        )
        z = z + tl.dot(x_val, y_val)
    tl.store(
        z_ptr
        + (b * N0 * N1)
        + (r + rows) * N1
        + (cols + c),
        z,
        mask= (r + rows) < N0  & (c + cols) < N1
    )
test(dot_kernel, dot_spec, B={"B0": 16, "B1": 16, "B2": 1}, nelem={"N0": 32, "N1": 32, "N2": 4, "MID": 32})


x_ptr [4, 32, 32]
y_ptr [4, 32, 32]
z_ptr [4, 32, 32]
Results match: False
tensor([[[ 8.3391,  7.9111,  8.4097,  ...,  6.7225,  8.9109,  8.1703],
         [ 7.6265,  7.7135,  7.8818,  ...,  5.5574,  8.3774,  9.3551],
         [ 9.3186,  8.2885,  8.0333,  ...,  6.4727,  9.5322,  9.6548],
         ...,
         [ 7.6429,  8.3466,  7.9423,  ...,  6.2658,  8.3070,  8.1512],
         [ 8.1154,  7.7465,  8.8425,  ...,  6.6136,  8.6476, 10.4751],
         [ 7.4778,  7.5884,  7.1350,  ...,  6.0462,  8.4397,  8.8035]],

        [[ 0.3006,  0.5485,  0.2660,  ...,  0.3626,  0.5522,  0.8689],
         [ 0.7159,  0.0853,  0.7634,  ...,  0.1857,  0.7333,  0.9228],
         [ 0.2173,  0.3993,  0.9514,  ...,  0.1072,  0.5109,  0.9906],
         ...,
         [ 0.7561,  0.7053,  0.1599,  ...,  0.2121,  0.5966,  0.9737],
         [ 0.5935,  0.1449,  0.1956,  ...,  0.2828,  0.7119,  0.4724],
         [ 0.5482,  0.1261,  0.2800,  ...,  0.6109,  0.8317,  0.9479]],

        [[ 0.8422,  0.4910,  0.0869,  ...

## Puzzle 10: Quantized Matrix Mult 

GPT-Q like puzzles

In [None]:
FPINT = 32 // 4
GROUP = 16

def quant_dot_spec(scale : Float[Tensor, "64 4"], 
                   offset : Int[Tensor, "64 1"], 
                   weight: Int[Tensor, "64 8"], 
                   activation: Float[Tensor, "4 32 32"]) -> Float[Tensor, "4 32 32"]:
    return (scale * (weight - offset[:, :, None].expand(32, 4, 8).view(32, 32)))  @ activation
    
@triton.jit
def quant_dot_kernel(scale_ptr, offset_ptr, weight_ptr, activation_ptr,
                     z_ptr, N0, N1, N2, MID, B0: tl.constexpr, B1: tl.constexpr, B2: tl.constexpr):
    r = tl.program_id(0) * B0
    rows = tl.arange(0, B0) [:, None]
    c = tl.program_id(1) * B1
    cols = tl.arange(0, B1)[None, :]
    mids = tl.arange(0, B0)

    b = tl.program_id(2)
    bid = b * N0 * N1

    z = 0
    for i in range(0, MID, B0): 
        scale_val = tl.load(
            scale_ptr
            + (r + rows) * (MID // GROUP)
            + (mids + i) // GROUP,
            (i + mids < MID // GROUP) & (rows + r < N0), 0
        )
        scale_val = scale_val
        
        offset_val = tl.load(
            offset_ptr
            + (r + rows) * MID // (FPINT * GROUP)
            + (mids + i) // (FPINT * GROUP),
            (i + mids < MID // (FPINT * GROUP)) & (rows + r < N0), 0
        )

        v = (tl.arange(0, MID) % FPINT) * 4as
        offset_val = (offset_val >> v) & (16)
        
        weight_val = tl.load(
            weight_ptr
            + (r + rows) * MID // FPINT
            + (mids + i) // FPINT,
            (i + mids < MID // FPINT) & (rows + r < N0), 0
        )
        v = (tl.arange(0, MID) % FPINT) * 4
        weight_val = (weight_val >> v) & (16)
        weight = (weight_val - offset_val) * scale_val
        
        activation_val = tl.load(
            activation_ptr
            + bid
            + (i + mids[:, None]) * N1
            + (cols + c),
            (i + mids[:, None] < MID) & (cols + c < N1)
            , 0
        )
        z = z + tl.dot(weight, activation_val)
    tl.store(
        z_ptr
        + (b * N0 * N1)
        + (r + rows) * N1
        + (cols + c),
        z,
        mask= (r + rows) < N0  & (c + cols) < N1
    )
test(quant_dot_kernel, quannt_dot_spec, B={"B0": 16, "B1": 16, "B2": 1}, nelem={"N0": 64, "N1": 64, "N2": 4, "MID": 64})


## Puzzle 11: Flash Attention 

Long reduction. 