In [1]:
from __future__ import annotations

import torch

import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl

In [3]:
@helion.kernel()
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    x, y = torch.broadcast_tensors(x, y)
    out = torch.empty_like(x, dtype = torch.promote_types(x.dtype, y.dtype), device=x.device)
    for tile in hl.tile(out.size()):
        out[tile] = x[tile] + y[tile]
    return out

def check(m: int, n: int) -> None:
    """
    Verify the add kernel implementation against PyTorch's native add function.

    Args:
        m: First dimension of the test tensors
        n: Second dimension of the test tensors
    """
    x = torch.randn([m, n], device=DEVICE, dtype=torch.float16)
    y = torch.randn([m, n], device=DEVICE, dtype=torch.float16)
    run_example(add, torch.add, (x, y))

def main() -> None:
    """
    Main entry point that runs the add kernel verification with 1024x1024 tensors.
    """
    check(10240, 10240)


if __name__ == "__main__":
    main()


Testing helion correctness...
[0s] Autotune random seed: 1806828401
[0s] Starting autotuning process, this may take a while...
[0s] Starting PatternSearch with initial_population=100, copies=5, max_generations=20


[26s] Initial random population of 100, 5 starting points: ok=100 min=0.3768 mid=1.2626 max=71.3226 best=Config(block_sizes=[256, 32], flatten_loops=[True], indexing='pointer', l2_groupings=[4], load_eviction_policies=['', 'last'], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[])
[26s] Generation 1 starting: 139 neighbors, 5 active search path(s)


[77s] Generation 1 complete: ok=144 min=0.3758 mid=0.3799 max=21.5281 best=Config(block_sizes=[256, 32], flatten_loops=[True], indexing='pointer', l2_groupings=[4], load_eviction_policies=['last', 'last'], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[])
[77s] Generation 2 starting: 94 neighbors, 4 active search path(s)


[112s] Generation 2 complete: ok=98 min=0.3758 mid=0.3789 max=5.0668 best=Config(block_sizes=[256, 32], flatten_loops=[True], indexing='pointer', l2_groupings=[4], load_eviction_policies=['last', 'last'], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[])
[112s] Generation 3 starting: 73 neighbors, 3 active search path(s)


[139s] Generation 3 complete: ok=77 min=0.3758 mid=0.3789 max=19.9250 best=Config(block_sizes=[256, 32], flatten_loops=[True], indexing='pointer', l2_groupings=[4], load_eviction_policies=['last', 'last'], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[])
[139s] Generation 4 starting: 46 neighbors, 2 active search path(s)


[156s] Generation 4 complete: ok=49 min=0.3758 mid=0.3768 max=20.1600 best=Config(block_sizes=[256, 32], flatten_loops=[True], indexing='pointer', l2_groupings=[4], load_eviction_policies=['last', 'last'], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[])
[156s] Autotuning complete in 156.5s after searching 452 configs.
One can hardcode the best config and skip autotuning with:
    @helion.kernel(config=helion.Config(block_sizes=[256, 32], flatten_loops=[True], indexing='pointer', l2_groupings=[4], load_eviction_policies=['last', 'last'], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[]), static_shapes=True)




Benchmark Results
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion               0.3768       1.01x          
torch                0.3799       1.00x (ref)    



## EXPONENTIAL FUNCTION

In [5]:

@helion.kernel(autotune_effort="none")
def exp_fwd(x: torch.Tensor) -> torch.Tensor:
    out = torch.empty_like(x)
    for tile in hl.tile(out.size()):
        out[tile] = torch.exp(x[tile])
    return out

@helion.kernel(autotune_effort="none")
def exp_bwd(dy: torch.Tensor, exp_x: torch.Tensor) -> torch.Tensor:
    dx = torch.empty_like(exp_x)
    for tile in hl.tile(exp_x.size()):
        dx[tile] = dy[tile] * exp_x[tile]
    return dx

In [6]:
class ExpFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx: object, x: torch.Tensor) -> torch.Tensor:
        y = exp_fwd(x)
        ctx.save_for_backward(y)
        return y
    
    @staticmethod
    def backward(ctx: object, grad_output: torch.Tensor) -> torch.Tensor:
        (x,) = ctx.saved_tensors
        return exp_bwd(grad_output, x)
    
def exp(x: torch.Tensor) -> torch.Tensor:
    """
    Exponential with forward and backward support.

    Args:
        x: Input tensor

    Returns:
        Output tensor with the exponential of each element in the input
    """
    return ExpFunction.apply(x)  # type: ignore[no-any-return]


In [None]:
from typing import Callable
def exp_tritonbench(
    tb_op: object, x: torch.Tensor
) -> Callable[[], dict[str, torch.Tensor]]:
    """
    Wrapper for tritonbench that returns output in expected format.

    Args:
        tb_op: TritonBench operator instance
        x: Input tensor

    Returns:
        Callable that returns dictionary containing the output tensor
    """
    return lambda: {"output": exp(x)}

In [8]:
def check(n: int) -> None:
    """
    Verify the exp kernel implementation against PyTorch's native exp function.

    Args:
        n: Size of the test tensor
    """
    x = torch.randn(n, device=DEVICE, dtype=torch.float32, requires_grad=True)
    run_example(exp, torch.exp, (x,), bwd=True)
    
def main() -> None:
    """
    Main entry point that runs the exp kernel verification.
    """
    check(10240 * 10240)


if __name__ == "__main__":
    main()

Testing helion correctness...
Using default config: @helion.kernel(config=helion.Config(block_sizes=[1024], indexing='pointer', load_eviction_policies=[''], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)


Using default config: @helion.kernel(config=helion.Config(block_sizes=[1024], indexing='pointer', load_eviction_policies=['', ''], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]), static_shapes=True)



Benchmark Results
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion               0.5110       1.00x          
torch                0.5110       1.00x (ref)    



In [9]:
@helion.kernel(autotune_effort="none")
def sum_kernel(x: torch.Tensor) -> torch.Tensor:
    m, n = x.shape
    out = torch.empty([m], dtype=x.dtype, device=x.device)
    
    for tile_m in hl.tile(m):
        out[tile_m] = x[tile_m,:].sum(-1)
    return out

def sum_tritonbench(tb_op: object, x: torch.Tensor) -> Callable[[], torch.Tensor]:
    """
    Wrapper for tritonbench that handles 1D input.

    Args:
        tb_op: TritonBench operator instance
        x: Input tensor (1D or 2D)

    Returns:
        Callable that returns sum of the tensor along the last dimension
    """

    def compute_sum() -> torch.Tensor:
        if x.ndim == 1:
            # For 1D tensors, reshape to 2D for sum_kernel
            x_2d = x.unsqueeze(0)
            result = sum_kernel(x_2d)
            return result.squeeze()
        return sum_kernel(x)

    return compute_sum

def check(m: int, n: int) -> None:
    """
    Verify the sum kernel implementation against PyTorch's native sum function.

    Args:
        m: First dimension of the test tensor
        n: Second dimension of the test tensor
    """
    x = torch.randn([m, n], device=DEVICE, dtype=torch.float32)
    kernels = {"helion": sum_kernel}
    run_example(kernels, lambda x: x.sum(-1), (x,))

def main() -> None:
    """
    Main entry point that runs the sum kernel verification with different tensor sizes.
    """
    check(5120, 2560)
    check(10240, 10240)


if __name__ == "__main__":
    main()


Testing helion correctness...
Using default config: @helion.kernel(config=helion.Config(block_sizes=[1], indexing='pointer', load_eviction_policies=['', ''], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[], reduction_loops=[None]), static_shapes=True)



Benchmark Results
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion               0.0512       1.08x          
torch                0.0553       1.00x (ref)    

Testing helion correctness...
Using default config: @helion.kernel(config=helion.Config(block_sizes=[1], indexing='pointer', load_eviction_policies=['', ''], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[], reduction_loops=[4096]), static_shapes=True)



Benchmark Results
Implementation       Time (ms)    Speedup        
-----------------------------------------------------------------
helion               0.2714       0.99x          
torch                0.2693       1.00x (ref)    



In [10]:
@helion.kernel(
    config=helion.Config(
        block_sizes=[32768, 1], num_warps=16, num_stages=5, indexing="pointer"
    )
)
def longsum_manual(x: torch.Tensor) -> torch.Tensor:
    m,n = x.size()
    out = torch.empty([m], dtype=x.dtype, device=x.device)
    
    block_size_n = hl.register_block_size(n)
    for tile_m in hl.tile(m):
        acc = hl.zeros([tile_m, block_size_n], dtype=x.dtype)
        for tile_n in hl.tile(n, block_size=block_size_n):
            acc += x[tile_m, tile_n]
        out[tile_m] = acc.sum(-1)
    return out

In [11]:
@helion.kernel()
def softmax_two_pass(x: torch.Tensor) -> torch.Tensor:
    m, n = x.size()
    out = torch.empty_like(x)
    block_size_m = hl.register_block_size(m)
    block_size_n = hl.register_block_size(n)
    for tile_m in hl.tile(m, block_size=block_size_m):
        mi = hl.full([tile_m], float("-inf"), dtype=torch.float32)
        di = hl.zeros([tile_m], dtype=torch.float32)
        for tile_n in hl.tile(n, block_size=block_size_n):
            values = x[tile_m, tile_n]
            local_amax = torch.amax(values, dim=1)
            mi_next = torch.maximum(mi, local_amax)
            di = di + torch.exp(mi - mi_next) + torch.exp(
                values - mi_next[:,None]
            ).sum(dim=1)
            mi = mi_next
        for tile_n in hl.tile(n, block_size=block_size_n):
            values = x[tile_m, tile_n]
            out[tile_m, tile_n] = torch.exp(values - mi[:, None]) / di[:, None]
    return out