In [29]:
import torch
import triton
import triton.language as tl
import time

# Vector Addition 

## Introduction

The goal of this exercise is to understand how to build a triton kernel that adds two vectors of floating point numbers. We will compare its compute time with the vanilla pytorch addition operation.

A triton kernel is a function that is executed on a GPU. The kernel is executed by multiple threads in parallel, where each thread will add a different pair of elements from the two vectors and write the result to a specific location in the output vector.

The kernel is launched with a grid of blocks (in our examaple each block will have a single thread) where each block contains a number of threads. The kernel is then executed by all threads in all blocks in parallel.

As an example, let's consider a vector of length 256 and a block size of 64, there would be 256 / 64 = 4 blocks (in our case 4 threads), where each of these instances of the add_kernel would access and compute the respective ranges: $[0:64], [64:128], [128:192],$ and $[192:256]$ of the vectors. And each of these instances (threads) would write the result to their respective ranges in the output vector.

In [30]:
@triton.jit # triton decorator that tells triton that this function is a triton function.
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector (it needs to be preallocated).
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process (`constexpr` so it can be used as a shape value).
):
    # Since add_kernel has multiple instance we get the program id to know which part of the vector to process.
    # For instance: Program 0: pid = 0, Program 1: pid = 1, ...
    pid = tl.program_id(axis=0) 
    
    # Note that offsets is a list of pointers.
    # If pid = 1, block_start = 1 * 64, offsets = [64, 65, 66, ..., 127].
    block_start = pid * BLOCK_SIZE 
    offsets = block_start + tl.arange(0, BLOCK_SIZE) 

    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements

    # Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y

    # Write x + y back to DRAM.
    # Each program writes back to a different part of the output vector. So everything is concatenated at the end.
    tl.store(output_ptr + offsets, output, mask=mask) 

We will now build the add function, which will launch the kernel for each thread in parallel (using a grid of blocks and threads). The kernel will add the two vectors and write the result to the output vector.

In [31]:
def add(x: torch.Tensor, y: torch.Tensor):

    # We need to preallocate the output.
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()

    # The SPMD launch grid denotes the number of kernel instances that run in parallel.
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # It calculates the number of blocks required for your kernel based on the total number of elements (n_elements) and the number of threads per block (BLOCK_SIZE). The result is returned as a tuple containing a single value which is the number of blocks needed to process the data.
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

    #  Each torch.tensor object is implicitly converted into a pointer to its first element.
    #  "triton.jit" functions can be indexed with a launch grid (which is a function) to obtain a callable GPU kernel.
    #  Don't forget to pass meta-parameters as keywords arguments.
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still running asynchronously at this point.
    return output

We will now compare the compute time for vector addition using Triton and PyTorch.

In [32]:
x = torch.randn(100000, device='cuda')
y = torch.randn(100000, device='cuda')

# Warm-up to get rid of the first run overhead. 
add(x, y)
torch.cuda.synchronize()

# Call your function
start = time.time()
output = add(x, y)
torch.cuda.synchronize() # Wait for the GPU operations to finish
end = time.time()

print(f"Triton: {(end - start)*1e3:.6f} ms")    

start = time.time()
output = x + y
torch.cuda.synchronize() # Wait for the GPU operations to finish
end = time.time()

print(f"Torch: {(end - start)*1e3:.6f} ms")

Triton: 1.060724 ms
Torch: 1.079082 ms


This gives roughly the equivalent time to the torch implementation, but since it is a simple operation, it doesn't really make a difference.