In [1]:
pip install triton torch



In [1]:
import torch
import triton
import triton.language as tl
import time

### Basic Softmax Kernel Implementation

In [25]:
@triton.jit

def softmax_kernel(X_ptr, Y_ptr, N, stride_x, stride_y, BLOCK_SIZE: tl.constexpr):
    row_id = tl.program_id(0)
    offsets = tl.arange(0, BLOCK_SIZE)
    row_offsets = row_id * stride_x + offsets

    mask = offsets < N
    x = tl.load(X_ptr + row_offsets, mask=mask)
    x_max = tl.max(x, axis=0)
    x = x - x_max
    num = tl.exp(x)
    denom = tl.sum(num, axis=0)
    softmax = num / denom
    tl.store(Y_ptr + row_id * stride_y + offsets, softmax, mask=mask)

In [26]:
B, N = 1024, 4096
x = torch.randn((B, N), device='cuda')
y_triton = torch.empty_like(x)
y_torch = torch.empty_like(x)

In [27]:
def run_triton():
    softmax_kernel[(B,)](
        x, y_triton, N,
        x.stride(0), y_triton.stride(0),
        BLOCK_SIZE=N
    )

In [28]:
def run_torch():
    y_torch.copy_(torch.nn.functional.softmax(x, dim=1))

In [29]:
def time_kernel(fn, *args, repeats=10):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    times = []

    for _ in range(repeats):
        torch.cuda.synchronize()
        start.record()
        fn(*args)
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))  # milliseconds

    return sum(times) / len(times)

In [31]:
triton_ms = time_kernel(run_triton)
torch_ms = time_kernel(run_torch)

# Accuracy check
max_diff = torch.max(torch.abs(y_triton - y_torch)).item()

# Print results
print(f"Triton softmax time: {triton_ms:.4f} ms")
print(f"PyTorch softmax time: {torch_ms:.4f} ms")
print(f"Max difference: {max_diff:.2e}")

Triton softmax time: 0.2310 ms
PyTorch softmax time: 0.4021 ms
Max difference: 3.73e-09


### Vectorized Softmax in Triton

A softmax kernel where each thread loads 4 floats

In [61]:
@triton.jit

def softmax_kernel_vectorized(
    X, Y, stride_xm, stride_ym, n_cols,
    BLOCK_SIZE: tl.constexpr, VEC: tl.constexpr
):
    row = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE * VEC)
    mask = col_offsets < n_cols

    x_ptrs = X + row * stride_xm + col_offsets
    y_ptrs = Y + row * stride_ym + col_offsets

    x = tl.load(x_ptrs, mask=mask, other=-float("inf"))
    row_max = tl.max(x, axis=0)
    x = x - row_max
    exp_x = tl.exp(x)
    row_sum = tl.sum(exp_x, axis=0)
    result = exp_x / row_sum
    tl.store(y_ptrs, result, mask=mask)

In [62]:
BLOCK_SIZE = 128
VEC = 4  # each thread processes 4 values

def run_triton_vectorized():
  softmax_kernel_vectorized[(B,)](
      x, y_triton, x.stride(0), y_triton.stride(0), N,
      BLOCK_SIZE=BLOCK_SIZE,
      VEC=VEC
  )

In [64]:
B, N = 1024, 4096
x = torch.randn((B, N), device='cuda')
y_triton_naive = torch.empty_like(x)
y_triton_vec = torch.empty_like(x)
y_torch = torch.empty_like(x)

# Run benchmarks
time_naive = time_kernel(run_triton)
time_vec = time_kernel(run_triton_vectorized)
time_torch = time_kernel(run_torch)

# Accuracy check
diff_naive = torch.max(torch.abs(y_triton_naive - y_torch)).item()
diff_vec = torch.max(torch.abs(y_triton_vec - y_torch)).item()

print(f"Max diff (naive): {diff_naive:.2e}")
print(f"Max diff (vectorized): {diff_vec:.2e}")
print("--------------------------------------")

print(f"Triton naive softmax time: {time_naive:.4f} ms")
print(f"Triton vectorized softmax time: {time_vec:.4f} ms")
print(f"PyTorch softmax time: {time_torch:.2e}")

Max diff (naive): 2.48e-02
Max diff (vectorized): 5.01e+00
--------------------------------------
Triton naive softmax time: 0.2953 ms
Triton vectorized softmax time: 0.0655 ms
PyTorch softmax time: 4.00e-01


#### This kernel implements a tiled and vectorized softmax using Triton. Since the input rows can be very large, the row is split into multiple tiles, each processed in chunks of BLOCK_SIZE × VEC. The kernel runs in three passes: first to compute the global row-wise max (for numerical stability), second to compute the sum of exponentials, and third to normalize and store the final softmax values. The use of vectorized memory access improves performance by loading multiple values at once, while tiling ensures the kernel scales to wide input sizes efficiently.

In [53]:
@triton.jit
def softmax_kernel_tiled(
    X, Y, stride_xm, stride_ym, n_cols,
    BLOCK_SIZE: tl.constexpr,
    VEC: tl.constexpr
):
    row = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE * VEC)  # vectorized offsets
    row_max = -float("inf")
    row_sum = 0.0

    # === Pass 1: Row-wise max
    for i in range(0, n_cols, BLOCK_SIZE * VEC):
        cols = col_offsets + i
        mask = cols < n_cols
        x_ptrs = X + row * stride_xm + cols
        x = tl.load(x_ptrs, mask=mask, other=-float("inf"))
        row_max = tl.maximum(row_max, tl.max(x, axis=0))

    # === Pass 2: Sum of exp(x - max)
    for i in range(0, n_cols, BLOCK_SIZE * VEC):
        cols = col_offsets + i
        mask = cols < n_cols
        x_ptrs = X + row * stride_xm + cols
        x = tl.load(x_ptrs, mask=mask, other=-float("inf"))
        exp_x = tl.exp(x - row_max)
        row_sum += tl.sum(exp_x, axis=0)

    # === Pass 3: Normalize and store
    for i in range(0, n_cols, BLOCK_SIZE * VEC):
        cols = col_offsets + i
        mask = cols < n_cols
        x_ptrs = X + row * stride_xm + cols
        y_ptrs = Y + row * stride_ym + cols
        x = tl.load(x_ptrs, mask=mask, other=-float("inf"))
        exp_x = tl.exp(x - row_max)
        softmax = exp_x / row_sum
        tl.store(y_ptrs, softmax, mask=mask)


In [54]:
BLOCK_SIZE = 128
VEC = 4

def run_triton_tiled():
    softmax_kernel_tiled[(B,)](
      x, y_triton, x.stride(0), y_triton.stride(0), N,
      BLOCK_SIZE=BLOCK_SIZE,
      VEC=VEC
  )

In [59]:
B, N = 1024, 4096
x = torch.randn((B, N), device='cuda')
y_triton_naive = torch.empty_like(x)
y_triton_vec = torch.empty_like(x)
y_triton_tiled = torch.empty_like(x)
y_torch = torch.empty_like(x)

# Run benchmarks
time_naive = time_kernel(run_triton)
time_vec = time_kernel(run_triton_vectorized)
time_tiled = time_kernel(run_triton_tiled)
time_torch = time_kernel(run_torch)

# Accuracy check
diff_naive = torch.max(torch.abs(y_triton_naive - y_torch)).item()
diff_vec = torch.max(torch.abs(y_triton_vec - y_torch)).item()
diff_tiled = torch.max(torch.abs(y_triton_tiled - y_torch)).item()

print(f"Max diff (naive): {diff_naive:.2e}")
print(f"Max diff (vectorized): {diff_vec:.2e}")
print(f"Max diff (vectorized + tiled): {diff_tiled:.2e}")
print("--------------------------------------")

print(f"Triton softmax time: {time_naive:.4f} ms")
print(f"Triton vectorized softmax time: {time_vec:.4f} ms")
print(f"Triton vectorized + tiled softmax time: {time_tiled:.4f} ms")
print(f"PyTorch softmax time: {time_torch:.4f} ms")

Max diff (naive): 2.19e-02
Max diff (vectorized): 5.24e+00
Max diff (vectorized + tiled): 2.15e-02
--------------------------------------
Triton softmax time: 0.4511 ms
Triton vectorized softmax time: 0.1155 ms
Triton vectorized + tiled softmax time: 0.3728 ms
PyTorch softmax time: 0.4187 ms
