# Fused Layers in Triton

## Setup

In [1]:
%pip install --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==3.0.0.post20240626041721

Looking in indexes: https://pypi.org/simple, https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # This is to avoid conflicts with PyTorch and the Triton compatibility
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"

Check the installed triton version.

In [3]:
import triton

assert triton.__version__ == "3.0.0", f"Expected Triton to have a version of 3.0.0, but found {triton.__version__}"

Import other needed stuff.

In [4]:
import torch
import triton.language as tl

## Generic Autotune Config

In [5]:
def get_autotune_config():
    return [
        triton.Config({}, num_warps=1),
        triton.Config({}, num_warps=2),
        triton.Config({}, num_warps=4),
        triton.Config({}, num_warps=8),
        triton.Config({}, num_warps=16),
        triton.Config({}, num_warps=32),
    ]

## Fused Layer Norm with Quantization

We want to fuse the RMSNorm and the quantization into one layer.

### Forward Pass

#### Function Definitions

First define the kernel.

In [6]:
# ruff: noqa: N803
@triton.autotune(
    configs=get_autotune_config(),
    key=["N", "HAS_BIAS"],
)
@triton.jit
def quant_rms_norm_fwd_kernel(
    # fmt: off
    # Pointers to arrays
    x_ptr, y_ptr, gain_ptr, bias_ptr, rstd_ptr,
    # Strides
    stride_x_row,  # How much to increase the pointer when moving by 1 row
    stride_y_row,
    # Some constants
    N,  # Number of columns in X
    EPSILON,  # To avoid division by zero
    # Meta-parameters
    BLOCK_SIZE_N: tl.constexpr,
    HAS_GAIN: tl.constexpr,
    HAS_BIAS: tl.constexpr
    # fmt: on
):
    """
    Forward kernel.

    Performs RMSNorm on ``X``, followed by 8-bit quantization.
    """

    # Map the PID to the row of X that should be loaded
    pid = tl.program_id(0)
    base_x = pid * stride_x_row
    offsets = tl.arange(0, BLOCK_SIZE_N)
    mask = offsets < N
    x = tl.load(x_ptr + base_x + offsets, mask=mask, other=0.0).to(tl.float32)  # Load in higher precision

    # Compute reciprocal standard deviation (rstd)
    x_bar = tl.where(offsets < N, x, 0.0)  # Masked `x` to avoid illegal access
    variance = tl.sum(x_bar * x_bar, axis=0) / N
    rstd = 1 / tl.sqrt(variance + EPSILON)
    tl.store(rstd_ptr + pid, rstd)  # We add PID since that is the row that the rstd is corresponding to

    # Normalize
    x_hat = x * rstd

    # Apply gain and bias
    y = x_hat

    if HAS_GAIN:
        gain = tl.load(gain_ptr + offsets, mask=mask).to(tl.float32)
        y = y * gain
    if HAS_BIAS:
        bias = tl.load(bias_ptr + offsets, mask=mask).to(tl.float32)
        y = y + bias

    # Apply 8-bit quantization
    scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), EPSILON)
    y = tl.extra.cuda.libdevice.round(y * scale)  # TODO: This is CUDA only... can we generalize this?
    y = tl.maximum(tl.minimum(y, 127), -128) / scale  # The nested max and min creates the clamp/clip function

    # Write output
    base_y = pid * stride_y_row
    tl.store(y_ptr + base_y + offsets, y, mask=mask)

Then define the companion function that handles checking and allocation of tensors.

In [7]:
# ruff: noqa: N806, S101
def quant_rms_norm_fwd(x, gain, bias, epsilon):
    """
    Forward pass.

    Performs RMSNorm on ``X``, followed by 8-bit quantization.

    Requires CUDA.
    """

    assert x.ndim == 2  # TODO: Support other ndim values?

    # Get dimensions
    M, N = x.shape

    # Validate that the input is OK
    assert x.stride(-1) == 1

    if gain is not None:
        assert gain.shape == (N,)
        assert gain.stride(-1) == 1
    if bias is not None:
        assert bias.shape == (N,)
        assert bias.stride(-1) == 1

    # Allocate output
    y = torch.empty_like(x, dtype=x.dtype)
    rstd = torch.empty((M,), dtype=torch.float32, device="cuda")

    # Enqueue fused kernel if less than 64KiB per feature
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
    if N > BLOCK_SIZE_N:
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KiB.")

    # Run the kernel
    with torch.cuda.device(x.device.index):
        quant_rms_norm_fwd_kernel[(M,)](
            # fmt: off
            # Pointers to arrays
            x, y, gain, bias, rstd,
            # Strides
            x.stride(0),
            y.stride(0),
            # Some constants
            N,  # Number of columns in X
            epsilon,  # To avoid division by zero
            # Meta-parameters
            BLOCK_SIZE_N,
            gain is not None,
            bias is not None
            # fmt: on
        )

    # Return stuff
    return y, rstd

#### Testing the Functions

First define the array that we want to normalize and quantize.

In [53]:
import numpy as np

np.random.seed(8192)

In [54]:
x = np.random.random((8, 8))
x

array([[0.26927402, 0.61529423, 0.09748146, 0.09273818, 0.73007226,
        0.58209141, 0.60969096, 0.61006532],
       [0.12066607, 0.24657084, 0.04623944, 0.30049028, 0.35141794,
        0.13295607, 0.14754041, 0.03140163],
       [0.61991843, 0.14574681, 0.26931555, 0.4847425 , 0.3966334 ,
        0.88168223, 0.83791913, 0.5625612 ],
       [0.87264191, 0.26286043, 0.58907885, 0.99076971, 0.69954819,
        0.60982454, 0.37018737, 0.83690735],
       [0.89161974, 0.01654753, 0.08784621, 0.01367872, 0.98187445,
        0.14133196, 0.06238426, 0.98368333],
       [0.61991595, 0.28636141, 0.08490172, 0.0382211 , 0.48337241,
        0.86436947, 0.33873519, 0.79784515],
       [0.35204787, 0.90411859, 0.0203531 , 0.27302275, 0.32587877,
        0.52114029, 0.52831412, 0.857042  ],
       [0.33801928, 0.81751108, 0.124455  , 0.72949503, 0.08025578,
        0.67488853, 0.30774964, 0.32175416]])

Get the baseline result. We will be using another backend for the baseline.

In [55]:
from keras_mml.layers import QuantRMSNorm

In [56]:
baseline_output = QuantRMSNorm()(x)
baseline_output = np.array(baseline_output)
baseline_output

array([[0.52924716, 1.2048818 , 0.19142982, 0.18016924, 1.4300933 ,
        1.1373184 , 1.1936212 , 1.1936212 ],
       [0.5979084 , 1.2094055 , 0.23101005, 1.4811821 , 1.725781  ,
        0.65226364, 0.7202078 , 0.1494771 ],
       [1.0702566 , 0.25253245, 0.46898884, 0.8417748 , 0.6854452 ,
        1.52722   , 1.4550679 , 0.9740537 ],
       [1.2582126 , 0.38195738, 0.85378706, 1.4267231 , 1.0110636 ,
        0.87625515, 0.5279999 , 1.2020423 ],
       [1.5168372 , 0.02637978, 0.14508878, 0.02637978, 1.6751158 ,
        0.237418  , 0.10551911, 1.6751158 ],
       [1.1791687 , 0.5442317 , 0.15549478, 0.07774739, 0.92001075,
        1.645653  , 0.6478949 , 1.5160741 ],
       [0.63585424, 1.6480303 , 0.03892985, 0.49311143, 0.59692436,
        0.94729304, 0.96026963, 1.557194  ],
       [0.6839782 , 1.6389667 , 0.24519974, 1.4582932 , 0.154863  ,
        1.3550512 , 0.619452  , 0.6452625 ]], dtype=float32)

Compare that with the Triton result.

In [57]:
x_torch = torch.tensor(x, device="cuda")

In [58]:
triton_output, triton_rstd = quant_rms_norm_fwd(x_torch, None, None, 1e-5)
triton_output = np.array(triton_output.cpu())
triton_output

array([[0.52923703, 1.20485878, 0.19142616, 0.1801658 , 1.43006599,
        1.13729656, 1.19359839, 1.19359839],
       [0.59783626, 1.20925975, 0.2309822 , 1.48100352, 1.72557282,
        0.65218502, 0.72012097, 0.14945906],
       [1.07024038, 0.25252864, 0.46898174, 0.84176213, 0.68543488,
        1.527197  , 1.45504594, 0.97403902],
       [1.25819969, 0.38195348, 0.85377836, 1.42670858, 1.01105332,
        0.87624621, 0.52799451, 1.20203006],
       [1.51681519, 0.0263794 , 0.14508668, 0.0263794 , 1.67509162,
        0.23741455, 0.10551758, 1.67509162],
       [1.17914748, 0.54422194, 0.15549198, 0.07774599, 0.91999424,
        1.64562345, 0.64788324, 1.51604676],
       [0.63584358, 1.64800274, 0.0389292 , 0.49310318, 0.59691441,
        0.94727719, 0.9602536 , 1.55716801],
       [0.68396449, 1.63893378, 0.24519482, 1.45826387, 0.15485989,
        1.35502398, 0.61943954, 0.64524955]])

In [59]:
if np.allclose(triton_output, baseline_output, atol=1e-3):
    print("✅ Triton and Baseline match")
else:
    raise ValueError("❌ Triton and Baseline differ")

✅ Triton and Baseline match
