# Vector addition with triton

Source: [Triton docs vector addition](https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py)

Let's put together a basic kernel with Triton (following an example in the documentation linked above).

We can use Triton to launch a 1-d grid of blocks, and use the blocks to add 2 vectors.  We write a python function to launch the kernel (`add`), and the kernel itself (`add_kernel`).  The kernel will operate on a single thread block of size `1024` (`32 x 32`).

First, the python wrapper:

1. We allocate the output vector first, since we need to pass the pointer into the triton func.
2. `grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)` creates a 1-d launch grid, with the number of thread blocks being `ceil(n_elements / 1024)`
3. We then call `add_kernel`, indexed with the grid. `add_kernel` does the main work.

In [21]:
import torch

import triton
import triton.language as tl

def add(x, y):
    output = torch.empty_like(x)
    n_elements = output.numel()
    
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    
    return output

Then, the kernel.  `triton.jit` will compile the kernel.  Remember that the kernel operates on a single thread block:

1. The kernel receives pointers to the first element in each vector, not the data itself.  We need to use the pointers to load the relevant data
2. Use `tl.constexpr` to pass the constant block size into the kernel (we need this to load the right amount of data)
3. Get the block id using `tl.program_id`.  You could also call `tl.num_programs` to see how many thread blocks are running
4. We find the start index for the data that will be processed by this block with `block_start`
5. Ue [`tl.arange`](https://triton-lang.org/main/python-api/generated/triton.language.arange.html#triton.language.arange) and offsets to specify all `1024` data elements per vector that this block will process
6. We need the mask to avoid accessing elements that don't exist (remember that we did `ceil(n_elements / 1024)` to find the block count, so the last block not be processing `1024` elements).
7. Load the x and y data elements we'll need.
8. Add everything using individual threads in the block, and assign to `output`.
9. Store the data we processed into the output pointer.

In [22]:
@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE # the first index of the data we'll need for the block
    
    offsets = block_start + tl.arange(0, BLOCK_SIZE) # all data the block will process
    mask = offsets < n_elements # don't try to load any data after n_elements (block size * block count > n_elements)

    x = tl.load(x_ptr + offsets, mask=mask) # Load the x data
    y = tl.load(y_ptr + offsets, mask=mask) # Load the y data
    output = x + y # Run the individual threads to add everything
    tl.store(output_ptr + offsets, output, mask=mask) # Store the data into the output pointer, use mask to avoid illegal access

We can test if the kernel works by comparing it to just summing the vectors with torch:

In [24]:
torch.manual_seed(0)
size = 1000
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(
    f'The maximum difference between torch and triton is '
    f'{torch.max(torch.abs(output_torch - output_triton))}'
)
print(output_torch[:10])
print(output_triton[:10])

The maximum difference between torch and triton is 0.0
tensor([1.3713, 1.3076, 0.4940, 1.2701, 1.2803, 1.1750, 1.1790, 1.4607, 0.3393,
        1.2689], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940, 1.2701, 1.2803, 1.1750, 1.1790, 1.4607, 0.3393,
        1.2689], device='cuda:0')
