# Ternary Multiplication in Triton

## Setup

Check the installed triton version.

In [1]:
import triton
assert triton.__version__ == "3.0.0"

Import other needed stuff.

In [2]:
import torch
import triton.language as tl
from jaxtyping import Float32

## Helper Functions

In [3]:
def get_current_target():
    return triton.runtime.driver.active.get_current_target()

In [4]:
import warnings


def is_cuda():
    current_target = get_current_target()
    if current_target.backend != "cuda":
        return False

    if current_target.arch < 70:  # CUDA compute capacity is below 7.0, which is minimum 'stable' supported
        warnings.warn("Compute capcity of CUDA device is below 7.0. The Triton compilation may fail terribly!")

    return True

In [5]:
# def is_hip_mi200():
#     target = triton.runtime.driver.active.get_current_target()
#     return target.backend == "hip" and target.arch == "gfx90a"

In [6]:
# is_hip_mi200()

In [7]:
def get_cuda_autotune_config():
    # return [
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 256,
    #         },
    #         num_stages=3,
    #         num_warps=8,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 64,
    #             "BLOCK_SIZE_N": 256,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 128,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 64,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 64,
    #             "BLOCK_SIZE_N": 128,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 32,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 64,
    #             "BLOCK_SIZE_N": 32,
    #         },
    #         num_stages=5,
    #         num_warps=2,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 32,
    #             "BLOCK_SIZE_N": 64,
    #         },
    #         num_stages=5,
    #         num_warps=2,
    #     ),
    #     # Good config for fp8 inputs.
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 256,
    #         },
    #         num_stages=3,
    #         num_warps=8,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 256,
    #             "BLOCK_SIZE_N": 128,
    #         },
    #         num_stages=3,
    #         num_warps=8,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 256,
    #             "BLOCK_SIZE_N": 64,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 64,
    #             "BLOCK_SIZE_N": 256,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 128,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 64,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 64,
    #             "BLOCK_SIZE_N": 128,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    #     triton.Config(
    #         {
    #             "BLOCK_SIZE_M": 128,
    #             "BLOCK_SIZE_N": 32,
    #         },
    #         num_stages=4,
    #         num_warps=4,
    #     ),
    # ]
    return [triton.Config({"BLOCK_SIZE_M": 2, "BLOCK_SIZE_N": 2})]

In [8]:
def get_autotune_config():
    if is_cuda():
        return get_cuda_autotune_config()
    else:
        raise ValueError("Not on CUDA... can't use!")

Main ternary multiplication kernel.

The rough pseudocode algorithm is as follows.
```python
# Do in parallel
for n in range(0, N, BLOCK_SIZE_N):
    acc = zeros((BLOCK_SIZE_N,), dtype=float32)
    for m in range(0, M, BLOCK_SIZE_M):
        x_block = x[m : m+BLOCK_SIZE_M]
        w_block = w[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N]
        
        # Since `w` is ternary, we only really care about the sign of the element in the array, and so
        # we just need to perform two conditional checks
        elems_to_sum = tl.where(w_block > 0, x_block, tl.where(w_block < 0, -x_block, tl.zeros_like(x_block)))
        acc += tl.sum(elems_to_sum)  # Sum along the M direction

    acc = acc / scale
    z[n : n+BLOCK_SIZE_N] = acc
```

In [26]:
# ruff: noqa: N803, PLR2044
@triton.autotune(
    configs=get_autotune_config(),
    key=["M", "N"],
)
@triton.jit
def ternary_mul_kernel(
    # Pointers to matrices
    x_ptr,
    w_ptr,
    z_ptr,
    # Scaling factor
    scale,
    # `W` matrix dimensions
    M,
    N,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
    # by to get the element one row down (A has M rows).
    stride_xm,
    stride_wm,
    stride_wn,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
):
    """
    Kernel for computing the ternary multiplication
        z = xW
    `x` has shape `(1, M)`, `W` has shape `(M, N)`, and `z` has shape `(1, N)`.
    """

    # -----------------------------------------------------------
    # Map `pid` to the block of `z` that it should compute.
    pid = tl.program_id(axis=0)

    # ----------------------------------------------------------
    # Create pointers for the first blocks of `x` and `W`.
    # We will advance this pointer as we move in the `M` direction and accumulate.
    # - `x_ptrs` is a block of `BLOCK_SIZE_M` pointers
    # - `w_ptrs` is a block of pointers with shape `(BLOCK_SIZE_M, BLOCK_SIZE_N)`
    offs_m = tl.arange(0, BLOCK_SIZE_M)
    offs_n = (pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N  # Guard against wrong offsets
    x_ptrs = x_ptr + offs_m
    w_ptrs = w_ptr + (offs_m[:, None] * stride_wm + offs_n[None, :] * stride_wn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the `z` vector.
    # We accumulate into a block of `BLOCK_SIZE_N` elements of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)
    for m in range(0, tl.cdiv(M, BLOCK_SIZE_M)):
        # Load the next block of `x` and `W`, generate a mask by checking the ??? dimension.
        # If it is out of bounds, set it to 0.
        # TODO: Check masks
        x = tl.load(x_ptrs, mask=offs_m < M - m * BLOCK_SIZE_M, other=0.0)[:, None]  # Force broadcast to correct shape here
        w = tl.load(w_ptrs, mask=offs_m[:, None] < M - m * BLOCK_SIZE_M, other=0.0)

        # Since `w` is ternary, we only really care about the sign of the element in the array, and so
        # we just need to perform two conditional checks
        elements_to_sum = tl.where(w > 0, x, tl.where(w < 0, -x, tl.zeros_like(x)))
        accumulator = accumulator + tl.sum(elements_to_sum, axis=0)  # Sum along the `M` direction

        # Advance the ptrs to the next `M` block.
        x_ptrs += BLOCK_SIZE_M * stride_xm
        w_ptrs += BLOCK_SIZE_M * stride_wm

    accumulator = accumulator / scale
    z = accumulator.to(tl.float16)

    # -----------------------------------------------------------
    # Write back the block of the output vector `z` with masks.
    offs_z = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    z_ptrs = z_ptr + offs_z
    z_mask = offs_z < N
    tl.store(z_ptrs, z, mask=z_mask)



We can now create a convenience wrapper function that only takes two input tensors, and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.

In [27]:
# ruff: noqa: E731
def ternary_mul(x, w, scale):
    # Check constraints.
    assert len(x) == w.shape[0], "Incompatible dimensions"
    assert x.is_contiguous(), "x must be contiguous"

    assert x.is_cuda and w.is_cuda

    # Get dimensions
    M, N = w.shape

    # Allocate output
    z = torch.empty((N,), device=x.device, dtype=torch.float16)  # TODO: Change precision?

    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE_N"]),)
    ternary_mul_kernel[grid](
        x, w, z,  #
        scale,  #
        M, N,  #
        x.stride(0),  #
        w.stride(0), w.stride(1)
    )
    return z

TESTING CODE

In [28]:
X_LEN = 8  # x is the 1D vector
W_LEN = 8  # W is the quantized weights matrix
W_SIZE = (X_LEN, W_LEN)

In [29]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f3c2968fdb0>

In [30]:
# x = torch.rand(X_LEN, device="cuda")
# w = torch.tensor([-1., 0., 1.], device="cuda")[torch.randint(2, W_SIZE)]
s = torch.tensor([1., 2, 4, 8], device="cuda")
w = torch.tensor([
    [ 1.,  0.,  0.,  0.],
    [ 0.,  1.,  1.,  0.],
    [ 0., -1.,  0.,  1.],
    [ 0.,  0.,  1., -1.]
], device="cuda")
scale = 1.

In [31]:
torch_output = torch.matmul(s, w) / scale

In [32]:
torch_output

tensor([ 1., -2., 10., -4.], device='cuda:0')

In [33]:
triton_output = ternary_mul(s, w, scale)

In [34]:
triton_output

tensor([ 1., -2., 10., -4.], device='cuda:0', dtype=torch.float16)

In [18]:
triton_output

tensor([ 1., -6.,  5.,  0.], device='cuda:0', dtype=torch.float16)