In [1]:
import torch
import triton
import triton.language as tl

In [2]:
@triton.jit
def vector_add(
            x_ptr: tl.tensor,
            y_ptr: tl.tensor,
            output_ptr: tl.tensor,
            n_elements: int,
            BLOCK_SIZE: tl.constexpr
        ):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.

    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)

In [4]:
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    assert x.device == y.device and y.device == output.device
    n_elements = output.numel()
    # The SPMD launch grid denotes the number of kernel instances that run in parallel.
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # In this case, we use a 1D grid where the size is the number of blocks:
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    vector_add[grid](x, y, output, n_elements, BLOCK_SIZE = 1024)
    return output

In [10]:
result = add(torch.tensor([1.0, 2.0, 3.0], device="cuda:0", dtype=torch.float32), torch.tensor([4.0, 5.0, 6.0], device="cuda:0", dtype=torch.float32))

In [11]:
result.dtype

torch.float32

In [12]:
result

tensor([5., 7., 9.], device='cuda:0')

In [13]:
result.to("cpu").numpy()

array([5., 7., 9.], dtype=float32)