In [1]:
import torch

torch.manual_seed(456)

nrows, BLOCK_DHEAD = 16, 8
BLOCK_N = 2
n_segments = 4  # rename to segment number
assert nrows % n_segments == 0
nrows_per_segment = nrows // n_segments
assert nrows_per_segment % BLOCK_N == 0
nrows_per_block = nrows_per_segment // BLOCK_N

Q_mat = torch.rand((1, BLOCK_DHEAD))
K_mat = torch.rand((nrows, BLOCK_DHEAD))
V_mat = torch.rand((nrows, BLOCK_DHEAD))
O_mat = torch.zeros_like(Q_mat)

In [2]:
# GM variables
d_j_split = torch.zeros((n_segments, nrows_per_block, Q_mat.shape[0]))
l_j_split = torch.zeros((n_segments, nrows_per_block, Q_mat.shape[0]))
acc_splitted = torch.zeros((n_segments, nrows_per_block, *O_mat.shape))

q = Q_mat  # load from GM
for segment_index in range(0, n_segments):  # <- this loop is parallelized, it is implicit in the Triton implementation
    for block_index_N, block_start_N in enumerate(range(0, nrows_per_segment, BLOCK_N)):  # <- this loop is parallelized, it is implicit in the Triton implementation

        block_start_N += segment_index * nrows_per_segment
        block_end_N = block_start_N + BLOCK_N

        k = K_mat[block_start_N:block_end_N, :]  # load from GM
        qk = q @ k.T
        l_j = torch.max(qk, dim=1).values
        l_j_split[segment_index, block_index_N, :] = l_j  # saving to GM
        numerators = torch.exp(qk - l_j[:, None])  # safe softmax numerator
        d_j = torch.sum(numerators, dim=1)
        d_j_split[segment_index, block_index_N, :] = d_j  # saving to GM

        v = V_mat[block_start_N:block_end_N, :]  # load from GM
        o_segment = numerators @ v

        acc_splitted[segment_index, block_index_N, :, :] = o_segment  # saving to GM


In [3]:
# shared memory variables
l_i = torch.zeros((Q_mat.shape[0],)) - float("inf")
d_i = torch.zeros((Q_mat.shape[0],))
acc = torch.zeros_like(O_mat)


for block_index_N in range(0, nrows // (n_segments * BLOCK_N)):
    for segment_index in range(0, n_segments):

        acc_i = acc_splitted[segment_index, block_index_N]
        # l_j = l_j_split[segment_index][block_index_N]
        l_j = l_j_split[segment_index, block_index_N]
        d_j = d_j_split[segment_index][block_index_N]

        l_new = torch.maximum(l_i, l_j)
        alpha = torch.exp(l_i - l_new)
        beta = torch.exp(l_j - l_new)
        d_new = alpha * d_i + beta * d_j

        p_scale = beta / d_new

        acc_i *= p_scale[:, None]
        acc_scale = d_i / d_new * alpha
        # scaling factor is applied to the exported matrix
        acc = acc * acc_scale[:, None]
        acc += acc_i  # accumulating in shared memory
        d_i = d_new
        l_i = l_new

O_mat = acc  # write to GM

print(O_mat)
print((torch.nn.functional.softmax(Q_mat @ K_mat.T, dim=1)) @ V_mat)
assert torch.allclose(O_mat, (torch.nn.functional.softmax(Q_mat @ K_mat.T, dim=1)) @ V_mat)

tensor([[0.4563, 0.6146, 0.5144, 0.3769, 0.5411, 0.6278, 0.4206, 0.4573]])
tensor([[0.4563, 0.6146, 0.5144, 0.3769, 0.5411, 0.6278, 0.4206, 0.4573]])
