In [3]:
import torch
import triton
import triton.language as tl

In [7]:
DEVICE = "cuda"

In [11]:
@triton.jit
def _add_(
    x_ptr,
    y_ptr,
    output_ptr,
    n_elements,
    block_size: tl.constexpr,
):
    pid = tl.program_id(0)

    block_start = pid * block_size
    offsets = block_start + tl.arange(0, block_size)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x+y

    tl.store(output_ptr + offsets, output, mask=mask)

In [14]:
def add(
    x:torch.Tensor,
    y:torch.Tensor,
):
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda

    n_elements = output.numel()

    grid = lambda meta: (
        triton.cdiv(n_elements, meta['block_size']),
    )

    _add_[grid](
        x,
        y,
        output,
        n_elements,
        block_size=1024,
    )

    return output

In [15]:
torch.manual_seed(0)
size = 8192
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
output_torch = x+y
output_triton = add(x, y)
assert torch.allclose(output_torch, output_triton)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')


tensor([1.3713, 1.3076, 0.4940,  ..., 0.2920, 1.5087, 0.9388], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940,  ..., 0.2920, 1.5087, 0.9388], device='cuda:0')
The maximum difference between torch and triton is 0.0
