# Ternary Multiplication in Triton

## Setup

Only need to run the first time. Works with latest triton. Sorry, this takes a minute to install.

In [1]:
!pip install jaxtyping~=0.2.31
!pip install git+https://github.com/Deep-Learning-Profiling-Tools/triton-viz@v1
!pip install --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==3.0.0.post20240626041721

Collecting jaxtyping~=0.2.31
  Downloading jaxtyping-0.2.31-py3-none-any.whl (41 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting typeguard==2.13.3 (from jaxtyping~=0.2.31)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, jaxtyping
Successfully installed jaxtyping-0.2.31 typeguard-2.13.3
Collecting git+https://github.com/Deep-Learning-Profiling-Tools/triton-viz@v1
  Cloning https://github.com/Deep-Learning-Profiling-Tools/triton-viz (to revision v1) to /tmp/pip-req-build-q8m3xwy1
  Running command git clone --filter=blob:none --quiet https://github.com/Deep-Learning-Profiling-Tools/triton-viz /tmp/pip-req-build-q8m3xwy1
  Running command git checkout -q 1772b6ead27a3218c9c2c9ad88bd4e94623fb74c
  Resolved https://github.com/Deep-Learnin

Check the installed triton version.

In [2]:
import triton
assert triton.__version__ == "3.0.0"

Import other needed stuff.

In [3]:
import torch
import triton.language as tl

## Helper Functions

In [4]:
def get_current_target():
    return triton.runtime.driver.active.get_current_target()

In [5]:
import warnings


def is_cuda():
    current_target = get_current_target()
    if current_target.backend != "cuda":
        return False

    if current_target.arch < 70:  # CUDA compute capacity is below 7.0, which is minimum 'stable' supported
        warnings.warn("Compute capcity of CUDA device is below 7.0. The Triton compilation may fail terribly!")

    return True

In [6]:
# def is_hip_mi200():
#     target = triton.runtime.driver.active.get_current_target()
#     return target.backend == "hip" and target.arch == "gfx90a"

In [7]:
# is_hip_mi200()

In [8]:
def get_cuda_autotune_config():
    return [
        triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32}, num_stages=5, num_warps=2),
        # Good config for fp8 inputs.
        triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=4)
    ]

In [9]:
def get_autotune_config():
    if is_cuda():
        return get_cuda_autotune_config()
    else:
        raise ValueError("Not on CUDA... can't use!")

In [10]:
@triton.autotune(
    configs=get_autotune_config(),
    key=['M'],
)
@triton.jit
def ternary_mul_kernel(
        # Pointers to matrices
        x_ptr, w_ptr, z_ptr,
        # Scaling factor
        scale,
        # W matrix dimensions
        M, N,  # So W has shape (M, N)
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr
):
    """
    Kernel for computing the ternary multiplication
        z = xW
    x has shape (1, M), W has shape (M, N), and z has shape (1, N)
    """

    # ----------------------------------------------------------------------------------------------
    # Map program IDs (`pid`) to the element in z that it is supposed to compute, and load the
    # appropriate parts of the `x` vector and `w` matrix.
    pid = tl.program_id(axis=0)

    block_start_x = pid * BLOCK_SIZE_M
    range_x = tl.arange(0, BLOCK_SIZE_M)
    offsets_x = block_start_x + range_x
    mask_x = offsets_x < M
    x = tl.load(x_ptr + offsets_x, mask=mask_x)

    block_start_w = pid  # This is the column number
    range_w = tl.arange(0, BLOCK_SIZE_M) * N  # Columnwise offsets
    offsets_w = block_start_w + range_w
    mask_w = offsets_w < M * N  # TODO: Check
    w = tl.load(w_ptr + offsets_w, mask=mask_w)

    # ----------------------------------------------------------------------------------------------
    # Iterate to compute the element in the `z` vector.
    # Since `w` is ternary, we only really care about the sign of the element in the array, and so
    # we just need to perform two conditional checks
    elems_to_sum = tl.where(w > 0, x, tl.where(w < 0, -x, tl.zeros_like(x)))
    print("elems", elems_to_sum)
    print("w", w)
    print("x", x)
    print("a", scale)
    total = tl.sum(elems_to_sum)
    total = total / scale  # Need to apply scale right after

    # ----------------------------------------------------------------------------------------------
    # Write the single element to the `z` output vector
    print("pid", pid)
    tl.store(z_ptr + pid, total)  # TODO: We shouldn't need a mask since the PID is OK... right?


We can now create a convenience wrapper function that only takes two input tensors, and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.

In [11]:
def ternary_mul(x, w, scale):
    # Check constraints.
    assert len(x) == w.shape[0], "Incompatible dimensions"
    assert x.is_contiguous(), "x must be contiguous"

    assert x.is_cuda and w.is_cuda

    # Get dimensions
    M, N = w.shape

    # Allocate output
    z = torch.empty((N,), device=x.device, dtype=torch.float16)  # TODO: Change precision?

    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), )
    ternary_mul_kernel[grid](
        x, w, z, #
        scale, #
        M, N #
    )
    return z

TESTING CODE

In [12]:
X_LEN = 8  # x is the 1D vector
W_LEN = 8  # W is the quantized weights matrix
W_SIZE = (X_LEN, W_LEN)

In [13]:
torch.manual_seed(0)

<torch._C.Generator at 0x7d2ca805a470>

In [24]:
# x = torch.rand(X_LEN, device="cuda")
# w = torch.tensor([-1., 0., 1.], device="cuda")[torch.randint(2, W_SIZE)]
x = torch.tensor([1., 2, 4, 8], device="cuda")
w = torch.tensor([
    [1., 0., 0., 0.],
    [0, 1., 1., 0],
    [0, 1., 0., 1.],
    [0, 0., 1., 1.]
], device="cuda")
scale = 1.

In [25]:
torch_output = torch.matmul(x, w) / scale

In [26]:
torch_output

tensor([ 1.,  6., 10., 12.], device='cuda:0')

In [29]:
triton_output = ternary_mul(x, w, scale)

In [30]:
triton_output

tensor([1.0000, 1.8750, 0.0000, 1.9844], device='cuda:0', dtype=torch.float16)