# Triton Puzzles

In [1]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import triton_viz
from triton_viz.interpreter import record_builder
import jaxtyping 
import inspect
from jaxtyping import Float

SyntaxError: unterminated string literal (detected at line 162) (draw.py, line 162)

In [None]:
def test(puzzle, puzzle_spec, nelem={}, B={"B0": 128}):
    B = dict(B)
    if "N1" in nelem:
        B["B1"] = 128
    if "N2" in nelem:
        B["B2"] = 128
        
    triton_viz.interpreter.record_builder.reset()
    torch.manual_seed(0)
    signature = inspect.signature(puzzle_spec)
    args = {}
    for n, p in signature.parameters.items():
        args[n + "_ptr"] = [d.size for d in p.annotation.dims]
    args["z_ptr"] = [d.size for d in signature.return_annotation.dims]
    
    tt_args = []
    for k, v in args.items():
        tt_args.append(torch.rand(*v))
    grid = lambda meta: (triton.cdiv(nelem["N0"], meta["B0"]), 
                         triton.cdiv(nelem.get("N1", 1), meta.get("B1", 1)), 
                         triton.cdiv(nelem.get("N2", 1), meta.get("B2", 1)))   
    triton_viz.trace(puzzle)[grid](*tt_args, **B, **nelem)
    z = tt_args[-1]
    tt_args = tt_args[:-1]
    z_ = puzzle_spec(*tt_args)
    print("Results match:",  torch.allclose(z, z_))
    triton_viz.launch()

## Puzzle 1: Constant Add

Add a constant to a vector. Uses one program block. Block size `B0` is always the same as vector length `N0`.


In [None]:
def add_spec(x: Float[Tensor, "128"]) -> Float[Tensor, "128"]:
    return x + 10.

@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    range = tl.arange(0, B0)
    x = tl.load(x_ptr + range)
    z = x + 10
    z = tl.store(z_ptr + range, z)

test(add_kernel, add_spec, nelem={"N0": 128})

## Puzzle 2: Constant Add Block

Add a constant to a vector. Uses one program block. Block size `B0` is always the same as vector length `N0`.



In [None]:
def add2_spec(x: Float[Tensor, "200"]) -> Float[Tensor, "200"]:
    return x + 10.

@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    pid = tl.program_id(0)
    range = pid * B0 + tl.arange(0, B0)
    x = tl.load(x_ptr + range, range < N0, 0
               )
    z = x + 10
    z = tl.store(z_ptr + range, 
                 z, range < N0)
    
test(add_mask2_kernel, add2_spec, nelem={"N0": 200})

## Puzzle 3: Outer Vector Add

Add two vectors. Uses one program block. Block size `B0` is always the same as vector `x` length `N0`.
Block size `B1` is always the same as vector `y` length `N1`.


In [None]:
def add_vec_spec(x: Float[Tensor, "128"], y: Float[Tensor, "128"]) -> Float[Tensor, "128 128"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    i_range = tl.arange(0, B0)[None, :] 
    j_range = tl.arange(0, B1)[:, None]
    
    x = tl.load(x_ptr + i_range)
    y = tl.load(y_ptr + j_range)
    
    z = x + y
    z = tl.store(z_ptr + i_range + B0 * j_range, z)
    
test(add_vec_kernel, add_vec_spec, nelem={"N0": 128, "N1": 128})

## Puzzle 4: Outer Vector Add Block

In [None]:
def add_vec_spec(x: Float[Tensor, "i j"], i: int) -> Float[Tensor, "i"]:
    return x + 10.

@triton_viz.trace
@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    i_range = tl.arange(0, BLOCK_SIZE)[:, None] + pid_0 * B0
    j_range = tl.arange(0, BLOCK_SIZE)[None, :] + pid_1 * B1
    
    x = tl.load(x_ptr + i_range)
    y = tl.load(x_ptr + j_range)
    
    z = x + y
    z = tl.store(z_ptr + i_range * B1 + j_range, 
                 z)
    
test(add_vec_kernel, add_vec_spec)

## Puzzle 5: Fused Op

## Puzzle 6: Fused Op Backwards

## Puzzle 7: Sum and Backward

## Puzzle 8: Manual Conv.

## Puzzle 9: Matrix Mult

## Puzzle 10: Quantized Matrix Mult 

In [None]:
@triton_viz.trace
@triton.jit
def dot_kernel(x_ptr, y_ptr, z_ptr, BLOCK_SIZE: tl.constexpr):
    r = tl.program_id(0) * BLOCK_SIZE
    c = tl.program_id(1) * BLOCK_SIZE
    b = tl.program_id(2)
    bid = b * 4 * BLOCK_SIZE * BLOCK_SIZE
    x_val = tl.load(
        x_ptr
        + bid
        + (r + tl.arange(0, BLOCK_SIZE)[:, None]) * 2 * BLOCK_SIZE
        + tl.arange(0, BLOCK_SIZE)[None, :]
    )
    y_val = tl.load(
        y_ptr
        + bid
        + tl.arange(0, BLOCK_SIZE)[:, None] * 2 * BLOCK_SIZE
        + tl.arange(0, BLOCK_SIZE)[None, :]
        + c
    )
    z = tl.dot(x_val, y_val)
    x_val = tl.load(
        x_ptr
        + bid
        + (r + tl.arange(0, BLOCK_SIZE)[:, None]) * 2 * BLOCK_SIZE
        + tl.arange(0, BLOCK_SIZE)[None, :]
        + BLOCK_SIZE
    )
    y_val = tl.load(
        y_ptr
        + bid
        + (BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None]) * 2 * BLOCK_SIZE
        + tl.arange(0, BLOCK_SIZE)[None, :]
        + c
    )
    z = z + tl.dot(x_val, y_val)
    tl.store(
        z_ptr
        + (b * (2 * BLOCK_SIZE) * (2 * BLOCK_SIZE - 10))
        + (r + tl.arange(0, BLOCK_SIZE)[:, None]) * (2 * BLOCK_SIZE - 10)
        + tl.arange(0, BLOCK_SIZE)[None, :]
        + c,
        z,
        mask=tl.arange(0, BLOCK_SIZE)[None, :] + c < 2 * BLOCK_SIZE - 10,
    )


def perform_dot(device, BATCH, BLOCK_SIZE):
    x = torch.randn((BATCH, 2 * BLOCK_SIZE, 2 * BLOCK_SIZE), device=device)
    y = torch.randn((BATCH, 2 * BLOCK_SIZE, 2 * BLOCK_SIZE), device=device)
    z = torch.zeros((BATCH, 2 * BLOCK_SIZE, 2 * BLOCK_SIZE - 10), device=device)
    dot_kernel[(2, 2, BATCH)](x, y, z, BLOCK_SIZE)
    return x, y, z
BLOCK_SIZE = 32
input_matrix1, input_matrix2, result = perform_dot(device, 12, BLOCK_SIZE)
triton_viz.launch()

In [None]:
@triton_viz.trace
@triton.jit
def add_kernel(X, Y, Z, n_elements,  # Size of the vector.  BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
    # NOTE: `constexpr` so it can be used as a shape value.
):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)


## Puzzle 2: Blocking

In [None]:
@triton_viz.trace
@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
    # NOTE: `constexpr` so it can be used as a shape value.
):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)


def add(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output.
    output = torch.empty_like(x)
    n_elements = output.numel()
    # The SPMD launch grid denotes the number of kernel instances that run in parallel.
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # In this case, we use a 1D grid where the size is the number of blocks:
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    # NOTE:
    #  - Each torch.tensor object is implicitly converted into a pointer to its first element.
    #  - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
    #  - Don't forget to pass meta-parameters as keywords arguments.
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
    # running asynchronously at this point.
    return output, grid

    # Directly use x and y here even though they are defined later in the file




In [None]:
device = "cpu"
size = 5000
input_vector1, input_vector2, output_triton = perform_vec_add(device, size)
triton_viz.launch()

In [None]:
@triton_viz.trace
@triton.jit
def dot_kernel(x_ptr, y_ptr, z_ptr, BLOCK_SIZE: tl.constexpr):
    r = tl.program_id(0) * BLOCK_SIZE
    c = tl.program_id(1) * BLOCK_SIZE
    b = tl.program_id(2)
    bid = b * 4 * BLOCK_SIZE * BLOCK_SIZE
    x_val = tl.load(
        x_ptr
        + bid
        + (r + tl.arange(0, BLOCK_SIZE)[:, None]) * 2 * BLOCK_SIZE
        + tl.arange(0, BLOCK_SIZE)[None, :]
    )
    y_val = tl.load(
        y_ptr
        + bid
        + tl.arange(0, BLOCK_SIZE)[:, None] * 2 * BLOCK_SIZE
        + tl.arange(0, BLOCK_SIZE)[None, :]
        + c
    )
    z = tl.dot(x_val, y_val)
    x_val = tl.load(
        x_ptr
        + bid
        + (r + tl.arange(0, BLOCK_SIZE)[:, None]) * 2 * BLOCK_SIZE
        + tl.arange(0, BLOCK_SIZE)[None, :]
        + BLOCK_SIZE
    )
    y_val = tl.load(
        y_ptr
        + bid
        + (BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None]) * 2 * BLOCK_SIZE
        + tl.arange(0, BLOCK_SIZE)[None, :]
        + c
    )
    z = z + tl.dot(x_val, y_val)
    tl.store(
        z_ptr
        + (b * (2 * BLOCK_SIZE) * (2 * BLOCK_SIZE - 10))
        + (r + tl.arange(0, BLOCK_SIZE)[:, None]) * (2 * BLOCK_SIZE - 10)
        + tl.arange(0, BLOCK_SIZE)[None, :]
        + c,
        z,
        mask=tl.arange(0, BLOCK_SIZE)[None, :] + c < 2 * BLOCK_SIZE - 10,
    )


def perform_dot(device, BATCH, BLOCK_SIZE):
    x = torch.randn((BATCH, 2 * BLOCK_SIZE, 2 * BLOCK_SIZE), device=device)
    y = torch.randn((BATCH, 2 * BLOCK_SIZE, 2 * BLOCK_SIZE), device=device)
    z = torch.zeros((BATCH, 2 * BLOCK_SIZE, 2 * BLOCK_SIZE - 10), device=device)
    dot_kernel[(2, 2, BATCH)](x, y, z, BLOCK_SIZE)
    return x, y, z
BLOCK_SIZE = 32
input_matrix1, input_matrix2, result = perform_dot(device, 12, BLOCK_SIZE)
triton_viz.launch()