In [None]:
import torch
import triton
import triton.language as tl

In [None]:
# this code will run on the GPU
# GPU will only give the pointer to first element of tensor in memory, 
# then upto us for computing all the indices of elements that we want to access

@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to the first input vector
    y_ptr,  # *Pointer* to the second input vector
    output_ptr,  # *Pointer* to the output vectorelement
    n_elements,  # size of the vector
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process
    # NOTE: 'constexpr' so it can be used as a shape value
):
    # There are multiple 'programs' processing different data. We identify which program 
    # This is analogous to the block id in CUDA
    pid = tl.program_id(axis = 0) # We use 1D launch grid so axis is 0

    # This program will process inputs that are offset from the initial data.
    # For eg, if you had a vector of length 256 and block_size 64, then programs
    # would each access the elements [0:64, 64:128, 128:192. 192:256].
    # *Note that offsets is a list of pointers*

    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE) # how to load the elements based on the pointer
    # for pid 0 --> offset is 0, 1, 2, ..., 1023
    # for pid 1 --> offset is 1024, 1025, 1026, ..., 2047
    # for pid 2 --> offset is 2048, 2049, 2050, ..., 3071

    # Create a mask to guard memeory operations agains out-of-bounds accesses
    # We need a mask because n_elements may not be a multiple of BLOCK_SIZE
    # So the last program (with the largest pid) cannot access all the elements
    # in its block, so we need to mask out the loading of the elements only
    # to those actually present in the tensor
    # Eg. if we have 2060 elements, then pid 2 --> 2048, 2049, .. 2060, 2061, ...3071
    # Mask ensures threads working after 2060 don't know anything 
    # i.e. of all offsets present, only work for those with value < n_elements
    mask = offsets < n_elements
    
    # Load x and y from DRAM, masking out any extra elements in case the input is not  
    # a multiple of BLOCK_SIZE
    x = tl.load(x_ptr + offsets, mask = mask)
    y = tl.load(y_ptr + offsets, mask = mask)
    # in CUDA we did output[i] = x[i] + y[i], but here we do all at once
    output = x + y # Shape: BLOCK_SIZE
    
    # Write x + y back to DRAM
    tl.store(output_ptr + offsets, output, mask = mask)

In [None]:
def add(x: torch.Tensor, y: torch.Tensor): 
    # we need to preallocate the output
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel() # gives total number of elements in array

    # The SPMD launch grid denotes the number of kernel instances that run in parallel
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # In this case, we use a 1D grid where the size is the number of blocks
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) # does ceil(n_elements / meta['BLOCK_SIZE)

    # NOTE:
    # - Each torch.tensor object is implicitly converted into a pointer to its first element
    # - triton.jit functions can be indexed with a launch grid to obtain a callable GPU kernel
    # - Don't forget to pass meta-parameters as keywords argument

    # What each block should do - is defined in the kernel
    add_kernel[grid](x, y, output, n_elements=n_elements, BLOCK_SIZE=1024)
    
    # We return a handle to z but, since torch.cuda.synchronize() hasn't been called, the kernel
    # is still running asynchronously at this point
    return output


In [None]:
torch.manual_seed(0)
size = 98432

In [None]:
x = torch.rand(size, device = 'cuda')
y = torch.rand(size, device = 'cuda')
output_torch = x + y
output_triton = add(x, y)
print(f"output: {output_torch}, output_triton: {output_triton}")
print(
    f"The maximum difference between torch and triton is {torch.max(torch.abs(output_torch - output_triton))}"
)