In [None]:
import time
import torch
import triton
import triton.language as tl

In [None]:
def softmax_naive_torch(x: torch.Tensor) -> torch.Tensor:
    x_max = x.max(dim=1, keepdim=True).values          # (N, 1)
    safe_x = x - x_max                                 # (N, M)
    exp_x = torch.exp(safe_x)                          # (N, M)
    denom = exp_x.sum(dim=1, keepdim=True)             # (N, 1)
    return exp_x / denom                               # (N, M)

In [None]:
def softmax_pytorch(x: torch.Tensor) -> torch.Tensor:
    return torch.softmax(x, dim=1)

In [None]:
@triton.jit
def _softmax(Y, stride_y_row, X, stride_x_row, M, N, BLOCK_SIZE : tl.constexpr):
    # program idx gives row idx
    row_idx = tl.program_id(0)

    input_row_ptr = X + (row_idx * stride_x_row)
    col_indices = tl.arange(0,BLOCK_SIZE)
    input_ptrs = input_row_ptr + col_indices

    mask = col_indices < N      # mask for valid inputs

    row = tl.load(input_ptrs, mask=mask, other=float("-inf"))

    row = row - tl.max(row, axis=0)
    row = tl.exp(row)
    denom = tl.sum(row, axis=0)
    row = row / denom

    output_row_ptr = Y + (row_idx * stride_y_row)
    output_ptrs = output_row_ptr + col_indices

    tl.store(output_ptrs, row, mask=mask)

def softmax_triton(X: torch.Tensor) -> torch.Tensor:
    # Allocate input/output tensors
    rows, cols = X.shape
    assert(X.dim()==2)

    BLOCK_SIZE = triton.next_power_of_2(cols)

    num_warms = 4
    if BLOCK_SIZE > 2047:
        num_warps = 8
    if BLOCK_SIZE > 4095:
        num_warps = 16

    Y = torch.empty_like(X)     # output buffer

    # SPMD launch grid
    grid = (rows, )
    # enqueue GPU kernel
    _softmax[grid](Y, Y.stride(0),
                X, X.stride(0),
                rows, cols, BLOCK_SIZE,
                num_warps=num_warps)

    return Y

In [None]:
# -----------------------------
# Timing helpers
# -----------------------------
@torch.no_grad()
def time_ms(fn, x, iters=100, warmup=25):
    # Warmup
    for _ in range(warmup):
        fn(x)
    torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end   = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(iters):
        fn(x)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / iters

In [None]:
torch.manual_seed(0)
device = "cuda"

# Choose shapes that make IO matter:
# M large enough to keep SMs busy, N moderately large
M = 81920
N = 4096  # try 1024, 2048, 4096, 8192

x = torch.randn((M, N), device=device, dtype=torch.float16)

# Correctness check (compare to torch.softmax ONLY for validation, not timing)
ref = softmax_pytorch(x)
naive = softmax_naive_torch(x)
max_err = (naive - ref).abs().max().item()
print(f"max abs error naive vs torch.softmax: {max_err:.3e}")

y_triton = softmax_triton(x)
max_err = (y_triton - ref).abs().max().item()
print(f"max abs error triton vs torch.softmax: {max_err:.3e}")

# Benchmarks
t_naive  = time_ms(softmax_naive_torch, x)
t_pytorch = time_ms(softmax_pytorch, x)
t_triton = time_ms(softmax_triton, x)

print(f"naive torch composition: {t_naive:.3f} ms")
print(f"fused pytorch softmax:    {t_pytorch:.3f} ms")
print(f"fused triton softmax:    {t_triton:.3f} ms")
print(f"speedup pytorch native:                {t_naive / t_pytorch:.2f}x")
print(f"speedup triton:                {t_naive / t_triton:.2f}x")


max abs error naive vs torch.softmax: 3.052e-05
max abs error triton vs torch.softmax: 3.815e-06
naive torch composition: 4.587 ms
fused pytorch softmax:    1.823 ms
fused triton softmax:    0.773 ms
speedup pytorch native:                2.52x
speedup triton:                5.93x
