In [None]:
import torch
import triton
import triton.language as tl


In [None]:
@triton.jit
def safeSoftmaxKernel(
    outs, ins,
    outs_stride, ins_stride,
    cols, BLOCK_SIZE: tl.constexpr
):
    bid = tl.program_id(axis=0)
    ins_start = ins + bid * ins_stride
    outs_start = outs + bid * outs_stride
    bsz = tl.load(ins_start + tl.arange(0, BLOCK_SIZE), 
                  mask=tl.arange(0, BLOCK_SIZE) < cols, 
                  other=float('-inf'))
    maxV = tl.max(bsz, axis=0)
    upper = tl.exp(bsz - maxV)
    sums = tl.sum(upper, axis=0)
    softmax_output = upper / sums
    tl.store(outs_start + tl.arange(0, BLOCK_SIZE), 
             softmax_output, 
             mask=tl.arange(0, BLOCK_SIZE) < cols)


In [None]:
def triton_safe_softmax(x):
    n_rows, n_cols = x.shape
    output = torch.empty_like(x)
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    BLOCK_SIZE = min(BLOCK_SIZE, 1024)
    
    grid = (n_rows,)

    safeSoftmaxKernel[grid](
        output,
        x,
        output.stride(0),
        x.stride(0), 
        n_cols, 
        BLOCK_SIZE=BLOCK_SIZE,
    )

    return output

In [None]:
torch.manual_seed(0)
x = torch.randn(256, 1024, device='cuda')

# 3.1 PyTorch 自带 softmax
torch_result = torch.softmax(x, dim=1)

# 3.2 Triton 版 softmax
triton_result = triton_safe_softmax(x)

# 3.3 计算最大误差
max_diff = torch.max(torch.abs(torch_result - triton_result))
print(f"pytorch与triton的safe-softmax最大误差为: {max_diff:.2e}")

# 3.4 检查是否在给定精度范围内
is_close = torch.allclose(torch_result, triton_result, rtol=1e-5, atol=1e-5)
print(f"结果正确否: {is_close}")