In [1]:
import logging

import helion
import helion.language as hl
import torch
from torch import Tensor
from jaxtyping import Float32, Int32

# If you set this to info you will see the output Triton Code
logging.getLogger().setLevel(logging.WARNING)

In [2]:
from triton.testing import do_bench
def test_kernel(kernel_fn, spec_fn, *args):
    """Test a Helion kernel against a reference implementation."""
    # Run our implementation
    result = kernel_fn(*args)
    # Run reference implementation
    expected = spec_fn(*args)

    # Check if results match
    torch.testing.assert_close(result, expected)
    print("✅ Results Match ✅")

def benchmark_kernel(kernel_fn, *args, **kwargs):
    """Benchmark a Helion kernel."""
    no_args = lambda: kernel_fn(*args, **kwargs)
    time_in_ms = do_bench(no_args)
    print(f"⏱ Time: {time_in_ms} ms")

def compare_implementations(kernel_fn, spec_fn, *args, **kwargs):
    """Benchmark a Helion kernel and its reference implementation."""
    kernel_no_args = lambda: kernel_fn(*args, **kwargs)
    spec_no_args = lambda: spec_fn(*args, **kwargs)
    kernel_time = do_bench(kernel_no_args)
    spec_time = do_bench(spec_no_args)
    print(f"⏱ Helion Kernel Time: {kernel_time:.3f} ms, PyTorch Reference Time: {spec_time:.3f} ms, Speedup: {spec_time/kernel_time:.3f}x")

In [None]:
@helion.kernel(config=helion.Config(block_sizes=[128,128]))
def example_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    m, n = x.size()
    out = torch.empty_like(x)    
    for tile_m, tile_n in hl.tile([m,n]):
        out[tile_m, tile_n] = x[tile_m, tile_n] + y[tile_m,tile_n]
    return out

# Create some sample data
x = torch.randn(10, 10, device="cuda")
y = torch.randn(10, 10, device="cuda")

# Run the kernel
result = example_add(x, y)

# Verify result
expected = x + y
torch.testing.assert_close(result, expected)
print("✅ Results Match ✅")
benchmark_kernel(example_add, x, y)
compare_implementations(example_add, torch.add, x, y)

✅ Results Match ✅
⏱ Time: 0.006967028159056312 ms
⏱ Helion Kernel Time: 0.007 ms, PyTorch Reference Time: 0.006 ms, Speedup: 0.907x


In [4]:
@helion.kernel()
def example_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    m, n = x.size()
    out = torch.empty_like(x)    
    for tile_m, tile_n in hl.tile([m,n]):
        out[tile_m, tile_n] = x[tile_m, tile_n] + y[tile_m,tile_n]
    return out

In [5]:
x = torch.randn(10, 10, device="cuda")
y = torch.randn(10, 10, device="cuda")

# Run the kernel
result = example_add(x, y)

# Verify result
expected = x + y
torch.testing.assert_close(result, expected)
print("✅ Results Match ✅")
benchmark_kernel(example_add, x, y)
compare_implementations(example_add, torch.add, x, y)

[0s] Autotune random seed: 499402173
[0s] Starting autotuning process, this may take a while...
[0s] Starting PatternSearch with initial_population=100, copies=5, max_generations=20


[25s] Initial random population of 100, 5 starting points: ok=100 min=0.0051 mid=0.0061 max=0.0072 best=Config(block_sizes=[1, 16], flatten_loops=[True], indexing='block_ptr', l2_groupings=[32], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=8, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[])
[25s] Generation 1 starting: 115 neighbors, 5 active search path(s)


[55s] Generation 1 complete: ok=120 min=0.0051 mid=0.0072 max=0.0072 best=Config(block_sizes=[2, 16], flatten_loops=[True], indexing='block_ptr', l2_groupings=[32], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=8, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[])
[55s] Generation 2 starting: 106 neighbors, 5 active search path(s)


[82s] Generation 2 complete: ok=111 min=0.0061 mid=0.0061 max=0.0072 best=Config(block_sizes=[2, 16], flatten_loops=[True], indexing='block_ptr', l2_groupings=[32], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=8, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[])
[82s] Autotuning complete in 83.0s after searching 321 configs.
One can hardcode the best config and skip autotuning with:
    @helion.kernel(config=helion.Config(block_sizes=[2, 16], flatten_loops=[True], indexing='block_ptr', l2_groupings=[32], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=8, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)



✅ Results Match ✅
⏱ Time: 0.006280533405434754 ms
⏱ Helion Kernel Time: 0.006 ms, PyTorch Reference Time: 0.006 ms, Speedup: 1.001x


## PUZZLE 1: CONSTANT ADD

In [10]:

def add_spec(x: Tensor) -> Tensor:
    """This is the spec that you should implement in the helion kernel below."""
    return x + 10.

# ---- ✨ Is this the best block size? ----
@helion.kernel(config = helion.Config(block_sizes = [1,]))
def add_kernel(x: torch.Tensor) -> torch.Tensor:
    # ---- ✨ Your Code Here ✨----
    # Set up the output buffer which you will return
    out = torch.empty_like(x)
    n = x.size()[0]
    # Use Helion to tile the computation
    for tile_n in hl.tile(n):
         out[tile_n] = x[tile_n] + 10

    return out

# Test the kernel
x = torch.randn(8192, device="cuda")
test_kernel(add_kernel, add_spec, x)
benchmark_kernel(add_kernel, x)
compare_implementations(add_kernel, add_spec, x)

✅ Results Match ✅
⏱ Time: 0.01255107654364613 ms
⏱ Helion Kernel Time: 0.012 ms, PyTorch Reference Time: 0.006 ms, Speedup: 0.507x


In [13]:

def add_spec(x: Tensor) -> Tensor:
    """This is the spec that you should implement in the helion kernel below."""
    return x + 10.

# ---- ✨ Is this the best block size? ----
@helion.kernel(config = helion.Config(block_sizes = [128,]))
def add_kernel(x: torch.Tensor) -> torch.Tensor:
    # ---- ✨ Your Code Here ✨----
    # Set up the output buffer which you will return
    out = torch.empty_like(x)
    n = x.size()[0]
    # Use Helion to tile the computation
    for tile_n in hl.tile(n):
         out[tile_n] = x[tile_n] + 10

    return out

# Test the kernel
x = torch.randn(8192, device="cuda")
test_kernel(add_kernel, add_spec, x)
benchmark_kernel(add_kernel, x)
compare_implementations(add_kernel, add_spec, x)

✅ Results Match ✅
⏱ Time: 0.01174441568081739 ms
⏱ Helion Kernel Time: 0.006 ms, PyTorch Reference Time: 0.006 ms, Speedup: 1.004x


## PUZZLE 2: OUTER VECTOR ADD

In [15]:
def broadcast_add_spec(x: Tensor, y: Tensor) -> Tensor:
    return x[None, :] + y[:, None]

# ---- ✨ Is this the best block size? ----
@helion.kernel(config = helion.Config(block_sizes = [32, 32]))
def broadcast_add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # Get tensor sizes
     # ---- ✨ Your Code Here ✨----
    n0 = x.size(0)
    n1 = y.size(0)
    out = x.new_empty(n1, n0)

    # Use Helion to tile the computation
    for tile_i, tile_j in hl.tile([n1, n0]):
        # Get tiles from x and y
        y_tile = y[tile_i]
        x_tile = x[tile_j]
        # Compute outer sum
        out[tile_i, tile_j] = y_tile[:, None] + x_tile[None, :]

    return out

# Test the kernel
x = torch.randn(1142, device="cuda")
y = torch.randn(512, device="cuda")
test_kernel(broadcast_add_kernel, broadcast_add_spec, x, y)
benchmark_kernel(broadcast_add_kernel, x, y)
compare_implementations(broadcast_add_kernel, broadcast_add_spec, x, y)

✅ Results Match ✅
⏱ Time: 0.007596276778106888 ms
⏱ Helion Kernel Time: 0.008 ms, PyTorch Reference Time: 0.009 ms, Speedup: 1.047x


## PUZZLE 3

In [17]:
def mul_relu_block_spec(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return torch.relu(x[None,:]*y[:,None])

@helion.kernel(config = helion.Config(block_sizes = [32, 32]))
def mul_relu_block_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    n0 = x.size(0)
    n1 = y.size(0)
    out = x.new_empty(n1,n0)
    
    for tile_i, tile_j in hl.tile([n1, n0]):
        y_tile = y[tile_i]
        x_tile = x[tile_j]
        
        out[tile_i, tile_j] = torch.relu(x_tile[None,:] * y_tile[:,None])
    
    return out

# Test the kernel
x = torch.randn(512, device="cuda")
y = torch.randn(512, device="cuda")
test_kernel(mul_relu_block_kernel, mul_relu_block_spec, x, y)
compare_implementations(mul_relu_block_kernel, mul_relu_block_spec, x, y)

✅ Results Match ✅
⏱ Helion Kernel Time: 0.008 ms, PyTorch Reference Time: 0.011 ms, Speedup: 1.520x


In [3]:
def mul_relu_block_back_spec(x: Tensor, y: Tensor, dz: Tensor) -> Tensor:
    x = x.clone()
    y = y.clone()
    x = x.requires_grad_(True)
    y = y.requires_grad_(True)
    z = torch.relu(x * y[:, None])
    grad_x, _ = torch.autograd.grad(z, [x,y], dz, retain_graph=True)
    return grad_x


@helion.kernel(config=helion.Config(block_sizes=[32, 32]))
def mul_relu_block_back_kernel(
    x: torch.Tensor, y: torch.Tensor, dz: torch.Tensor
) -> torch.Tensor:
    # Get tensor sizes
    n0 = x.size(1)
    n1 = x.size(0)
    # Create output tensor for gradients
    dx = torch.empty_like(x)

    # Use Helion to tile the computation
    for tile_i, tile_j in hl.tile([n1, n0]):
        # Get input tiles
        x_tile = x[tile_i, tile_j]
        y_tile = y[tile_i]
        dz_tile = dz[tile_i, tile_j]

        # Compute gradients for ReLU * multiplication backward
        # For ReLU, gradient is 1 where input > 0, 0 otherwise
        relu_mask = (x_tile * y_tile[:, None]) > 0
        # Chain rule: dx = dz * relu_grad * y
        dx[tile_i, tile_j] = dz_tile * relu_mask * y_tile[:, None]

    return dx


x = torch.randn(512, 1024, device="cuda")
y = torch.randn(512, device="cuda")
dz = torch.randn(512, 1024, device="cuda")
test_kernel(mul_relu_block_back_kernel, mul_relu_block_back_spec, x, y, dz)       
        

✅ Results Match ✅


## LONG SUM

In [None]:

def sum_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4"]:
    return x.sum(1)

@helion.kernel(config=helion.Config(block_sizes=[32, 32]))
def sum_kernel(x: torch.Tensor) -> torch.Tensor:
    batch, seq_len = x.size()

    out = torch.empty(batch, dtype=x.dtype, device=x.device)
    for tile_batch in hl.tile(batch):
        acc = torch.zeros(tile_batch, dtype=torch.float32, device=x.device)

        for tile_seq in hl.tile(seq_len):
            chunk = x[tile_batch,tile_seq]
            acc += torch.sum(chunk, dim=1)
        out[tile_batch] = acc
    
    return out

# Test the kernel
x = torch.randn(4, 200, device="cuda")
test_kernel(sum_kernel, sum_spec, x)

✅ Results Match ✅


## SOFTMAX 

In [8]:
def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]:
    x_max = torch.max(x,axis=1, keepdim=True)[0]
    x -= x_max
    x_exp = x.exp()
    return x_exp / x_exp.sum(1,keepdim=True)

@helion.kernel(config=helion.Config(block_sizes=[32, 32]))
def softmax_kernel(x: torch.Tensor) -> torch.Tensor:
    batch, seq_len = x.size()

    out = torch.empty_like(x)
    block_batch = hl.register_block_size(batch)
    block_seq_len = hl.register_block_size(seq_len)

    for tile_batch in hl.tile(batch, block_size=block_batch):
        _max  = hl.full([tile_batch], float("-inf"), dtype=torch.float32)  
        _norm = hl.zeros([tile_batch], dtype=torch.float32)                

        for tile_seq in hl.tile(seq_len, block_size=block_seq_len):
            chunk_f32 = x[tile_batch, tile_seq].to(torch.float32)          
            local_max = torch.amax(chunk_f32, dim=1)                        # [tile_batch]
            new_max   = torch.maximum(_max, local_max)                      # [tile_batch]
            _norm = _norm * torch.exp(_max - new_max)
            _norm = _norm + torch.exp(chunk_f32 - new_max[:, None]).sum(dim=1)
            _max  = new_max

        # Pass 2: normalize
        for tile_seq in hl.tile(seq_len, block_size=block_seq_len):
            chunk_f32 = x[tile_batch, tile_seq].to(torch.float32)
            out[tile_batch, tile_seq] = (
                torch.exp(chunk_f32 - _max[:, None]) / _norm[:, None]
            ).to(x.dtype)  # cast back to input dtype

    return out

# Test the kernel
x = torch.randn(4, 200, device="cuda")
test_kernel(softmax_kernel, softmax_spec, x)


✅ Results Match ✅


# FLASH ATTENTION

In [None]:
def flashatt_spec(q: Float32[Tensor, "200"], k: Float32[Tensor, "200"], v: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]:
    x = q[:, None] * k[None, :]
    x_max = x.max(1, keepdim=True)[0]
    x = x - x_max
    x_exp = x.exp()
    soft = x_exp / x_exp.sum(1, keepdim=True)
    return (v[None, :] * soft).sum(1)



@helion.kernel(config=helion.Config(block_sizes=[32, 32]))
def flashatt_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    # Get tensor size
    seq_len = q.size(0)
    # Create output tensor
    out = torch.empty_like(q)

    # Process each query position
    for tile_q in hl.tile(seq_len):
        q_tile = q[tile_q].to(torch.float32) 

        # Initialize tracking variables for stable softmax
        max_val = torch.full_like(q_tile, float('-inf'))
        sum_exp = torch.zeros_like(q_tile)
        weighted_sum = torch.zeros_like(q_tile)

        # Process in tiles for better cache efficiency
        for tile_kv in hl.tile(seq_len):
            k_tile = k[tile_kv].to(torch.float32)
            v_tile = v[tile_kv].to(torch.float32)

            # Compute attention scores
            scores = q_tile[:, None] * k_tile[None, :]

            # Find max for numerical stability
            batch_max = torch.amax(scores, dim=1)
            new_max = torch.maximum(max_val, batch_max)

            # Scale old accumulations
            scale_factor = torch.exp(max_val - new_max)
            # correct the previous sum (this is for normalization)
            sum_exp = sum_exp * scale_factor
            # correct the previous weighted sum (qk with v)
            weighted_sum = weighted_sum * scale_factor

            # Update with new values
            exp_scores = torch.exp(scores - new_max[:, None])
            sum_exp = sum_exp + torch.sum(exp_scores, dim=1)
            weighted_sum = weighted_sum + torch.sum(exp_scores * v_tile[None, :], dim=1)

            # Update max_val
            max_val = new_max

        # Compute final output
        out[tile_q] = weighted_sum / sum_exp

    return out


# Test the kernel
q = torch.randn(200, device="cuda")
k = torch.randn(200, device="cuda")
v = torch.randn(200, device="cuda")
test_kernel(flashatt_kernel, flashatt_spec, q, k, v)


✅ Results Match ✅


# MATRIX MULTIPLICATION

In [None]:
import torch
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

In [51]:
def dot_spec(x: Float32[Tensor, "4 32 32"], y: Float32[Tensor, "4 32 32"]) -> Float32[Tensor, "4 32 32"]:
    return (x @ y).to(torch.float32)

@helion.kernel(autotune_effort="none")
def dot_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    batch, mx, nx = x.size()
    _, my, ny = y.size()
    assert nx == my, f"Inner dimensions must match but got {nx=} and {my=}"
    out_type = torch.promote_types(x.dtype, y.dtype)
    out = torch.zeros((batch, mx,ny), dtype=out_type,device=x.device)

    for tile_batch, tile_x, tile_y in hl.tile([batch, mx, ny]):
        acc = hl.zeros([tile_batch, tile_x, tile_y],dtype=out_type)
        for tile_k in hl.tile(nx): #or my
            x_tile = x[tile_batch, tile_x, tile_k].to(out_type)
            y_tile = y[tile_batch, tile_k, tile_y].to(out_type)
            acc = torch.baddbmm(acc, x_tile, y_tile)
        out[tile_batch, tile_x,tile_y] = acc
    return out


x = torch.randn(4, 32, 32, device="cuda")
# x = torch.randn(2, 2, 2, device="cuda")
y = torch.randn(4, 32, 32, device="cuda")
# y = torch.randn(2, 2, 2, device="cuda")
test_kernel(dot_kernel, dot_spec, x.to(torch.float32), y.to(torch.float32))

Using default config: @helion.kernel(config=helion.Config(block_sizes=[4, 16, 16, 16], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[]), static_shapes=True)


AssertionError: Tensor-likes are not close!

Mismatched elements: 4085 / 4096 (99.7%)
Greatest absolute difference: 0.021955490112304688 at index (2, 31, 7) (up to 1e-05 allowed)
Greatest relative difference: 0.48506131768226624 at index (2, 22, 21) (up to 1.3e-06 allowed)

In [40]:
dot_kernel(x,y)

tensor([[[-0.5123, -0.9325],
         [-0.7933,  0.0670]],

        [[-0.7205,  0.1165],
         [ 1.0029,  0.2240]]], device='cuda:0')

In [41]:
dot_spec(x,y)

tensor([[[-0.5121, -0.9332],
         [-0.7939,  0.0670]],

        [[-0.7214,  0.1162],
         [ 1.0042,  0.2246]]], device='cuda:0')