In [1]:
import torch 
import triton 
from torch import Tensor 
import triton.language as tl
import jaxtyping 
from jaxtyping import Float32, Int32

In [2]:
import triton_viz
import inspect
from triton_viz.interpreter import record_builder

import triton_viz
import inspect
from triton_viz.interpreter import record_builder

def test(puzzle, puzzle_spec, nelem={}, B={"B0": 32}, viz=True):
    B = dict(B)
    if "N1" in nelem and "B1" not in B:
        B["B1"] = 32
    if "N2" in nelem and "B2" not in B:
        B["B2"] = 32

    triton_viz.interpreter.record_builder.reset()
    torch.manual_seed(0)
    signature = inspect.signature(puzzle_spec)
    args = {}
    for n, p in signature.parameters.items():
        print(p)
        args[n + "_ptr"] = ([d.size for d in p.annotation.dims], p)
    args["z_ptr"] = ([d.size for d in signature.return_annotation.dims], None)

    tt_args = []
    for k, (v, t) in args.items():
        tt_args.append(torch.rand(*v) - 0.5)
        if t is not None and t.annotation.dtypes[0] == "int32":
            tt_args[-1] = torch.randint(-100000, 100000, v)
    grid = lambda meta: (triton.cdiv(nelem["N0"], meta["B0"]),
                         triton.cdiv(nelem.get("N1", 1), meta.get("B1", 1)),
                         triton.cdiv(nelem.get("N2", 1), meta.get("B2", 1)))

    #for k, v in args.items():
    #    print(k, v)
    triton_viz.trace(puzzle)[grid](*tt_args, **B, **nelem)
    z = tt_args[-1]
    tt_args = tt_args[:-1]
    z_ = puzzle_spec(*tt_args)
    match = torch.allclose(z, z_, rtol=1e-3, atol=1e-3)
    print("Results match:",  match)
    failures = False
    if viz:
        failures = triton_viz.launch()
    if not match or failures:
        print("Invalid Access:", failures)
        print("Yours:", z)
        print("Spec:", z_)
        print(torch.isclose(z, z_))
        return
    # PUPPIES!
    from IPython.display import HTML
    import random
    print("Correct!")
    pups = [
    "2m78jPG",
    "pn1e9TO",
    "MQCIwzT",
    "udLK6FS",
    "ZNem5o3",
    "DS2IZ6K",
    "aydRUz8",
    "MVUdQYK",
    "kLvno0p",
    "wScLiVz",
    "Z0TII8i",
    "F1SChho",
    "9hRi2jN",
    "lvzRF3W",
    "fqHxOGI",
    "1xeUYme",
    "6tVqKyM",
    "CCxZ6Wr",
    "lMW0OPQ",
    "wHVpHVG",
    "Wj2PGRl",
    "HlaTE8H",
    "k5jALH0",
    "3V37Hqr",
    "Eq2uMTA",
    "Vy9JShx",
    "g9I2ZmK",
    "Nu4RH7f",
    "sWp0Dqd",
    "bRKfspn",
    "qawCMl5",
    "2F6j2B4",
    "fiJxCVA",
    "pCAIlxD",
    "zJx2skh",
    "2Gdl1u7",
    "aJJAY4c",
    "ros6RLC",
    "DKLBJh7",
    "eyxH0Wc",
    "rJEkEw4"]
    return HTML("""
    <video alt="test" controls autoplay=1>
        <source src="https://openpuppies.com/mp4/%s.mp4"  type="video/mp4">
    </video>
    """%(random.sample(pups, 1)[0]))

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
@triton.jit
def demo(x_ptr):
    range = tl.arange(0, 8)
    print(range)
    # adding here means we are creating set of memory address to load from. not mathematical addition. 
    x = tl.load(x_ptr + range, range < 5, 0)
    print(x)

triton_viz.trace(demo)[(1, 1, 1)](torch.ones(4, 3))
triton_viz.launch()

[0 1 2 3 4 5 6 7]
[1. 1. 1. 1. 1. 0. 0. 0.]


* Running on public URL: https://2a2e76ba0dc21b70f3.gradio.live


{}

In [4]:
@triton.jit
def demo_2d(x_ptr):
    i_range = tl.arange(0, 8)[:, None]
    j_range = tl.arange(0, 8)[None, :]
    range = i_range * 4 + j_range
    print(range)
    x = tl.load(x_ptr + range, (i_range < 4) & (j_range < 3), 0)
    print(x)

triton_viz.trace(demo_2d)[(1, 1, 1)](torch.ones(4, 4))
triton_viz.launch()

[[ 0  1  2  3  4  5  6  7]
 [ 4  5  6  7  8  9 10 11]
 [ 8  9 10 11 12 13 14 15]
 [12 13 14 15 16 17 18 19]
 [16 17 18 19 20 21 22 23]
 [20 21 22 23 24 25 26 27]
 [24 25 26 27 28 29 30 31]
 [28 29 30 31 32 33 34 35]]
[[1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]]


* Running on public URL: https://8f7507f160b0b5cce2.gradio.live


{}

In [6]:
# loading multiple blocks at once. 
# 3 blocks 

@triton.jit
def demo_blocks(x_ptr):
    pid = tl.program_id(0)
    range = tl.arange(0, 8) + pid * 8
    x = tl.load(x_ptr + range, range < 20)
    print(f'for each {pid}, {x}')

x = torch.ones(2, 4, 4)
triton_viz.trace(demo_blocks)[(3, 1, 1)](x)
triton_viz.launch()

for each [0], [1. 1. 1. 1. 1. 1. 1. 1.]
for each [1], [1. 1. 1. 1. 1. 1. 1. 1.]
for each [2], [1. 1. 1. 1. 0. 0. 0. 0.]


* Running on public URL: https://14de7241a88d80285f.gradio.live


{}

## Puzzle 1: Constant Add

- add a constant to a vector 

In [13]:
def add_spec(x: Float32[Tensor, "32"]) -> Float32[Tensor, "32"]:
    "This is the spec that you should implement. Uses typing to define sizes."
    return x + 10.

@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    offset = tl.arange(0, B0)
    x = tl.load(x_ptr + offset)
    x = x + 10.0
    tl.store(z_ptr + offset, x)
    return

test(add_kernel, add_spec, nelem={"N0": 32}, viz=True)

x: jaxtyping.Float32[Tensor, '32']
Results match: True


* Running on public URL: https://f7419bc7ad4d442a13.gradio.live


Correct!


## Puzzle 2: Constant Add

- add a constant to a vector where B0 is smaller than N0

In [4]:
def add2_spec(x: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]:
    return x + 10.

@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0:tl.constexpr):
    pid = tl.program_id(0)
    offset = pid * B0 + tl.arange(0, B0)
    mask = offset < N0
    x = tl.load(x_ptr + offset, mask)
    x = x + 10.0
    tl.store(z_ptr + offset, x, mask = mask)

test(add_mask2_kernel, add2_spec, nelem={"N0": 200})


x: jaxtyping.Float32[Tensor, '200']
Results match: True


* Running on public URL: https://190133e694ec25a77a.gradio.live


Correct!


## Puzzle 3
- vector addition 

In [3]:
def add_vec_spec(x: Float32[Tensor, "32"], y: Float32[Tensor, "32"]) -> Float32[Tensor, "32 32"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    offset_x = tl.arange(0, B0)
    offset_y = tl.arange(0, B1)
    offset_z = offset_y[:, None] * B0 + offset_x[None, :]
    x = tl.load(x_ptr + offset_x)
    y = tl.load(y_ptr + offset_y)
    z = y[:, None] + x[None, :]
    tl.store(z_ptr + offset_z, z)
    return

test(add_vec_kernel, add_vec_spec, nelem={"N0": 32, "N1": 32})

x: jaxtyping.Float32[Tensor, '32']
y: jaxtyping.Float32[Tensor, '32']
Results match: True
* Running on public URL: https://e2296d6d5b1e90f6fb.gradio.live


Correct!


## Puzzle 4: Outer Vector Add Block
- add a row vector to a column vector 

In [3]:
def add_vec_block_spec(x: Float32[Tensor, "100"], y: Float32[Tensor, "90"]) -> Float32[Tensor, "90 100"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    offset_x = pid_0 * B0 + tl.arange(0, B0)
    offset_y = pid_1 * B1 + tl.arange(0, B1)
    offset_z = offset_y[:, None] * N0 + offset_x[None, :]

    mask_x = offset_x < N0
    mask_y = offset_y < N1
    mask_z = mask_y[::, None] & mask_x[None, :]

    x = tl.load(x_ptr + offset_x, mask_x)
    y = tl.load(y_ptr + offset_y, mask_y)
    z = y[:, None] + x[None, :]
    tl.store(z_ptr + offset_z, z, mask = mask_z)
    return

test(add_vec_block_kernel, add_vec_block_spec, nelem={"N0": 100, "N1": 90})

x: jaxtyping.Float32[Tensor, '100']
y: jaxtyping.Float32[Tensor, '90']
Results match: True
* Running on public URL: https://2296f9708ac724b883.gradio.live


Correct!


## Puzzle 5 - Fused Outer Multiplication 

- multiply a row vector with a column vector and perform relu 

In [3]:
def mul_relu_block_spec(x: Float32[Tensor, "100"], y: Float32[Tensor, "90"]) -> Float32[Tensor, "90 100"]:
    return torch.relu(x[None, :] * y[:, None])

@triton.jit
def mul_relu_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)
    offset_x = pid_0 * B0 + tl.arange(0, B0)
    offset_y = pid_1 * B1 + tl.arange(0, B1)
    offset_z = offset_y[:, None] * N0 + offset_x[None, :]

    mask_x = offset_x < N0
    mask_y = offset_y < N1
    mask_z = mask_y[:, None] & mask_x[None, :]

    x = tl.load(x_ptr + offset_x, mask_x)
    y = tl.load(y_ptr + offset_y, mask_y)
    z = x[None, :] * y[:, None]
    out_z = tl.where(z > 0, z, 0)
    tl.store(z_ptr + offset_z, out_z, mask = mask_z)
    return

test(mul_relu_block_kernel, mul_relu_block_spec, nelem={"N0": 100, "N1": 90})



x: jaxtyping.Float32[Tensor, '100']
y: jaxtyping.Float32[Tensor, '90']
Results match: True
* Running on public URL: https://1d609ba1488941bb2a.gradio.live


Correct!


## Puzzle - 6: Fused Outer Multiplication - Backwards
- backwards of a function that multiplies a matrix with a row vector and performs relu on top 

In [3]:
def mul_relu_block_back_spec(x: Float32[Tensor, "90 100"], y: Float32[Tensor, "90"],
                             dz: Float32[Tensor, "90 100"]) -> Float32[Tensor, "90 100"]:
    x = x.clone()
    y = y.clone()
    x = x.requires_grad_(True)
    y = y.requires_grad_(True)
    z = torch.relu(x * y[:, None])
    z.backward(dz)
    dx = x.grad
    return dx

@triton.jit
def mul_relu_block_back_kernel(x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    offset_x = pid_0 * B0 + tl.arange(0, B0)
    offset_y = pid_1 * B1 + tl.arange(0, B1)
    offset_xy = offset_y[:, None] * N0 + offset_x[None, :]

    mask_x = offset_x < N0
    mask_y = offset_y < N1
    mask_xy = mask_y[:, None] & mask_x[None, :]

    x = tl.load(x_ptr + offset_xy, mask=mask_xy)
    y = tl.load(y_ptr + offset_y, mask=mask_y)
    dz = tl.load(dz_ptr + offset_xy, mask=mask_xy)

    # The gradient of relu is 1 if the input is greater than 0, otherwise 0.
    df = tl.where(x * y[:, None] > 0, 1.0, 0.0)
    dxy_x = y[:, None]
    # The gradient of x * y is y. Here we use the chain rule.
    dx = df * dxy_x * dz

    tl.store(dx_ptr + offset_xy, dx, mask=mask_xy)

test(mul_relu_block_back_kernel, mul_relu_block_back_spec, nelem={"N0": 100, "N1": 90})

x: jaxtyping.Float32[Tensor, '90 100']
y: jaxtyping.Float32[Tensor, '90']
dz: jaxtyping.Float32[Tensor, '90 100']
Results match: True
* Running on public URL: https://954bbbd5046cb3fe95.gradio.live


Correct!


## Puzzle 7: Long Sum
- sum a batch of numbers 

In [None]:
def sum_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4"]:
    return x.sum(1)

@triton.jit
def sum_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    pid_i = tl.program_id(0)
    off_i = pid_i * B0 + tl.arange(0, B0)
    mask_i = off_i < N0

    z = tl.zeros([B0], dtype=tl.float32)

    for id_j in tl.range(0, T, B1):
        off_j = id_j + tl.arange(0, B1)
        off_ij = off_i[:, None] * T + off_j[None, :]
        mask_j = off_j < T
        mask_ij = mask_i[:, None] & mask_j[None, :]
        x = tl.load(x_ptr + off_ij, mask=mask_ij)
        z += tl.sum(x, axis=1)

    tl.store(z_ptr + off_i, z, mask=mask_i)

test(sum_kernel, sum_spec, B={"B0": 1, "B1": 32}, nelem={"N0": 4, "N1": 32, "T": 200})

## Puzzle 8 - Long Softmax 
- softmax on a bunch of logits 


In [None]:
def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]:
    x_max = x.max(1, keepdim=True)[0]
    x = x - x_max
    x_exp = x.exp()
    return x_exp / x_exp.sum(1, keepdim=True)

@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    log2_e = 1.44269504
    offset_0 = pid_0 * B0 + tl.arange(0, B0)
    mask_0 = offset_0 < N0

    exp_sum = tl.zeros([B0], dtype=tl.float32)
    x_max = tl.full([B0], -float("inf"), dtype=tl.float32)
    new_x_max = tl.full((B0,), -float("inf"), dtype=tl.float32)

    for id_j in tl.range(0, T, B1):
        off_j = id_j + tl.arange(0, B1)
        off_ij = off_i[:, None] * T + off_j[None, :]
        mask_j = off_j < T
        mask_ij = mask_0[:, None] & mask_j[None, :]
        x = tl.load(x_ptr + off_ij, mask=mask_ij)

        # exp(x-new_max)=exp(x-old_max+old_max-new_max)=exp(x-old_max)*exp(old_max-new_max)
        # This is called "online softmax"
        new_x_max = tl.maximum(x_max, tl.max(x, axis=1))
        new_exp_x = tl.exp2(log2_e * (x - new_x_max[:, None]))
        factor = tl.exp2(log2_e * (x_max - new_x_max))
        exp_sum = exp_sum * factor + tl.sum(new_exp_x, axis=1)
        x_max = new_x_max

    for id_j in tl.range(0, T, B1):
        off_j = id_j + tl.arange(0, B1)
        off_ij = off_i[:, None] * T + off_j[None, :]
        mask_j = off_j < T
        mask_ij = mask_0[:, None] & mask_j[None, :]
        x = tl.load(x_ptr + off_ij, mask=mask_ij)
        exp_x = tl.exp2(log2_e * (x - x_max[:, None]))
        z = exp_x / exp_sum[:, None]
        tl.store(z_ptr + off_ij, z, mask=mask_ij)

    return

@triton.jit
def softmax_kernel_brute_force(
    x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr
):
    block_id_i = tl.program_id(0)
    log2_e = 1.44269504
    off_i = block_id_i * B0 + tl.arange(0, B0)
    mask_i = off_i < N0

    exp_sum = tl.zeros([B0], dtype=tl.float32)
    x_max = tl.zeros([B0], dtype=tl.float32)

    for id_j in tl.range(0, T, B1):
        off_j = id_j + tl.arange(0, B1)
        off_ij = off_i[:, None] * T + off_j[None, :]
        mask_j = off_j < T
        mask_ij = mask_i[:, None] & mask_j[None, :]
        x = tl.load(x_ptr + off_ij, mask=mask_ij)
        x_max = tl.maximum(x_max, tl.max(x, axis=1))

    for id_j in tl.range(0, T, B1):
        off_j = id_j + tl.arange(0, B1)
        off_ij = off_i[:, None] * T + off_j[None, :]
        mask_j = off_j < T
        mask_ij = mask_i[:, None] & mask_j[None, :]
        x = tl.load(x_ptr + off_ij, mask=mask_ij)
        exp_x = tl.exp2(log2_e * (x - x_max[:, None]))
        exp_sum += tl.sum(exp_x, axis=1)

    for id_j in tl.range(0, T, B1):
        off_j = id_j + tl.arange(0, B1)
        off_ij = off_i[:, None] * T + off_j[None, :]
        mask_j = off_j < T
        mask_ij = mask_i[:, None] & mask_j[None, :]
        x = tl.load(x_ptr + off_ij, mask=mask_ij)
        exp_x = tl.exp2(log2_e * (x - x_max[:, None]))
        z = exp_x / exp_sum[:, None]
        tl.store(z_ptr + off_ij, z, mask=mask_ij)

    return

test(softmax_kernel_brute_force, softmax_spec, B={"B0": 1, "B1":32},
     nelem={"N0": 4, "N1": 32, "T": 200})