In [None]:
!nvidia-smi
!pip -q install triton

In [None]:
import torch
import triton
import triton.language as tl

In [None]:
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.__version__)

#Vector Addition

In [None]:
@triton.jit                                                                     # Decorator that turns a Python function into a GPU kernel
def add_kernel(x_ptr, y_ptr, out_ptr,                                           # Pointers to global memory
               N,                                                               # Total number of elements to add
               BLOCK_SIZE: tl.constexpr):                                       # tl.constexpr marks a parameter as compile-time constant

                                                                                # A program instance is one independent unit of parallel execution in Triton
    pid = tl.program_id(axis=0)                                                 # pid ∈ {0, 1, …, G−1} where G = ceil(N/BLOCK_SIZE)
    offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)                          # [pid * BLOCK_SIZE + 0, ..., pid * BLOCK_SIZE + BLOCK_SIZE-1]
    mask = offs < N                                                             # [True, True, ..., False]
    x = tl.load(x_ptr + offs, mask=mask)                                        # Read from global memory into registers
    y = tl.load(y_ptr + offs, mask=mask)                                        # Read from global memory into registers
    z = x + y                                                                   # Compute in registers
    tl.store(out_ptr + offs, z, mask=mask)                                      # Write from registers to global memory



def add_triton(x: torch.Tensor,                                                 # Python wrapper that defines the launch grid and calls the kernel
               y: torch.Tensor, 
               block_size: int = 1024) -> torch.Tensor:
    """
    x, y: 1D or arbitrary-shaped CUDA tensors of same shape & dtype.
    This wrapper flattens them, launches the kernel, and reshapes back.
    """
    assert x.is_cuda and y.is_cuda, "Use CUDA tensors"
    assert x.shape == y.shape and x.dtype == y.dtype

    N = x.numel()
    out = torch.empty_like(x)
                                                                                # Triton needs a grid (how many program instances to launch).
                                                                                # We want one program instance per BLOCK_SIZE chunk of elements.
    grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)                   # meta is a small dict of compile-time constants Triton supplies to the grid callable
    add_kernel[grid](x.reshape(-1), y.reshape(-1), out.reshape(-1),             # We pass the grid callable to the kernel
                     N,                                                         # Launch syntax: kernel[grid](args..., CONST1=..., CONST2=...)
                     BLOCK_SIZE=block_size)                                     # x.reshape(-1), y.reshape(-1), out.reshape(-1): Pass flattened CUDA tensors as pointer arguments to the kernel.
                                                                                # Strictly speaking, x.reshape(-1) is not itself a pointer.
                                                                                # Triton automatically extracts the tensor’s underlying device pointer (the address of its data in GPU memory) when launching.

    return out

In [None]:
x = torch.randn(1_000_000, device="cuda", dtype=torch.float32)
y = torch.randn_like(x)

out = add_triton(x, y, block_size=1024)

print(torch.allclose(out, x + y))

import time
def bench(fn, *args, warmup=10, iters=50):
    for _ in range(warmup):
        _ = fn(*args); torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(iters):
        _ = fn(*args); torch.cuda.synchronize()
    return (time.time() - t0) / iters

print("Triton ms:", 1e3 * bench(add_triton, x, y))
print("PyTorch ms:", 1e3 * bench(lambda a,b: a+b, x, y))