In [1]:
import torch
import triton
from triton import language as tl

In [2]:
@triton.jit
def softmax_kernel(x_ptr, y_ptr, x_num_rows, x_num_cols, BLOCK_SIZE: tl.constexpr):
    x_col = tl.program_id(0) # parallelize on the grid for columns
    tid = tl.arange(0, BLOCK_SIZE)
    x_offsets = tid * x_num_cols + x_col
    mask = tid < x_num_rows
    mask = mask & (x_col < x_num_cols)
    
    x = tl.load(x_ptr + x_offsets, mask=mask)
    x_max = tl.max(x)
    x = x - x_max
    x_exp = tl.exp(x)
    x_exp_sum = tl.sum(x_exp)
    y = x_exp / x_exp_sum
    tl.store(y_ptr + x_offsets, y, mask=mask)

In [3]:
def softmax_triton(x: torch.Tensor) -> torch.Tensor:
    y = torch.empty_like(x)
    block_size = triton.next_power_of_2(x.shape[0])
    
    grid = lambda meta: (x.shape[1],)
    softmax_kernel[grid](x, y, x.shape[0], x.shape[1], BLOCK_SIZE = block_size)
    return y

In [4]:
r1 = softmax_triton(torch.tensor([[1, 5], [3, 3], [5, 1]], dtype=torch.float32, device="cuda:0"))

In [5]:
r2 = torch.softmax(torch.tensor([[1, 5], [3, 3], [5, 1]], dtype=torch.float32, device="cuda:0"), 0)

In [6]:
r1 - r2

tensor([[-9.2190e-05, -5.0333e-03],
        [-6.8119e-04, -6.8118e-04],
        [-5.0334e-03, -9.2188e-05]], device='cuda:0')