In [None]:
!nvidia-smi
!pip -q install triton

In [None]:
import torch
import triton
import triton.language as tl

In [None]:
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.__version__)

#Vector Addition

In [None]:
@triton.jit                                                                     # Decorator that turns a Python function into a GPU kernel
def add_kernel(x_ptr, y_ptr, out_ptr,                                           # Pointers to global memory
               N,                                                               # Total number of elements to add
               BLOCK_SIZE: tl.constexpr):                                       # tl.constexpr marks a parameter as compile-time constant

                                                                                # A program instance is one independent unit of parallel execution in Triton
    pid = tl.program_id(axis=0)                                                 # pid ∈ {0, 1, …, G−1} where G = ceil(N/BLOCK_SIZE)
    offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)                          # [pid * BLOCK_SIZE + 0, ..., pid * BLOCK_SIZE + BLOCK_SIZE-1]
    mask = offs < N                                                             # [True, True, ..., False]
    x = tl.load(x_ptr + offs, mask=mask)                                        # Read from global memory into registers
    y = tl.load(y_ptr + offs, mask=mask)                                        # Read from global memory into registers
    z = x + y                                                                   # Compute in registers
    tl.store(out_ptr + offs, z, mask=mask)                                      # Write from registers to global memory



def add_triton(x: torch.Tensor,                                                 # Python wrapper that defines the launch grid and calls the kernel
               y: torch.Tensor, 
               block_size: int = 1024) -> torch.Tensor:
    """
    x, y: 1D or arbitrary-shaped CUDA tensors of same shape & dtype.
    This wrapper flattens them, launches the kernel, and reshapes back.
    """
    assert x.is_cuda and y.is_cuda, "Use CUDA tensors"
    assert x.shape == y.shape and x.dtype == y.dtype

    N = x.numel()
    out = torch.empty_like(x)
                                                                                # Triton needs a grid (how many program instances to launch).
                                                                                # We want one program instance per BLOCK_SIZE chunk of elements.
    grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)                   # meta is a small dict of compile-time constants Triton supplies to the grid callable
    add_kernel[grid](x.reshape(-1), y.reshape(-1), out.reshape(-1),             # We pass the grid callable to the kernel
                     N,                                                         # Launch syntax: kernel[grid](args..., CONST1=..., CONST2=...)
                     BLOCK_SIZE=block_size)                                     # x.reshape(-1), y.reshape(-1), out.reshape(-1): Pass flattened CUDA tensors as pointer arguments to the kernel.
                                                                                # Strictly speaking, x.reshape(-1) is not itself a pointer.
                                                                                # Triton automatically extracts the tensor’s underlying device pointer (the address of its data in GPU memory) when launching.

    return out

In [None]:
x = torch.randn(1_000_000, device="cuda", dtype=torch.float32)
y = torch.randn_like(x)

out = add_triton(x, y, block_size=1024)

print(torch.allclose(out, x + y))

import time
def bench(fn, *args, warmup=10, iters=50):
    for _ in range(warmup):
        _ = fn(*args); torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(iters):
        _ = fn(*args); torch.cuda.synchronize()
    return (time.time() - t0) / iters

print("Triton ms:", 1e3 * bench(add_triton, x, y))
print("PyTorch ms:", 1e3 * bench(lambda a,b: a+b, x, y))

#Matrix Multiplication

In [None]:
@triton.jit                                                                     # Decorator that turns a Python function into a GPU kernel
def matmul_kernel(A_ptr, B_ptr, C_ptr,                                          # Pointers to global memory
                  M, N, K,                                                      # Shapes: A → (M, K) B → (K, N) C → (M, N)
                  stride_am, stride_ak,                                         # Stride allows for accessing elements of the matrices
                  stride_bk, stride_bn,                                         # address = base + i * stride_row + j * stride_col
                  stride_cm, stride_cn,
                  BLOCK_M: tl.constexpr,                                        # Compile-time constants telling Triton the tile sizes along M, N, K
                  BLOCK_N: tl.constexpr, 
                  BLOCK_K: tl.constexpr):
  
    pid_m = tl.program_id(0)                                                    # Each tile (block) of the output matrix is handled by one program instance
    pid_n = tl.program_id(1)                                                    # pid_m and pid_n identify the tile position in the output matrix C (row-tile, column-tile)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)                            # Offsets give the row (offs_m) and column (offs_n) indices of elements - 
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)                            # - inside that tile of the output matrix.

    offs_k = tl.arange(0, BLOCK_K)                                              # offs_k indexes the shared K dimension, i.e. the columns of A and the rows of B


                                                                                # C = A x B  =>  C_ij = A_ik x B_kj
                                                                                # When we tile:
                                                                                # For A, each tile covers some rows of A (so moves down with pid_m) and all needed columns K.
                                                                                # For B, each tile covers some columns of B (so moves right with pid_n) and all needed rows K.
                                                                                # A’s tile start = (pid_m * BLOCK_M, 0) → move down for different row-blocks.
                                                                                # B’s tile start = (0, pid_n * BLOCK_N) → move right for different column-blocks.



    A_block_ptr = tl.make_block_ptr(A_ptr,                                      # A_ptr → base address of A.
                                    (M, K),                                     # (M, K) → full matrix shape.
                                    (stride_am, stride_ak),                     # (stride_am, stride_ak) → row/column strides
                                    (pid_m * BLOCK_M, 0),                       # (pid_m * BLOCK_M, 0) → starting offset of this tile (top-left corner).
                                    (BLOCK_M, BLOCK_K),                         # (BLOCK_M, BLOCK_K) → tile size this instance will load.
                                    (1, 0))                                     # (1, 0) → memory traversal order (row-major).
    
    B_block_ptr = tl.make_block_ptr(B_ptr,                                      # B_ptr → base address of B.
                                    (K, N),                                     # (K, N) → full matrix shape.
                                    (stride_bk, stride_bn),                     # stride_bk, stride_bn) → row/column strides
                                    (0, pid_n * BLOCK_N),                       # (0, pid_n * BLOCK_N)
                                    (BLOCK_K, BLOCK_N),                         # (BLOCK_K, BLOCK_N) → tile size this instance will load.
                                    (1, 0))                                     # (1, 0) → row-major traversal.
    
    acc = tl.zeros([BLOCK_M, BLOCK_N], tl.float32)                              # Creates a local tile of zeros in SRAM — shape (BLOCK_M, BLOCK_N) 
                                                                                # — to accumulate partial sums for the output tile of matrix C.

    for k in range(0, K, BLOCK_K):                                              # Iterate over the K dimension in tiles of size BLOCK_K
        A = tl.load(A_block_ptr)                                                # Load current A tile (BLOCK_M × BLOCK_K) from global → registers/SRAM.
        B = tl.load(B_block_ptr)                                                # Load current B tile (BLOCK_K × BLOCK_N) from global → registers/SRAM.
        acc += tl.dot(A, B)                                                     # Tile matmul and accumulate into the output tile (BLOCK_M × BLOCK_N)
        A_block_ptr = tl.advance(A_block_ptr, (0, BLOCK_K))                     # Slide A’s block pointer right by BLOCK_K (next K-slice)
        B_block_ptr = tl.advance(B_block_ptr, (BLOCK_K, 0))                     # Slide B’s block pointer down by BLOCK_K (next K-slice).

    tl.store(C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, # Computes the global addresses for the C tile and writes acc there
             acc)

In [None]:
def matmul_triton(A: torch.Tensor, B: torch.Tensor,
                  BLOCK_M=128, BLOCK_N=128, BLOCK_K=32) -> torch.Tensor:

    assert A.is_cuda and B.is_cuda, "Use CUDA tensors"
    assert A.dtype == B.dtype, "Dtypes must match"
    assert A.shape[-1] == B.shape[-2], "Inner dims must match (A: MxK, B: KxN)"

    A_ = A.contiguous()   
    B_ = B.contiguous()

    M, K = A_.shape
    K2, N = B_.shape
    assert K == K2, "Inner dims must match (A: MxK, B: KxN)"

    C = torch.empty((M, N), device=A.device, dtype=torch.float32) 

    stride_am, stride_ak = A_.stride()
    stride_bk, stride_bn = B_.stride()
    stride_cm, stride_cn = C.stride()

    
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))                   # Triton launch grid: (#tiles along M, #tiles along N)


    matmul_kernel[grid](                                                        # ← Triton: kernel[grid](...)
        A_, B_, C,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K                       # ← tl.constexpr params
    )
    return C

In [None]:
import time

def bench(fn, *args, warmup=5, iters=30, sync_cuda=False):
    for _ in range(warmup):
        _ = fn(*args)
        if sync_cuda: torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(iters):
        _ = fn(*args)
        if sync_cuda: torch.cuda.synchronize()
    return (time.time() - t0) / iters

# sizes
M, K, N = 1024, 1536, 768

A_cpu = torch.randn(M, K, dtype=torch.float32)
B_cpu = torch.randn(K, N, dtype=torch.float32)

A_gpu = A_cpu.to("cuda", dtype=torch.float16)
B_gpu = B_cpu.to("cuda", dtype=torch.float16)

C_ref = A_cpu @ B_cpu
C_tri = matmul_triton(A_gpu, B_gpu)             
print("allclose (tri vs ref):", torch.allclose(C_tri.cpu(), C_ref, atol=1e-2, rtol=0))

t_tri = bench(matmul_triton, A_gpu, B_gpu, sync_cuda=True)
print(f"Triton GPU: {1e3*t_tri:.2f} ms")

def torch_cuda_mm(A,B): return (A.float() @ B.float())  
t_torch = bench(torch_cuda_mm, A_gpu, B_gpu, sync_cuda=True)
print(f"Torch CUDA: {1e3*t_torch:.2f} ms")

# CPU baseline (single-threaded vs multithreaded may vary)
t_cpu = bench(lambda a,b: a @ b, A_cpu, B_cpu, sync_cuda=False)
print(f"Torch CPU : {1e3*t_cpu:.2f} ms")