# Ternary Multiplication in Triton

## Setup

Check the installed triton version.

In [1]:
import triton
assert triton.__version__ == "3.0.0"

Import other needed stuff.

In [75]:
import torch
import triton.language as tl
from jaxtyping import Float32

## Helper Functions

In [3]:
def get_current_target():
    return triton.runtime.driver.active.get_current_target()

In [4]:
import warnings


def is_cuda():
    current_target = get_current_target()
    if current_target.backend != "cuda":
        return False

    if current_target.arch < 70:  # CUDA compute capacity is below 7.0, which is minimum 'stable' supported
        warnings.warn("Compute capcity of CUDA device is below 7.0. The Triton compilation may fail terribly!")

    return True

In [5]:
# def is_hip_mi200():
#     target = triton.runtime.driver.active.get_current_target()
#     return target.backend == "hip" and target.arch == "gfx90a"

In [6]:
# is_hip_mi200()

In [18]:
def get_cuda_autotune_config():
    # return [
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 256,
    #         },
    #         num_stages=3,
    #         num_warps=8,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 64,
    #             "BLOCK_SIZE_N": 256,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 128,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 64,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 64,
    #             "BLOCK_SIZE_N": 128,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 32,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 64,
    #             "BLOCK_SIZE_N": 32,
    #         },
    #         num_stages=5,
    #         num_warps=2,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 32,
    #             "BLOCK_SIZE_N": 64,
    #         },
    #         num_stages=5,
    #         num_warps=2,
    #     ),
    #     # Good config for fp8 inputs.
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 256,
    #         },
    #         num_stages=3,
    #         num_warps=8,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 256,
    #             "BLOCK_SIZE_N": 128,
    #         },
    #         num_stages=3,
    #         num_warps=8,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 256,
    #             "BLOCK_SIZE_N": 64,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 64,
    #             "BLOCK_SIZE_N": 256,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 128,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 64,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 64,
    #             "BLOCK_SIZE_N": 128,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 32,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    # ]
    return [triton.Config({"BLOCK_SIZE_M": 2, "BLOCK_SIZE_N": 2})]

In [19]:
def get_autotune_config():
    if is_cuda():
        return get_cuda_autotune_config()
    else:
        raise ValueError("Not on CUDA... can't use!")

Main ternary multiplication kernel.

In [76]:
# ruff: noqa: N803
@triton.autotune(
    configs=get_autotune_config(),
    key=["M", "N"],
)
@triton.jit
def ternary_blocks_mul_kernel(
    # Pointers to matrices
    x_ptr, w_ptr, s_ptr,
    # Scaling factor
    scale,
    # W matrix dimensions
    M, N,  # So W has shape (M, N)
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr
):
    """
    Kernel for computing ternary multiplication of x and W with the blocked results placed in the
    resultant matrix S.

    x is a (transposed) vector of dimension M, W has shape (M, N), and S has shape
    (M, N / BLOCK_SIZE_N), where the division performed here is ceiling division.
    """
    
    # Map program IDs to the elements that it is supposed to obtain
    pid_m = tl.program_id(axis=0)  # PID for blocks of M
    pid_n = tl.program_id(axis=1)  # PID for blocks of N
    
    range_m = tl.arange(0, BLOCK_SIZE_M)
    range_n = tl.arange(0, BLOCK_SIZE_N)
    
    block_start_row = pid_m * BLOCK_SIZE_M  # First row that this block accesses
    block_start_col = pid_n * BLOCK_SIZE_N  # First column that this block accesses
    
    offsets_x = block_start_row + range_m
    mask_x = offsets_x < M
    x = tl.load(x_ptr + offsets_x, mask=mask_x)
    
    base_coords = range_m[None, :] + range_n[:, None] * M  # This makes a 2D array of base offsets
    base_coords = tl.reshape(base_coords, (BLOCK_SIZE_M * BLOCK_SIZE_N,))  # Reshape it back to 1D
    coords = base_coords * block_start_row + block_start_col * M
    mask_w = range_m[None, :] + block_start_row < M  # Ensure that we don't access elements after the last row...
    mask_w = mask_w & (range_n[:, None] + block_start_col < N)  # ...and don't access elements after the last column
    w = tl.load(w_ptr + coords, mask=mask_w)
    
    print("x", x)
    print("w", w)
    
    # # ----------------------------------------------------------------------------------------------
    # # Map program IDs (`pid`) to the element in z that it is supposed to compute, and load the
    # # appropriate parts of the `x` vector and `w` matrix.
    # pid_m = tl.program_id(axis=0)  # PID for M dimension
    # pid_n = tl.program_id(axis=1)  # PID for N dimension

    # block_start_x = pid_m * BLOCK_SIZE_M
    # range_x = tl.arange(0, BLOCK_SIZE_M)
    # offsets_x = block_start_x + range_x
    # mask_x = offsets_x < M

    # x = tl.load(x_ptr + offsets_x, mask=mask_x)
    # print("x", x)

    # block_start_w = pid  # This is the column number
    # range_w = tl.arange(0, BLOCK_SIZE_M) * N  # Columnwise offsets
    # offsets_w = block_start_w + range_w
    # print("offsets_w", offsets_w)
    # mask_w = offsets_w < M * N  # TODO: Check
    # w = tl.load(w_ptr + offsets_w, mask=mask_w)
    # print("w", w)

    # # ----------------------------------------------------------------------------------------------
    # # Iterate to compute the element in the `z` vector.
    # # Since `w` is ternary, we only really care about the sign of the element in the array, and so
    # # we just need to perform two conditional checks
    # elems_to_sum = tl.where(w > 0, x, tl.where(w < 0, -x, tl.zeros_like(x)))
    # # print("elems", elems_to_sum)

    # # print("x", x)
    # # print("a", scale)
    # total = tl.sum(elems_to_sum)
    # total = total / scale  # Need to apply scale right after

    # # ----------------------------------------------------------------------------------------------
    # # Write the single element to the `z` output vector
    # tl.store(z_ptr + pid, total)  # TODO: We shouldn't need a mask since the PID is OK... right?


SyntaxError: unterminated string literal (detected at line 4) (355720455.py, line 4)

Summation kernel.

In [None]:
# ruff: noqa: N803
@triton.jit
def sum_kernel(
    # Pointers to arrays
    s_ptr,
    z_ptr,
    # S matrix dimensions
    M,
    N,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
):
    """
    Kernel that sums elements in the matrix S row-wise, placing the outputs the summed elements in
    the vector z.

    S is a matrix of shape (M, N) and z is a vector with N dimensions.
    """
    
    # Map program IDs to the appropriate elements in z that it needs to compute
    pid = tl.program_id(axis=0)
    block_start_z = pid * BLOCK_SIZE_M
    offsets_z = block_start_z + tl.arange(0, BLOCK_SIZE_M)
    mask_z = offsets_z < M  # Guard against OOM

    # Get the offsets and mask for the row of values that z is supposed to be a sum of
    range_row = tl.arange(0, BLOCK_SIZE_N)[None, :]  # Reshape the `arange` to have shape (1, BLOCK_SIZE_N)
    offsets_row = offsets_z[:, None] * N + range_row
    mask_row = mask_z[:, None]

    # Obtain the sum for those elements
    accumulator = tl.zeros(BLOCK_SIZE_M)
    for i in range(0, N, BLOCK_SIZE_N):
        # Load the elements to sum
        range_s = range_row + i
        offsets_s = offsets_row + i
        mask_s = mask_row * (range_s < N)
        s = tl.load(s_ptr + offsets_s, mask=mask_s, other=0.0)

        # Perform the sum
        accumulator = accumulator + tl.sum(s, axis=1)

    # Store the summed elements
    tl.store(z_ptr + offsets_z, accumulator, mask=mask_z)

We can now create a convenience wrapper function that only takes two input tensors, and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.

In [65]:
# ruff: noqa: E731
def ternary_mul(x, w, scale):
    # Check constraints.
    assert len(x) == w.shape[0], "Incompatible dimensions"
    assert x.is_contiguous(), "x must be contiguous"

    assert x.is_cuda and w.is_cuda

    # Get dimensions
    M, N = w.shape

    # Allocate intermediate output
    s = torch.empty((M, ), device=x.device, dtype=torch.float16)  # TODO: Change precision?

    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
    ternary_mul_kernel[grid](
        x, w, z, #
        scale, #
        M, N #
    )
    return z

TESTING CODE

In [66]:
X_LEN = 8  # x is the 1D vector
W_LEN = 8  # W is the quantized weights matrix
W_SIZE = (X_LEN, W_LEN)

In [67]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f48745efd90>

In [68]:
# x = torch.rand(X_LEN, device="cuda")
# w = torch.tensor([-1., 0., 1.], device="cuda")[torch.randint(2, W_SIZE)]
s = torch.tensor([1., 2, 4, 8], device="cuda")
w = torch.tensor([
    [1., 0., 0., 0.],
    [0, 1., 1., 0],
    [0, -1., 0., 1.],
    [0, 0., 1., -1.]
], device="cuda")
scale = 1.

In [69]:
torch_output = torch.matmul(s, w) / scale

In [70]:
torch_output

tensor([ 1., -2., 10., -4.], device='cuda:0')

In [71]:
triton_output = ternary_mul(s, w, scale)

In [72]:
triton_output

pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx () pid: 1
pid (1, 0, 0) idx ()

tensor([ 1.,  8.,  0., -2.], device='cuda:0', dtype=torch.float16)