# 参考

- https://nagi.fun/triton-intro-softmax

- https://github.com/lessw2020/triton_kernels_for_fun_and_profit/blob/main/demos/demo_softmax.py

In [35]:
import torch

import triton
import triton.language as tl

# Softmax 实现

## PyTorch Eager Mode

In [36]:
def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    x_max = x.max(dim=1, keepdim=True)[0]
    safe_x = x - x_max
    numerator = torch.exp(safe_x) 
    denominator = numerator.sum(dim=1, keepdim=True)
    softmax_out = numerator / denominator
    return softmax_out

In [37]:
@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):
    # setup input ptrs
    row_index = tl.program_id(0)

    row_start_ptr = input_ptr + (row_index * stride_input_row)
    col_offsets = tl.arange(0,block_size)
    input_pointers = row_start_ptr + col_offsets

    row_mask = col_offsets < num_cols

    # move to SRAM
    row = tl.load(input_pointers,mask = row_mask, other = float("-inf") )

    # softmax itself
    safe_row = row - tl.max(row, axis=0) 
    numerator = tl.exp(safe_row)
    denominator = tl.sum(numerator, axis=0)
    sm_out = numerator / denominator

    # write back to HBM
    output_row_ptr = output_ptr + (row_index * stride_output_row)
    output_pointers = output_row_ptr + col_offsets
    tl.store(output_pointers, sm_out, mask= row_mask)

In [38]:
def softmax(x:torch.Tensor)->torch.Tensor:
    """ Triton impl of Softmax, fwd pass only """
    rows, cols = x.shape
    assert x.dim() ==2, f"only accepts 2D tensors for now"
    block_size = triton.next_power_of_2(cols)
    num_warps = 4  # *32 
    if block_size > 2047: # 2048
        num_warps = 8
    if block_size > 4095: # 4096
        num_warps=16
    
    grid = (rows,)

    # allocate our output buffer
    sm_out = torch.empty_like(x)

    _softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps =num_warps

    )

    return sm_out

In [39]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=[
            'triton',
            'torch-native',
            'torch-jit',
        ],  
        line_names=[
            "Triton",
            "Torch (native)",
            "Torch (jit)",
        ],  
        styles=[('blue', '-'), ('green', '-'), ('green', '--')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))

def benchmark(M, N, provider):
    x = torch.randn(M, N, device='cuda', dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch-native':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
    if provider == 'torch-jit':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)
    gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)

In [40]:
benchmark.run(print_data=True, show_plots=True)

ValueError: object __array__ method not producing an array

<Figure size 640x480 with 1 Axes>

softmax-performance:
          N       Triton  Torch (native)  Torch (jit)
0     256.0   546.133347      585.142849   199.804881
1     384.0   819.200021      768.000002   261.446801
2     512.0   910.222190      910.222190   297.890907
3     640.0   975.238103      930.909084   330.322585
4     768.0  1068.521715      983.040025   346.140834
5     896.0  1146.880029     1023.999986   353.975316
6    1024.0  1170.285698     1092.266694   352.344077
7    1152.0  1152.000003      604.327881   354.461542
8    1280.0  1204.705861      660.645170   341.333342
9    1408.0  1251.555511      715.174609   333.748161
10   1536.0  1293.473742      780.190482   336.657521
11   1664.0  1298.731729      806.787872   334.893076
12   1792.0  1333.581395      855.880586   333.395349
13   1920.0  1365.333313      890.434763   333.913036
14   2048.0  1365.333285      949.797080   334.367358
15   2176.0  1365.333358      967.111077   334.769235
16   2304.0  1391.094346     1009.972563   333.610868
17   24