# Triton Atomic Add

Below is a test of Triton's atomic add functionality. We'll compare two different approaches for atomic addition:
1. `tl.atomic_add` - Atomic addition
2. `tl.atomic_cas` - Atomic compare and swap guard

First we start with a non-atomic operation. We will produce a vector of values and `tl.store` them to the same address.

In [1]:
%xmode Minimal
import triton
import triton.language as tl
import torch
from torch.testing import assert_close


@triton.jit
def kernel_non_atomic(x_p, o_p, BLOCK: tl.constexpr):
    # Load x
    ptrs = tl.arange(0, BLOCK)
    x = tl.load(x_p + ptrs)
    # Store all values of vector x into single scalar o
    ptrs = tl.zeros_like(ptrs)
    tl.store(o_p + ptrs, x)


def non_atomic():
    torch.random.manual_seed(0)
    BLOCK = 32
    x = torch.randn(BLOCK, device="cuda")
    o = x.new_zeros(1)
    kernel_non_atomic[(1,)](x, o, BLOCK)
    exp = x.sum()
    assert_close(o.view(1), exp.view(1))


non_atomic()

Exception reporting mode: Minimal


AssertionError: Tensor-likes are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 0.38307708501815796 at index (0,) (up to 1e-05 allowed)
Greatest relative difference: 0.2929307222366333 at index (0,) (up to 1.3e-06 allowed)

In [5]:
@triton.jit
def kernel(a_p, b_p, o_p, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
    a = tl.load(a_p + tl.arange(0, BLOCK_M)[:, None] * BLOCK_K + tl.arange(0, BLOCK_K)[None, :])
    b = tl.load(b_p + tl.arange(0, BLOCK_N)[:, None] * BLOCK_K + tl.arange(0, BLOCK_K)[None, :])
    diff = a[:, None, :] - b[None, :, :]
    diff *= diff
    diff = tl.sum(diff, 2)
    tl.store(o_p + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :], diff)

def launch():
    M, N, K = 32, 32, 16
    a = torch.randn(M, K, device="cuda")
    b = torch.randn(N, K, device="cuda")
    o = a.new_zeros(M, N)
    kernel[(1,)](a, b, o, M, N, K)
    exp = torch.cdist(a, b)
    assert_close(o, exp)

launch()

CompilationError: at 2:69:def kernel(a_p, b_p, o_p, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
    a = tl.load(a_p + tl.arange(0, BLOCK_M) * BLOCK_K + tl.arange(0, BLOCK_K))
                                                                     ^
ValueError('Cannot make_shape_compatible: incompatible dimensions at index 0: 32 and 16')

As expected, `tl.store` does not use atomic addition when storing a vector of values into a single address. Nor will it atomically add across parallel programs.

Now let's try atomic addition. We know this will atomically add across parallel programs, but how does it behave when we have one program writing to the same address?

In [3]:
@triton.jit
def kernel_atomic(x_p, o_p, BLOCK: tl.constexpr):
    # Load x
    ptrs = tl.arange(0, BLOCK)
    x = tl.load(x_p + ptrs)
    # Store all values of vector x into single scalar o
    ptrs = tl.zeros_like(ptrs)
    tl.atomic_add(o_p + ptrs, x)


def atomic_add():
    torch.random.manual_seed(0)
    BLOCK = 32
    x = torch.randn(BLOCK, device="cuda")
    o = x.new_zeros(1)
    kernel_atomic[(1,)](x, o, BLOCK)
    exp = x.sum()
    assert_close(o.view(1), exp.view(1))


atomic_add()

`tl.atomic_add` properly accumulates the values of the vector `x` into the scalar `o`. Meaning it is atomic not just across parallel calls, but also across repeated pointers within a single program.

Now let's compare the runtime of `tl.atomic_add` and `tl.atomic_cas`. For this example we'll assume that the program has no duplicate addresses within a single program, i.e. we only need atomicity across parallel programs.

In [3]:
@triton.jit
def kernel_add(x_p, o_p, BLOCK: tl.constexpr):
    # Load x
    ptrs = tl.arange(0, BLOCK)[:, None] * BLOCK + tl.arange(0, BLOCK)[None, :]
    x = tl.load(x_p + ptrs)
    # Store x into o
    tl.atomic_add(o_p + ptrs, x)


@triton.jit
def kernel_cas(x_p, o_p, lock_p, BLOCK: tl.constexpr):
    # Load x
    ptrs = tl.arange(0, BLOCK)[None, :] * BLOCK + tl.arange(0, BLOCK)[:, None]
    x = tl.load(x_p + ptrs)

    # Acquire lock
    while tl.atomic_cas(lock_p, 0, 1) != 0:
        pass
    x += tl.load(o_p + ptrs)
    tl.store(o_p + ptrs, x)
    tl.atomic_xchg(lock_p, 0)


def launch(method, dtype=torch.float32, check=True):
    torch.random.manual_seed(0)
    BLOCK = 64
    GRID = 2048
    x = torch.randn(BLOCK, BLOCK, device="cuda", dtype=dtype)
    o = torch.zeros_like(x)

    if method == "add":
        kernel = kernel_add[(GRID,)](x, o, BLOCK, num_warps=4)
    else:
        lock = torch.zeros(1, device="cuda", dtype=torch.int32)
        kernel = kernel_cas[(GRID,)](x, o, lock, BLOCK, num_warps=4)
    if check:
        exp = x * GRID
        assert_close(o, exp, rtol=1e-4, atol=0)
    return kernel

print("Testing atomic_add")
launch("add")
print("Testing atomic_cas")
launch("cas")

add_ms = triton.testing.do_bench(lambda: launch("add"))
cas_ms = triton.testing.do_bench(lambda: launch("cas"))
print(f"Atomic add took {add_ms} ms")
print(f"Atomic cas took {cas_ms} ms")

Testing atomic_add
Testing atomic_cas
Atomic add took 0.45505693554878235 ms
Atomic cas took 2.5705363750457764 ms


There are some limitations to `tl.atomic_add`. Notably it does not support `tl.bfloat16`. It appears that bfloat16 support only exists on SM_90 (Hopper) and above.

In [2]:
import triton
import triton.language as tl
import torch
from torch.testing import assert_close

@triton.jit
def kernel(
    x_p, y_p, o_p, lock_p,
    M, N, K,
    stride_x_b, stride_x_m, stride_x_k,
    stride_y_b, stride_y_n, stride_y_k,
    stride_o_b, stride_o_m, stride_o_n,
    stride_lock_b,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    ATOMIC_ADD: tl.constexpr = True,
):
    b = tl.program_id(0)
    start_m = tl.program_id(1) * BLOCK_M

    x_p += b * stride_x_b 
    y_p += b * stride_y_b 
    o_p += b * stride_o_b
    lock_p += b * stride_lock_b

    X_ptr = tl.make_block_ptr(x_p, (M, K), (stride_x_m, stride_x_k), (start_m, 0), (BLOCK_M, BLOCK_K), (1, 0))
    x = tl.load(X_ptr)

    Y_ptr = tl.make_block_ptr(y_p, (N, K), (stride_y_n, stride_y_k), (0, 0), (BLOCK_N, BLOCK_K), (1, 0))
    for _ in range(0, N, BLOCK_N):
        y = tl.load(Y_ptr)
        o = tl.dot(x, tl.trans(y))
        offsets = tl.arange(0, BLOCK_M)[:, None] * stride_o_m + tl.arange(0, BLOCK_N)[None, :] * stride_o_n
        if ATOMIC_ADD:
            tl.atomic_add(o_p + offsets, o)
        else:
            while tl.atomic_cas(lock_p, 0, 1) != 0:
                pass
            o += tl.load(o_p + offsets)
            tl.store(o_p + offsets, o)
            tl.atomic_xchg(lock_p, 0)
            lock_p += 1
        o_p += BLOCK_N * stride_o_n


def launch(method, dtype=torch.float32):
    torch.random.manual_seed(0)
    B, M, N, K = 4, 512, 512, 64
    x = torch.randn(B, M, K, device="cuda", dtype=dtype)
    y = torch.randn(B, N, K, device="cuda", dtype=dtype)
    o = x.new_zeros(B, M, N)

    if method == "cas":
        lock = torch.zeros(B, M, device="cuda", dtype=torch.int32)
    else:
        lock = torch.empty(1, device="cuda", dtype=torch.int32)

    def grid(META):
        return (B, triton.cdiv(M, META["BLOCK_M"]))

    return kernel[grid](
        x, y, o, lock,
        M, N, K,
        x.stride(0), x.stride(1), x.stride(2),
        y.stride(0), y.stride(1), y.stride(2),
        o.stride(0), o.stride(1), o.stride(2),
        lock.stride(0),
        BLOCK_M=64, BLOCK_N=64, BLOCK_K=K, 
        ATOMIC_ADD=(method == "add"),
        num_warps=4,
    )

cas_ms = triton.testing.do_bench(lambda: launch("cas"))
add_ms = triton.testing.do_bench(lambda: launch("add"))
print(f"Atomic add took {add_ms} ms")
print(f"Atomic cas took {cas_ms} ms")