# Fused Softmax
Adopted from https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html#sphx-glr-getting-started-tutorials-02-fused-softmax-py

In [None]:
!nvidia-smi

In [None]:
import torch
import triton
import triton.language as tl

In [None]:
DEVICE = 'cuda'

## Naive softmax

In [None]:
def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

## Test the correctness

In [None]:
def test_softmax(BT, V, f):
    shape = (BT, V)
    x = torch.randn(shape, device=DEVICE)

    output_ref = torch.softmax(x, dim=-1)
    output_triton = f(x)    
    
    print(f"Testing {shape=}")
    torch.testing.assert_close(output_triton, output_ref)
    print("✅ Triton kernel is correct!")

In [None]:
BT = 1024
for V in [2 ** i for i in range(9, 16)]:
    try:
        test_softmax(BT, V, naive_softmax)
    except AssertionError as e:
        print("AssertionError occurred: {e}")

## Benchmark

In [None]:
import os

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['V'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2 ** i for i in range(3, 14, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['naive', 'torch'],  # Possible values for `line_arg`.
        line_names=['naive', 'torch'],  # Label name for the lines.
        styles=[('blue', '-'), ('green', '-')],  # Line styles.
        ylabel='GB/s',  # Label name for the y-axis.
        plot_name='naive-softmax-performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(V, provider):
    BT = 4096
    shape = (BT, V)
    x = torch.randn(shape, device=DEVICE, dtype=torch.float32)

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), quantiles=quantiles)
    if provider == 'naive':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


benchmark.run(print_data=True, show_plots=True, save_path=os.path.abspath("../benchmark"))

## Write Triton kernel

In [None]:
@triton.jit
def softmax_kernel_v0(x_ptr, output_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
    
    pid = tl.program_id(0)

    # Calculate the starting pointer of each row
    x_row_start = x_ptr + pid * n_cols
    output_row_start = output_ptr + pid * n_cols

    offsets = tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_cols

    x_row = tl.load(x_row_start + offsets, mask=mask, other=float('-inf')) # shape: (1, BLOCK_SIZE)
    x_max = tl.max(x_row, axis=0)                                          # shape: (1,)
    numerator = tl.exp(x_row - x_max)                                      # shape: (1, BLOCK_SIZE)
    denominator = tl.sum(numerator, axis=0)                                # shape: (1,)
    
    softmax_output = numerator / denominator                               # shape: (1, BLOCK_SIZE)
 

    tl.store(output_row_start + offsets, softmax_output, mask=mask)

## Helper function to allocate tensors

In [None]:
def triton_softmax_v0(x):
    BT, V = x.shape
    n_rows, n_cols = BT, V
    output = torch.empty_like(x)
    
    MAX_BLOCK_SIZE = 65536 // x.element_size()
    BLOCK_SIZE = min(MAX_BLOCK_SIZE, triton.next_power_of_2(n_cols))

    assert n_cols <= BLOCK_SIZE, f"This implementation does not support more than {BLOCK_SIZE} elements in the last dimension. Got:{n_cols}"

    grid = lambda META: (n_rows,)
    softmax_kernel_v0[grid](x, output, n_cols, BLOCK_SIZE=BLOCK_SIZE)

    return output


## Benchmark with `triton.testing.do_bench`

In [None]:
import os

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['V'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2 ** i for i in range(3, 14, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['triton', 'torch'],  # Possible values for `line_arg`.
        line_names=['triton', 'torch'],  # Label name for the lines.
        styles=[('blue', '-'), ('green', '-')],  # Line styles.
        ylabel='GB/s',  # Label name for the y-axis.
        plot_name='softmax-performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(V, provider):
    BT = 4096
    shape = (BT, V)
    x = torch.randn(shape, device=DEVICE, dtype=torch.float32)

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_softmax_v0(x), quantiles=quantiles)
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


benchmark.run(print_data=True, show_plots=True, save_path=os.path.abspath("../benchmark"))

## Simple for loop approach


In [None]:
@triton.jit
def for_loop_softmax_kernel(x_ptr, output_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
    
    pid = tl.program_id(0)

    # Calculate the starting pointer of each row
    x_row_start = x_ptr + pid * n_cols
    output_row_start = output_ptr + pid * n_cols


    x_max = float("-inf")

    # First Pass: Find max_x
    for i in tl.range(0, n_cols, BLOCK_SIZE):
        # Chunk in row-wise. Update global maximum each loops
        offsets = i + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_cols
        x_block = tl.load(x_row_start + offsets, mask=mask, other=float('-inf')) 
        
        block_max = tl.max(x_block, axis=0)  # local maximum
        x_max = tl.maximum(x_max, block_max)         # update global maximum


    # Second Pass: Find denominator sum(exp(x - x_max))
    denominator = 0.0

    for i in tl.range(0, n_cols, BLOCK_SIZE):
        offsets = i + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_cols
        x_block = tl.load(x_row_start + offsets, mask=mask, other=float('-inf')) 

        numerator = tl.exp(x_block - x_max)
        denominator += tl.sum(numerator)

    # Now we have the correct denominator

    # Final Pass: Calculate output and store

    for i in tl.range(0, n_cols, BLOCK_SIZE):
        offsets = i + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_cols
        x_block = tl.load(x_row_start + offsets, mask=mask, other=float('-inf')) 

        numerator = tl.exp(x_block - x_max)
        output_block = numerator / denominator

        tl.store(output_row_start + offsets, output_block, mask=mask)




def triton_for_loop_softmax(x):
    BT, V = x.shape
    n_rows, n_cols = BT, V
    output = torch.empty_like(x)
    
    MAX_BLOCK_SIZE = 65536 // x.element_size()
    BLOCK_SIZE = min(MAX_BLOCK_SIZE, triton.next_power_of_2(n_cols))

    grid = lambda META: (n_rows,)
    for_loop_softmax_kernel[grid](x, output, n_cols, BLOCK_SIZE=BLOCK_SIZE)

    return output

In [None]:
BT = 1024
for V in [2 ** i for i in range(9, 16)]:
      test_softmax(BT, V, triton_for_loop_softmax)

In [33]:
import os

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['V'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2 ** i for i in range(3, 18, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['triton', 'torch'],  # Possible values for `line_arg`.
        line_names=['triton', 'torch'],  # Label name for the lines.
        styles=[('blue', '-'), ('green', '-')],  # Line styles.
        ylabel='GB/s',  # Label name for the y-axis.
        plot_name='softmax-naive-for-loop-performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(V, provider):
    BT = 4096
    shape = (BT, V)
    x = torch.randn(shape, device=DEVICE, dtype=torch.float32)

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_for_loop_softmax(x), quantiles=quantiles)
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


benchmark.run(print_data=True, show_plots=True, save_path=os.path.abspath("../benchmark"))

KeyboardInterrupt: 

In [None]:
def calculate_settings(n, element_size):
    # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43

    MAX_BLOCK_SIZE = 65536 // element_size
    BLOCK_SIZE = min(MAX_BLOCK_SIZE, triton.next_power_of_2(n))
    num_warps = 4
    if BLOCK_SIZE >= 32768 // element_size:
        num_warps = 32 
    elif BLOCK_SIZE >= 8192 // element_size:
        num_warps = 16
    elif BLOCK_SIZE >= 2048 // element_size:
        num_warps = 8
    return BLOCK_SIZE, num_warps

In [None]:
def triton_tuned_for_loop_softmax(x):
    BT, V = x.shape
    n_rows, n_cols = BT, V
    output = torch.empty_like(x)
    
    BLOCK_SIZE, num_warps = calculate_settings(n_cols, x.element_size())

    
    grid = lambda META: (n_rows,)
    for_loop_softmax_kernel[grid](x, output, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)

    return output

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['V'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2 ** i for i in range(3, 18, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['triton', 'torch'],  # Possible values for `line_arg`.
        line_names=['triton', 'torch'],  # Label name for the lines.
        styles=[('blue', '-'), ('green', '-')],  # Line styles.
        ylabel='GB/s',  # Label name for the y-axis.
        plot_name='softmax-tuned-for-loop-performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(V, provider):
    BT = 4096
    shape = (BT, V)
    x = torch.randn(shape, device=DEVICE, dtype=torch.float32)

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_tuned_for_loop_softmax(x), quantiles=quantiles)
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


benchmark.run(print_data=True, show_plots=True, save_path=os.path.abspath("../benchmark"))

## Better Alogrithm : Online Softmax 

https://github.com/NVIDIA/online-softmax

https://arxiv.org/pdf/1805.02867

In [None]:
@triton.jit
def online_softmax_kernel(x_ptr, output_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
    
    pid = tl.program_id(0)

    # Calculate the starting pointer of each row
    x_row_start = x_ptr + pid * n_cols
    output_row_start = output_ptr + pid * n_cols

    # First Pass: Find statistics maximum m and denominator d.
    x_max = float('-inf')
    denominator = 0.0

    for i in tl.range(0, n_cols, BLOCK_SIZE):
        offsets = i + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_cols
        x_block = tl.load(x_row_start + offsets, mask=mask, other=float('-inf')) 
        block_max = tl.max(x_block)

        new_x_max = tl.maximum(x_max, block_max)
        denominator = denominator * tl.exp(x_max - new_x_max) + tl.sum(tl.exp(x_block - new_x_max))
        x_max = new_x_max


    # Now we have the correct denominator

    # Final Pass: Calculate output and store

    for i in tl.range(0, n_cols, BLOCK_SIZE):
        offsets = i + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_cols
        x_block = tl.load(x_row_start + offsets, mask=mask, other=float('-inf')) 

        numerator = tl.exp(x_block - x_max)
        output_block = numerator / denominator

        tl.store(output_row_start + offsets, output_block, mask=mask)




def triton_online_softmax(x):
    BT, V = x.shape
    n_rows, n_cols = BT, V
    output = torch.empty_like(x)

    BLOCK_SIZE, num_warps = calculate_settings(n_cols, x.element_size())

    grid = lambda META: (n_rows,)
    online_softmax_kernel[grid](x, output, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)

    return output

In [None]:
BT = 1024
for V in [2 ** i for i in range(9, 16)]:
      test_softmax(BT, V, triton_online_softmax)

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['V'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2 ** i for i in range(3, 18, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['triton', 'torch'],  # Possible values for `line_arg`.
        line_names=['triton', 'torch'],  # Label name for the lines.
        styles=[('blue', '-'), ('green', '-')],  # Line styles.
        ylabel='GB/s',  # Label name for the y-axis.
        plot_name='online-softmax-performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(V, provider):
    BT = 4096
    shape = (BT, V)
    x = torch.randn(shape, device=DEVICE, dtype=torch.float32)

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_online_softmax(x), quantiles=quantiles)
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


benchmark.run(print_data=True, show_plots=True, save_path=os.path.abspath("../benchmark"))

## Summary

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['V'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2 ** i for i in range(3, 18, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['triton_online_softmax', 'triton_naive_softmax', 'torch'],  # Possible values for `line_arg`.
        line_names=['triton_online_softmax', 'triton_naive_softmax', 'torch'],  # Label name for the lines.
        styles=[('red', '-'), ('blue', '-'), ('green', '-')],  # Line styles.
        ylabel='GB/s',  # Label name for the y-axis.
        plot_name='all-softmax-performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(V, provider):
    BT = 4096
    shape = (BT, V)
    x = torch.randn(shape, device=DEVICE, dtype=torch.float32)

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), quantiles=quantiles)
    if provider == 'triton_online_softmax':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_online_softmax(x), quantiles=quantiles)
    if provider == 'triton_naive_softmax':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_tuned_for_loop_softmax(x), quantiles=quantiles)
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


benchmark.run(print_data=True, show_plots=True, save_path=os.path.abspath("../benchmark"))