In [1]:
import torch
import triton
from triton import language as tl
import numpy as np
import numpy.typing as npt

First, let's implement softmax in numpy 

In [2]:
def softmax_np(x: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
    x_exp = np.exp(x)
    return x_exp / np.sum(x_exp, axis=0)

In [3]:
softmax_np(np.array([[1.0, 3.0], [2.0, 2.0], [3.0, 1.0]]))

array([[0.09003057, 0.66524096],
       [0.24472847, 0.24472847],
       [0.66524096, 0.09003057]])

Then let's implement softmax in Triton

In [4]:
@triton.jit
def softmax_kernel(x_ptr, output_ptr, x_num_rows: tl.constexpr, x_num_columns: tl.constexpr, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0) # pid will be the column index, we parallelize across row_index and compute an entire column
    offsets = tl.arange(0, BLOCK_SIZE) * x_num_columns + pid
    mask = offsets < (x_num_rows * x_num_columns)
    
    x = tl.load(x_ptr + offsets, mask=mask, other=-float('inf'))
    x_exp = tl.exp(x)
    denom = tl.sum(x_exp, axis=0)
    y = x_exp / denom

    tl.store(output_ptr + offsets, y, mask=mask)
    

In [5]:
def softmax_triton(x: torch.Tensor) -> torch.Tensor:
    y = torch.empty_like(x)
    x_num_cols = x.shape[1]
    
    grid = lambda meta: (x_num_cols, )
    softmax_kernel[grid](x, y, x.shape[0], x.shape[1], BLOCK_SIZE=1024)
    return y

In [6]:
softmax_triton(torch.tensor([[1.0, 3.0], [2.0, 2.0], [3.0, 1.0]], device="cuda:0", dtype=torch.float32))

tensor([[0.0900, 0.6652],
        [0.2447, 0.2447],
        [0.6652, 0.0900]], device='cuda:0')