<a href="https://colab.research.google.com/github/TechDailyNotes/study-notes-triton/blob/main/triton_02_fused_softmax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fused Softmax

## Motivations

In [5]:
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly

Looking in indexes: https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
Collecting triton-nightly
  Downloading https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/3.post20240626041721/triton_nightly-3.0.0.post20240626041721-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (139.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.1/139.1 MB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: triton-nightly
Successfully installed triton-nightly-3.0.0.post20240626041721


In [1]:
import torch
import triton
import triton.language as tl
from triton.runtime import driver

def naive_softmax(x):
    # Step 1: Shift to prevent overflow.
    # MN reads, M writes
    x_max = x.max(dim=1)[0]
    # MN + M reads, MN writes
    z = x - x_max[:, None]

    # Step 2: Compute softmax.
    # MN reads, MN writes
    numerator = z.exp()
    # MN reads, M writes
    denominator = numerator.sum(dim=1)
    # MN + M reads, MN writes
    retval = numerator / denominator[:, None]
    # 5MN + 2M reads, 3MN + 2M writes
    return retval

## Compute Kernel

In [44]:
@triton.jit
def softmax_kernel(input_ptr, output_ptr, n_rows, n_cols,
                   input_row_stride, output_row_stride,
                   BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):
    # Step 1: Get kernel parameters.
    row_index_start = tl.program_id(0)
    row_index_end = n_rows
    row_index_step = tl.num_programs(0)

    for row_index in tl.range(row_index_start, row_index_end, row_index_step,
                              num_stages=num_stages):
        # Step 2: Load data.
        input_start_ptr = input_ptr + row_index * input_row_stride
        input_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = input_start_ptr + input_offsets
        input_mask = input_offsets < n_cols
        input_row = tl.load(input_ptrs, mask=input_mask, other=-float("inf"))

        # Step 3: Compute data.
        input_row_ = input_row - tl.max(input_row, axis=0)
        numerator = tl.exp(input_row_)
        denominator = tl.sum(numerator, axis=0)
        output_row = numerator / denominator

        # Step 4: Store data.
        output_start_ptr = output_ptr + row_index * output_row_stride
        output_offsets = tl.arange(0, BLOCK_SIZE)
        output_ptrs = output_start_ptr + output_offsets
        output_mask = output_offsets < n_cols
        tl.store(output_ptrs, output_row, mask=output_mask)

In [52]:
device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
MAX_SHARED_MEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
MAX_NUM_REGS = properties["max_num_regs"]
SM_COUNT = properties["multiprocessor_count"]
target = triton.runtime.driver.active.get_current_target()

def softmax(x):
    # Step 1: Init output.
    y = torch.empty_like(x)

    # Step 2: Set kernel parameters.
    n_rows, n_cols = x.shape

    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    num_stages = 4 if MAX_SHARED_MEM > 200000 else 2
    num_warps = 8

    kernels = {}

    # Step 3: Launch kernel function.
    kernel, n_programs = kernels.get(BLOCK_SIZE, (None, 0))

    if kernel is None:
        kernel = softmax_kernel.warmup(x, y, n_rows, n_cols,
                                       x.stride(0), y.stride(0),
                                       BLOCK_SIZE=BLOCK_SIZE,
                                       num_stages=num_stages,
                                       num_warps=num_warps, grid=(1, ))
        kernel._init_handles()

        occupancy = min(MAX_NUM_REGS // (num_warps * WARP_SIZE * kernel.n_regs),
                        MAX_SHARED_MEM // kernel.metadata.shared)
        n_programs = SM_COUNT * occupancy

        kernels[BLOCK_SIZE] = (kernel, n_programs)

    n_programs = min(n_programs, n_rows)

    kernel[(n_programs, 1, 1)](x, y, n_rows, n_cols, x.stride(0), y.stride(0))

    # Step 4: Return output.
    return y

# device = torch.cuda.current_device()
# properties = driver.active.utils.get_device_properties(device)
# NUM_SM = properties["multiprocessor_count"]
# NUM_REGS = properties["max_num_regs"]
# SIZE_SMEM = properties["max_shared_mem"]
# WARP_SIZE = properties["warpSize"]
# target = triton.runtime.driver.active.get_current_target()
# kernels = {}


# def softmax(x):
#     n_rows, n_cols = x.shape

#     # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
#     BLOCK_SIZE = triton.next_power_of_2(n_cols)

#     # Another trick we can use is to ask the compiler to use more threads per row by
#     # increasing the number of warps (`num_warps`) over which each row is distributed.
#     # You will see in the next tutorial how to auto-tune this value in a more natural
#     # way so you don't have to come up with manual heuristics yourself.
#     num_warps = 8

#     # Number of software piepling stages.
#     num_stages = 4 if SIZE_SMEM > 200000 else 2

#     # Allocate output
#     y = torch.empty_like(x)

#     # pre-compile kernel to get register usage and compute thread occupancy.
#     kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
#     if kernel is None:
#         kernel = softmax_kernel.warmup(x, y, n_rows, n_cols, x.stride(0), y.stride(0), BLOCK_SIZE=BLOCK_SIZE,
#                                        num_stages=num_stages, num_warps=num_warps, grid=(1, ))
#         kernel._init_handles()
#         n_regs = kernel.n_regs
#         size_smem = kernel.metadata.shared
#         occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
#         occupancy = min(occupancy, SIZE_SMEM // size_smem)
#         num_programs = NUM_SM * occupancy
#         kernels[BLOCK_SIZE] = (kernel, num_programs)

#     num_programs = min(num_programs, n_rows)

#     # Create a number of persistent programs.
#     kernel[(num_programs, 1, 1)](
#         x,
#         y,
#         n_rows,
#         n_cols,
#         x.stride(0),
#         y.stride(0),
#     )
#     return y

## Unit Test

In [55]:
torch.manual_seed(0)
x = torch.randn((1823, 781), dtype=torch.float32, device="cuda")
output_triton = softmax(x)
output_torch = torch.softmax(x, axis=1)
assert torch.allclose(output_triton, output_torch), (output_triton, output_torch)
print(f"Max diff is {torch.max(torch.abs(output_triton - output_torch))}.")

Max diff is 7.450580596923828e-09.


## Benchmark