# Triton Puzzles



In [15]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import triton_viz
from triton_viz.interpreter import record_builder
import jaxtyping 
import inspect
from jaxtyping import Float32, Int32

def test(puzzle, puzzle_spec, nelem={}, B={"B0": 32}):
    B = dict(B)
    if "N1" in nelem and "B1" not in B:
        B["B1"] = 32
    if "N2" in nelem and "B2" not in B:
        B["B2"] = 32
        
    triton_viz.interpreter.record_builder.reset()
    torch.manual_seed(0)
    signature = inspect.signature(puzzle_spec)
    args = {}
    for n, p in signature.parameters.items():
        print(p)
        args[n + "_ptr"] = ([d.size for d in p.annotation.dims], p)
    args["z_ptr"] = ([d.size for d in signature.return_annotation.dims], None)
    
    tt_args = []
    for k, (v, t) in args.items():
        tt_args.append(torch.rand(*v))
        if t is not None and t.annotation.dtypes[0] == "int32":
            tt_args[-1] = torch.randint(-100000, 100000, v)
    grid = lambda meta: (triton.cdiv(nelem["N0"], meta["B0"]), 
                         triton.cdiv(nelem.get("N1", 1), meta.get("B1", 1)), 
                         triton.cdiv(nelem.get("N2", 1), meta.get("B2", 1)))   

    #for k, v in args.items():
    #    print(k, v)
    triton_viz.trace(puzzle)[grid](*tt_args, **B, **nelem)
    z = tt_args[-1]
    tt_args = tt_args[:-1]
    z_ = puzzle_spec(*tt_args)
    match = torch.allclose(z, z_, rtol=1e-3, atol=1e-3)
    print("Results match:",  match)
    if not match:
        print("Yours:", z)
        print("Spec:", z_)
        print(torch.isclose(z, z_))
    triton_viz.launch()

## Puzzle 1: Constant Add

Add a constant to a vector. Uses one program block. Block size `B0` is always the same as vector `x` with length `N0`.


$$z_i = 10 + x_i \text{ for } i = 1\ldots N_0$$


In [2]:
def add_spec(x: Float32[Tensor, "32"]) -> Float32[Tensor, "32"]:
    return x + 10.

@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    range = tl.arange(0, B0)
    x = tl.load(x_ptr + range)
    z = x + 10
    z = tl.store(z_ptr + range, z)

test(add_kernel, add_spec, nelem={"N0": 32})

x: jaxtyping.Float32[Tensor, '32']
x_ptr ([32], <Parameter "x: jaxtyping.Float32[Tensor, '32']">)
z_ptr ([32], None)
Results match: True
Running on local URL:  http://127.0.0.1:7860


## Puzzle 2: Constant Add Block

Add a constant to a vector. Uses one program block. Block size `B0` is now smaller than the shape vector `x` which is `N0`.


$$z_i = 10 + x_i \text{ for } i = 1\ldots N_0$$



In [3]:
def add2_spec(x: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]:
    return x + 10.

@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    pid = tl.program_id(0)
    range = pid * B0 + tl.arange(0, B0)
    x = tl.load(x_ptr + range, range < N0, 0
               )
    z = x + 10
    z = tl.store(z_ptr + range, 
                 z, range < N0)
    
test(add_mask2_kernel, add2_spec, nelem={"N0": 200})

x: jaxtyping.Float32[Tensor, '200']
x_ptr ([200], <Parameter "x: jaxtyping.Float32[Tensor, '200']">)
z_ptr ([200], None)
Results match: True
Running on local URL:  http://127.0.0.1:7861


## Puzzle 3: Outer Vector Add

Add two vectors. 

Uses one program block. Block size `B0` is always the same as vector `x` length `N0`.
Block size `B1` is always the same as vector `y` length `N1`.


$$z_{i, j} = x_i + y_j\text{ for } i = 1\ldots B_0,\ j = 1\ldots B_1$$


In [4]:
def add_vec_spec(x: Float32[Tensor, "32"], y: Float32[Tensor, "32"]) -> Float32[Tensor, "32 32"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    i_range = tl.arange(0, B0)[None, :] 
    j_range = tl.arange(0, B1)[:, None]
    
    x = tl.load(x_ptr + i_range)
    y = tl.load(y_ptr + j_range)
    
    z = x + y
    z = tl.store(z_ptr + i_range + B0 * j_range, z)
    
test(add_vec_kernel, add_vec_spec, nelem={"N0": 32, "N1": 32})

x: jaxtyping.Float32[Tensor, '32']
y: jaxtyping.Float32[Tensor, '32']
x_ptr ([32], <Parameter "x: jaxtyping.Float32[Tensor, '32']">)
y_ptr ([32], <Parameter "y: jaxtyping.Float32[Tensor, '32']">)
z_ptr ([32, 32], None)
Results match: True
Running on local URL:  http://127.0.0.1:7862


## Puzzle 4: Outer Vector Add Block

Add a row vector to a column vector. 

Uses two program blocks. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`.

$$z_{i, j} = x_i + y_j\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$


In [5]:
def add_vec_block_spec(x: Float32[Tensor, "100"], y: Float32[Tensor, "90"]) -> Float32[Tensor, "90 100"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    i_range = tl.arange(0, B0)[:, None] + pid_0 * B0
    j_range = tl.arange(0, B1)[None, :] + pid_1 * B1
    
    x = tl.load(x_ptr + i_range, i_range < N0, 0)
    y = tl.load(y_ptr + j_range, j_range < N1, 0)
    
    z = x + y
    z = tl.store(z_ptr + i_range + N0 * j_range, z, (i_range < N0) & (j_range < N1))
    
test(add_vec_block_kernel, add_vec_block_spec, nelem={"N0": 100, "N1": 90})

x: jaxtyping.Float32[Tensor, '100']
y: jaxtyping.Float32[Tensor, '90']
x_ptr ([100], <Parameter "x: jaxtyping.Float32[Tensor, '100']">)
y_ptr ([90], <Parameter "y: jaxtyping.Float32[Tensor, '90']">)
z_ptr ([90, 100], None)
Results match: True
Running on local URL:  http://127.0.0.1:7863


## Puzzle 5: Fused Outer Multiplication

Multiply a row vector to a column vector and take a relu. 

Uses two program blocks. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`.

$$z_{i, j} = \text{relu}(x_i \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$



In [6]:
def mul_relu_block_spec(x: Float32[Tensor, "100"], y: Float32[Tensor, "90"]) -> Float32[Tensor, "90 100"]:
    return torch.relu(x[None, :] * y[:, None])

@triton.jit
def mul_relu_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    i_range = tl.arange(0, B0)[:, None] + pid_0 * B0
    j_range = tl.arange(0, B1)[None, :] + pid_1 * B1
    
    x = tl.load(x_ptr + i_range, i_range < N0, 0)
    y = tl.load(y_ptr + j_range, j_range < N1, 0)
    
    z = x * y
    z = tl.where(z > 0, z, 0)
    
    z = tl.store(z_ptr + i_range + N0 * j_range, z, (i_range < N0) & (j_range < N1))
    
test(mul_relu_block_kernel, mul_relu_block_spec, nelem={"N0": 100, "N1": 90})

x: jaxtyping.Float32[Tensor, '100']
y: jaxtyping.Float32[Tensor, '90']
x_ptr ([100], <Parameter "x: jaxtyping.Float32[Tensor, '100']">)
y_ptr ([90], <Parameter "y: jaxtyping.Float32[Tensor, '90']">)
z_ptr ([90, 100], None)
Results match: True
Running on local URL:  http://127.0.0.1:7864


## Puzzle 6: Fused Outer Multiplication - Backwards


Backwards of a function that multiplies a matrix with a row vector and take a relu. 

Uses two program blocks. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`. Chain rule backward `dz` 
is of shape `N0`

$$f(x, y) = \text{relu}(x_i \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$

$$dx_{i, j} = f_x'(x, y)_{i, j} \times dz_{i,j}$$

In [7]:
def mul_relu_block_back_spec(x: Float32[Tensor, "90 100"], y: Float32[Tensor, "90"], 
                             dz: Float32[Tensor, "90 100"]) -> Float32[Tensor, "90 100"]:
    x = x.clone()
    y = y.clone()
    x = x.requires_grad_(True)
    y = y.requires_grad_(True)
    z = torch.relu(x * y[:, None])
    z.backward(dz)
    dx = x.grad
    return dx

@triton.jit
def mul_relu_block_back_kernel(x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    i_range = tl.arange(0, B0)[None, :] + pid_0 * B0
    j_range = tl.arange(0, B1)[:, None] + pid_1 * B1
    
    x = tl.load(x_ptr + i_range + N0 * j_range, (i_range < N0) & (j_range < N1), 0)
    y = tl.load(y_ptr + j_range, j_range < N1, 0)

    # Forward
    z = x * y
    dz = tl.load(dz_ptr + i_range + N0 * j_range, (i_range < N0) & (j_range < N1), 0)
    dr = tl.where(z > 0, dz, 0)
    dx = dr * y
    tl.store(dx_ptr + i_range + N0 * j_range, dx, (i_range < N0) & (j_range < N1))
    
test(mul_relu_block_back_kernel, mul_relu_block_back_spec, nelem={"N0": 100, "N1": 90})


x: jaxtyping.Float32[Tensor, '90 100']
y: jaxtyping.Float32[Tensor, '90']
dz: jaxtyping.Float32[Tensor, '90 100']
x_ptr ([90, 100], <Parameter "x: jaxtyping.Float32[Tensor, '90 100']">)
y_ptr ([90], <Parameter "y: jaxtyping.Float32[Tensor, '90']">)
dz_ptr ([90, 100], <Parameter "dz: jaxtyping.Float32[Tensor, '90 100']">)
z_ptr ([90, 100], None)
Results match: True


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Running on local URL:  http://127.0.0.1:7865


## Puzzle 7: Long Sum

Sum of a batch of numbers. 

Uses one program blocks. Block size `B0` represents a range of batches of  `x` of length `N0`.
Each element is of length `T`. Process it `B1 < T` elements at a time.  

$$z_{i} = \sum^{T}_j x_{i,j} =  \text{ for } i = 1\ldots N_0$$

Hint: You will need a for loop for this problem. These work and look the same as in Python. 

In [8]:
def sum_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4"]:
    return x.sum(1)

@triton.jit
def sum_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    batch = tl.arange(0, B0)[:, None] + pid_0 * B0
    total = 0
    for i in range(0, T, B1):
        i_range = tl.arange(0, B1)[None, :] + i
        offset = batch*T + i_range
        x = tl.load(x_ptr + offset, i_range < T, 0)
        total = total + tl.sum(x, 1)
    
    tl.store(z_ptr + batch, total[None, :], batch < N0)
    
test(sum_kernel, sum_spec, B={"B0": 1, "B1": 32}, nelem={"N0": 4, "N1": 32, "T": 200})


x: jaxtyping.Float32[Tensor, '4 200']
x_ptr ([4, 200], <Parameter "x: jaxtyping.Float32[Tensor, '4 200']">)
z_ptr ([4], None)
Results match: True
Running on local URL:  http://127.0.0.1:7866


## Puzzle 8: Long Softmax


Softmax of a batch of logits. 

Uses one program blocks. Block size `B0` represents the batch of `x` of length `N0`.
Block logit length `T`.   Process it `B1 < T` elements at a time.  

$$z_{i, j} = \text{softmax}(x_{i,1} \ldots x_{i, T}) \text{ for } i = 1\ldots N_0$$

Note softmax needs to be computed in numerically stable form as in Python. 

There is a simple way to do this with 3 loops. You can also do it with 2 loops if you are clever. Hint: you will find this identity useful:

$$\exp(x_i - m) =  \exp(x_i - m/2 - m/2) = \exp(x_i - m/ 2) /  \exp(m/2) $$

In [9]:
def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]:
    x_max = x.max(1, keepdim=True)[0]
    x = x - x_max
    x_exp = x.exp() 
    return x_exp / x_exp.sum(1, keepdim=True)

@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    batch = tl.arange(0, B0)[:, None] + pid_0 * B0
    x_max = -1e9
    partition = 0
    for i in range(0, T, B1):
        i_range = tl.arange(0, B1)[None, :] + i
        offset = batch*T + i_range
        x = tl.load(x_ptr + offset, i_range < T, -1e9)
        chunk_max = tl.max(x, 1)
        if i == 0:
            x_max_old = chunk_max
        x_max = tl.where(chunk_max > x_max_old, chunk_max, x_max_old)
        diff = x_max - x_max_old
        partition = partition / tl.exp(diff) + tl.sum(tl.exp(x - x_max), 1)
        x_max_old = x_max
    
    for i in range(0, T, B1):
        i_range = tl.arange(0, B1)[None, :] + i
        offset = batch*T + i_range
        x = tl.load(x_ptr + offset, i_range < T, -1e9) - x_max
        x_exp = tl.exp(x)
        z = x_exp / partition
        tl.store(z_ptr + offset, z, i_range < T)
    
test(softmax_kernel, softmax_spec, B={"B0": 1, "B1":32}, 
     nelem={"N0": 4, "N1": 32, "T": 200})


x: jaxtyping.Float32[Tensor, '4 200']
x_ptr ([4, 200], <Parameter "x: jaxtyping.Float32[Tensor, '4 200']">)
z_ptr ([4, 200], None)
Results match: True
Running on local URL:  http://127.0.0.1:7867


## Puzzle 8: Simple FlashAttention

A scalar version of FlashAttention. 

Uses zero programs. Block size `B0` represents `k` of length `N0`.
Block size `B0` represents `q` of length `N0`. Block size `B0` represents `v` of length `N0`. 
Sequence length is `T`. Process it `B1 < T` elements at a time.  

$$z_{i} = \sum_{j} \text{softmax}(q_1 k_1, \ldots, q_T k_T)_j v_{j} \text{ for } i = 1\ldots N_0$$

This can be done in 1 loop using a similar trick from the last puzzle. 

In [10]:
def flashatt_spec(q: Float32[Tensor, "200"], k: Float32[Tensor, "200"], v: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]:
    x = q[:, None] * k[None, :]
    x_max = x.max(1, keepdim=True)[0]
    x = x - x_max
    x_exp = x.exp() 
    soft =  x_exp  / x_exp.sum(1, keepdim=True)
    return (v[None, :] * soft).sum(1)

@triton.jit
def flashatt_kernel(q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr):
    x_max = -1e9
    partition = 0
    acc = 0
    for i in range(0, T, B0):
        i_range = tl.arange(0, B0) + i
        offset = i_range
        q = tl.load(q_ptr + offset, i_range < T, -1e9)[:, None]
        k = tl.load(k_ptr + offset, i_range < T, -1e9)[None, :]
        v = tl.load(v_ptr + offset, i_range < T, -1e9)[None, :]
        x = q * k
        chunk_max = tl.max(x, 1)
        if i == 0:
            x_max_old = chunk_max
        x_max = tl.where(chunk_max > x_max_old, chunk_max, x_max_old)
        diff = x_max - x_max_old
        exp_x = tl.exp(x - x_max[:, None])
        partition = partition / tl.exp(diff) + tl.sum(exp_x, 1)
        acc = acc / tl.exp(diff) + tl.sum(v * exp_x, 1)
        x_max_old = x_max
    i_range = tl.arange(0, B0)
    tl.store(z_ptr + i_range, acc / partition, i_range < T)
    
test(flashatt_kernel, flashatt_spec, B={"B0":200}, 
     nelem={"N0": 200, "T": 200})

q: jaxtyping.Float32[Tensor, '200']
k: jaxtyping.Float32[Tensor, '200']
v: jaxtyping.Float32[Tensor, '200']
q_ptr ([200], <Parameter "q: jaxtyping.Float32[Tensor, '200']">)
k_ptr ([200], <Parameter "k: jaxtyping.Float32[Tensor, '200']">)
v_ptr ([200], <Parameter "v: jaxtyping.Float32[Tensor, '200']">)
z_ptr ([200], None)
Results match: True
Running on local URL:  http://127.0.0.1:7868


## Puzzle 9: Two Dimensional Convolution

A batched 2D convolution. 

Uses one program id axis. Block size `B0` represent the batches to process out of `N0`.
Image `x` is size is `H` by `W` with only 1 channel, and kernel `k` is size `KH` by `KW`.

$$z_{i, j, k} = \sum_{oj, ok} k_{oj,ok} \times x_{i,j + oj, k + ok} \text{ for } i = 1\ldots N_0$$



In [11]:
def conv2d_spec(x: Float32[Tensor, "4 8 8"], k: Float32[Tensor, "4 4"]) -> Float32[Tensor, "4 8 8"]:
    z = torch.zeros(4, 8, 8)
    x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0)
    print(x.shape, k.shape)
    for i in range(8):
        for j in range(8):
            z[:, i, j] = (k[None, :, :] * x[:, i: i+4, j: j + 4]).sum(1).sum(1)
    return z


@triton.jit
def conv2d_kernel(x_ptr, k_ptr, z_ptr, N0, H, W, KH: tl.constexpr, KW: tl.constexpr, B0: tl.constexpr):
    pid_0 = tl.program_id(0)
    kh_range = tl.arange(0, KH)[None, :, None]
    kw_range = tl.arange(0, KW)[None, None, :]
    k = tl.load(k_ptr + kh_range * KW + kw_range)
    for i in range(0, H):
        for j in range(0, W):
            x = tl.load(x_ptr + pid_0 * H * W + (kh_range + i) * W + (kw_range + j), 
                        ((kh_range + i) < H) & ((kw_range + j) < W), 0)
            out = tl.sum(tl.sum(x * k, 2), 1)
            tl.store(z_ptr + pid_0 * H * W + i*W + j + tl.arange(0,1), out)
    
test(conv2d_kernel, conv2d_spec, B={"B0": 1}, nelem={"N0": 4, "H": 8, "W": 8, "KH": 4, "KW": 4})

x: jaxtyping.Float32[Tensor, '4 8 8']
k: jaxtyping.Float32[Tensor, '4 4']
x_ptr ([4, 8, 8], <Parameter "x: jaxtyping.Float32[Tensor, '4 8 8']">)
k_ptr ([4, 4], <Parameter "k: jaxtyping.Float32[Tensor, '4 4']">)
z_ptr ([4, 8, 8], None)
torch.Size([4, 12, 12]) torch.Size([4, 4])
Results match: True
Running on local URL:  http://127.0.0.1:7869

Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB


## Puzzle 10: Matrix Multiplication

A blocked matrix multiplication.

Uses two program id axes. Block size `B2` represent the batches to process out of `N2`.
Block size `B0` represent the rows of `x` to process out of `N0`. Block size `B1` represent the cols of `y` to process out of `N1`. The middle shape is `MID`.

$$z_{i, j, k} = \sum_{k} x_{i,j, l} \times y_{i, l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1$$

You are allowed to use `tl.dot` which computes a smaller mat mul. 

Hint: the main trick is that you can split a matmul into smaller parts. 

$$z_{i, j, k} = \sum_{k=1}^{K/2} x_{i,j, l} \times y_{i, l, k} +  \sum_{k=K/2}^{K} x_{i,j, l} \times y_{i, l, k} $$


In [16]:
def dot_spec(x: Float32[Tensor, "4 32 32"], y: Float32[Tensor, "4 32 32"]) -> Float32[Tensor, "4 32 32"]:
    return x @ y
    
@triton.jit
def dot_kernel(x_ptr, y_ptr, z_ptr, N0, N1, N2, MID, B0: tl.constexpr, B1: tl.constexpr, B2: tl.constexpr):
    r = tl.program_id(0) * B0
    rows = tl.arange(0, B0) [:, None]
    c = tl.program_id(1) * B1
    cols = tl.arange(0, B1)[None, :]
    mids = tl.arange(0, B0)
    b = tl.program_id(2)
    bid = b * N0 * N1
    z = 0
    for i in range(0, MID, B0): 
        x_val = tl.load(
            x_ptr
            + bid
            + (r + rows) * MID
            + (mids + i),
            (i + mids < MID) & (rows + r < N0), 0
        )
        y_val = tl.load(
            y_ptr
            + bid
            + (i + mids[:, None]) * N1
            + (cols + c),
            (i + mids[:, None] < MID) & (cols + c < N1)
            , 0
        )
        z = z + tl.dot(x_val, y_val)
    tl.store(
        z_ptr
        + (b * N0 * N1)
        + (r + rows) * N1
        + (cols + c),
        z,
        mask= (r + rows) < N0  & (c + cols) < N1
    )
test(dot_kernel, dot_spec, B={"B0": 16, "B1": 16, "B2": 1}, nelem={"N0": 32, "N1": 32, "N2": 4, "MID": 32})


x: jaxtyping.Float32[Tensor, '4 32 32']
y: jaxtyping.Float32[Tensor, '4 32 32']
x_ptr ([4, 32, 32], <Parameter "x: jaxtyping.Float32[Tensor, '4 32 32']">)
y_ptr ([4, 32, 32], <Parameter "y: jaxtyping.Float32[Tensor, '4 32 32']">)
z_ptr ([4, 32, 32], None)
Results match: True
Running on local URL:  http://127.0.0.1:7872


## Puzzle 11: Quantized Matrix Mult 

When doing matrix multiplication with quantized neural networks a common strategy is to store the weight matrix in lower precision, with a shift and scale term. 

For this problem our `weight` will be stored in 4 bits. We can store `FPINT` of these in a 32 bit integer. In addition for every `group` weights in order we will store 1 `scale` float value and 1 `shift` 4 bit value. We store these for the column of weight. The `activation`s are stored separately in standard floats. 

Mathematically it looks like.

$$z_{j, k} = \sum_{k} sc_{j, l/g} (w_{j, l} - sh_{j, l/g}) \times y_{l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1$$

However, it is a bit more complex since we need to also extract the 4-bit values into floats to begin. 




In [13]:

FPINT = 32 // 4
GROUP = 8

def quant_dot_spec(scale : Float32[Tensor, "32 8"], 
                   offset : Int32[Tensor, "32 1"], 
                   weight: Int32[Tensor, "32 8"], 
                   activation: Float32[Tensor, "64 32"]) -> Float32[Tensor, "32 32"]:
    def extract(x):
        over = torch.arange(8) * 4 
        mask = 2**4 - 1
        return (x[..., None] >> over) & mask
    scale = scale[..., None].expand(-1, 8, GROUP).contiguous().view(-1, 64)
    offset = extract(offset)[..., None].expand(-1, 1, 8, GROUP).contiguous().view(-1, 64)
    return ( scale * (extract(weight).view(-1, 64) - offset))  @ activation
    
@triton.jit
def quant_dot_kernel(scale_ptr, offset_ptr, weight_ptr, activation_ptr,
                     z_ptr, N0, N1, MID, B0: tl.constexpr, B1: tl.constexpr):
    r = tl.program_id(0) * B0
    rows = tl.arange(0, B0) [:, None]
    c = tl.program_id(1) * B1
    cols = tl.arange(0, B1)[None, :]
    mids = tl.arange(0, B0)

    row_range = r + rows
    z = 0
    for i in range(0, MID, B0): 
        mid_range = i + mids[None, :]
        check = (i + mids < MID) & (rows + r < N0)
        scale_val = tl.load(
            scale_ptr 
            + row_range * (MID // GROUP)
            + mid_range // GROUP, check, 0)
        scale_val = scale_val

        weight_val = tl.load(
            weight_ptr
            + row_range * (MID // FPINT)
            + mid_range // FPINT,
            check, 0)
        v = (mid_range % FPINT) * 4
        weight_val = (weight_val >> v) & (2**4 - 1)

        offset_val = tl.load(
            offset_ptr
            + row_range * (MID // (FPINT * GROUP))
            + mid_range // (FPINT * GROUP), check, 0)
        v = ((mid_range // GROUP) % FPINT) * 4
        offset_val = (offset_val >> v) & (2**4 - 1)
        
        weight = (weight_val - offset_val) * scale_val
        
        activation_val = tl.load(
            activation_ptr
            + (i + mids[:, None]) * N1
            + (cols + c),
            (i + mids[:, None] < MID) & (cols + c < N1)
            , 0
        )
        z = z + tl.dot(weight, activation_val)
    tl.store(
        z_ptr
        + (r + rows) * N1
        + (cols + c),
        z,
        mask= (r + rows) < N0  & (c + cols) < N1
    )
test(quant_dot_kernel, quant_dot_spec, B={"B0": 16, "B1": 16},
                                       nelem={"N0": 32, "N1": 32, "MID": 64})


scale: jaxtyping.Float32[Tensor, '32 8']
offset: jaxtyping.Int32[Tensor, '32 1']
weight: jaxtyping.Int32[Tensor, '32 8']
activation: jaxtyping.Float32[Tensor, '64 32']
scale_ptr ([32, 8], <Parameter "scale: jaxtyping.Float32[Tensor, '32 8']">)
offset_ptr ([32, 1], <Parameter "offset: jaxtyping.Int32[Tensor, '32 1']">)
weight_ptr ([32, 8], <Parameter "weight: jaxtyping.Int32[Tensor, '32 8']">)
activation_ptr ([64, 32], <Parameter "activation: jaxtyping.Float32[Tensor, '64 32']">)
z_ptr ([32, 32], None)
Results match: True
Running on local URL:  http://127.0.0.1:7870
