In [1]:
import torch 
import triton 
from torch import Tensor 
import triton.language as tl
import jaxtyping 
from jaxtyping import Float32, Int32

In [2]:
import triton_viz
import inspect
from triton_viz.interpreter import record_builder

def test(puzzle, puzzle_spec, nelem={}, B={"B0": 32}, viz=True):
    B = dict(B)
    if "N1" in nelem and "B1" not in B:
        B["B1"] = 32
    if "N2" in nelem and "B2" not in B:
        B["B2"] = 32

    triton_viz.interpreter.record_builder.reset()
    torch.manual_seed(0)
    signature = inspect.signature(puzzle_spec)
    args = {}
    for n, p in signature.parameters.items():
        print(p)
        args[n + "_ptr"] = ([d.size for d in p.annotation.dims], p)
    args["z_ptr"] = ([d.size for d in signature.return_annotation.dims], None)

    tt_args = []
    for k, (v, t) in args.items():
        tt_args.append(torch.rand(*v) - 0.5)
        if t is not None and t.annotation.dtypes[0] == "int32":
            tt_args[-1] = torch.randint(-100000, 100000, v)
    grid = lambda meta: (triton.cdiv(nelem["N0"], meta["B0"]),
                         triton.cdiv(nelem.get("N1", 1), meta.get("B1", 1)),
                         triton.cdiv(nelem.get("N2", 1), meta.get("B2", 1)))

    #for k, v in args.items():
    #    print(k, v)
    triton_viz.trace(puzzle)[grid](*tt_args, **B, **nelem)
    z = tt_args[-1]
    tt_args = tt_args[:-1]
    z_ = puzzle_spec(*tt_args)
    match = torch.allclose(z, z_, rtol=1e-3, atol=1e-3)
    print("Results match:",  match)
    failures = False
    if viz:
        failures = triton_viz.launch()
    if not match or failures:
        print("Invalid Access:", failures)
        print("Yours:", z)
        print("Spec:", z_)
        print(torch.isclose(z, z_))
        return
    # PUPPIES!
    from IPython.display import HTML
    import random
    print("Correct!")
    pups = [
    "2m78jPG",
    "pn1e9TO",
    "MQCIwzT",
    "udLK6FS",
    "ZNem5o3",
    "DS2IZ6K",
    "aydRUz8",
    "MVUdQYK",
    "kLvno0p",
    "wScLiVz",
    "Z0TII8i",
    "F1SChho",
    "9hRi2jN",
    "lvzRF3W",
    "fqHxOGI",
    "1xeUYme",
    "6tVqKyM",
    "CCxZ6Wr",
    "lMW0OPQ",
    "wHVpHVG",
    "Wj2PGRl",
    "HlaTE8H",
    "k5jALH0",
    "3V37Hqr",
    "Eq2uMTA",
    "Vy9JShx",
    "g9I2ZmK",
    "Nu4RH7f",
    "sWp0Dqd",
    "bRKfspn",
    "qawCMl5",
    "2F6j2B4",
    "fiJxCVA",
    "pCAIlxD",
    "zJx2skh",
    "2Gdl1u7",
    "aJJAY4c",
    "ros6RLC",
    "DKLBJh7",
    "eyxH0Wc",
    "rJEkEw4"]
    return HTML("""
    <video alt="test" controls autoplay=1>
        <source src="https://openpuppies.com/mp4/%s.mp4"  type="video/mp4">
    </video>
    """%(random.sample(pups, 1)[0]))

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
@triton.jit
def demo(x_ptr):
    range = tl.arange(0, 8)
    print(range)
    # adding here means we are creating set of memory address to load from. not mathematical addition. 
    x = tl.load(x_ptr + range, range < 5, 0)
    print(x)

triton_viz.trace(demo)[(1, 1, 1)](torch.ones(4, 3))
triton_viz.launch()

[0 1 2 3 4 5 6 7]
[1. 1. 1. 1. 1. 0. 0. 0.]


* Running on public URL: https://2a2e76ba0dc21b70f3.gradio.live


{}

In [4]:
@triton.jit
def demo_2d(x_ptr):
    i_range = tl.arange(0, 8)[:, None]
    j_range = tl.arange(0, 8)[None, :]
    range = i_range * 4 + j_range
    print(range)
    x = tl.load(x_ptr + range, (i_range < 4) & (j_range < 3), 0)
    print(x)

triton_viz.trace(demo_2d)[(1, 1, 1)](torch.ones(4, 4))
triton_viz.launch()

[[ 0  1  2  3  4  5  6  7]
 [ 4  5  6  7  8  9 10 11]
 [ 8  9 10 11 12 13 14 15]
 [12 13 14 15 16 17 18 19]
 [16 17 18 19 20 21 22 23]
 [20 21 22 23 24 25 26 27]
 [24 25 26 27 28 29 30 31]
 [28 29 30 31 32 33 34 35]]
[[1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]]


* Running on public URL: https://8f7507f160b0b5cce2.gradio.live


{}

In [6]:
# loading multiple blocks at once. 
# 3 blocks 

@triton.jit
def demo_blocks(x_ptr):
    pid = tl.program_id(0)
    range = tl.arange(0, 8) + pid * 8
    x = tl.load(x_ptr + range, range < 20)
    print(f'for each {pid}, {x}')

x = torch.ones(2, 4, 4)
triton_viz.trace(demo_blocks)[(3, 1, 1)](x)
triton_viz.launch()

for each [0], [1. 1. 1. 1. 1. 1. 1. 1.]
for each [1], [1. 1. 1. 1. 1. 1. 1. 1.]
for each [2], [1. 1. 1. 1. 0. 0. 0. 0.]


* Running on public URL: https://14de7241a88d80285f.gradio.live


{}