# Simple Ring Reduce

Here is the very simplest ring reduce. Our goal is to sum up n vectors in R^n. Here each worker holds a column xss[:][i] of the input matrix, and we're summing the columns together. We'll end up with the correct column-sum on the 2nd superdiagonal of the matrix. (It just works out that way; see the picture in the slides.)

This is supposed to correspond to the illustration on [these](https://dlsys.cs.washington.edu/pdf/lecture11.pdf) slides.

In [188]:
import numpy as np

def ring_reduce(xss):
    n = len(xss)
    assert n == len(xss[0]) # Assume the number of workers is equal to
                            # the number of components
    for i in range(n - 1):
        for j in range(n):
            # At time stage i, worker j:
            #
            # - sends data to worker (j - 1) % n in component (j - 1 +
            #   i) % n and receives data from worker (j + 1) % n in
            #   component (j + i) % n
            #
            donor = (j + 1) % n
            component = (j + i) % n
            received_data = xss[component][donor]

            # - Does some reduction in the component it just
            #   received, (j + i - 1) % n

            xss[component][j] += received_data

def create_matrix(n):
    return np.random.randint(1, 10, size=(n, n))

n = 5
original_matrix = create_matrix(n)
print("Original Matrix:")
print(original_matrix)
matrix_to_modify = original_matrix.copy()
ring_reduce(matrix_to_modify)

print("\nReduced Matrix:")
print(matrix_to_modify)

# In our example, the correct results end up on the second
# superdiagonal
for i in range(n):
    assert(sum(original_matrix[i]) == matrix_to_modify[i][(i + 2) % n])

Original Matrix:
[[9 6 4 7 6]
 [3 5 8 2 9]
 [9 9 4 5 8]
 [9 6 9 1 8]
 [3 8 1 2 5]]

Reduced Matrix:
[[15  6 32 28 21]
 [16 13  8 27 25]
 [27 18  9  5 35]
 [33 24 18  9  8]
 [ 3 19 11 10  8]]


# Extending to Attention

For reference, let's look at the top of page 4 of [FlashAttention2](https://arxiv.org/pdf/2307.08691.pdf), where they describe an iterative approach which loads blocks $S^{(i)}$ and $V^{(i)}$ from high-bandwidth memory one at a time, and incorporates them into their running approximation of their output vectors $\mathbf{O}$. Once the algorithm here has loaded and incorporated all the blocks, the output vectors will be correct. (However, note that there are, I think, two typos in the algorithm, which I have pasted into the ring-attention channel.)

In ring attention, we imagine this same stagewise computation taking place separately for each worker. Now instead of loading a blocks $S^{(i)}$ and $V^{(i)}$ from HBM at each stage, you receive at each stage a block of $\mathbf{k}$'s and $\mathbf{v}$'s from your neighbor. The $\mathbf{v}$'s correspond exactly to the block $V^{(i)}$ loaded from bandwidth memory in flash attention. The $\mathbf{k}$'s are not same as the block $S^{(i)}$, however; rather, they're some of the ingredients which you, the worker, need in order to calculate the block $S^{(i)}$. The other ingredients you need are the $\mathbf{q}$ vectors which you, the worker, are holding. (Throughout the whole algorithm, each worker keeps hold of the same $\mathbf{q}$'s.)

By this process each worker gradually builds up their output vectors $\mathbf{O}$.

Let's look at a simplified version, where all the block sizes are 1. It should remind you of the simple ring reduce above, in many ways:

In [189]:
import torch

def update(m, l, O, s, v):
      m_new = torch.max(m, s)
      l_new = torch.exp(m - m_new) * l + torch.exp(s - m_new)
      p = (1 / l_new) * torch.exp(s - m_new)
      O_new = (l_new / l).pow(-1) * O * torch.exp(m - m_new) + p * v
      return m_new, l_new, O_new

def naive_attn(q, k, v, scale):
    s = q @ k.mT * scale
    a = torch.softmax(s, dim=-1)
    return a @ v

def naive_ring_attn(q, k, v, scale):
    n = q.size(0)
    O = torch.zeros_like(v)
    m = torch.full(size = (n,), fill_value=torch.finfo(q.dtype).min)
    l = torch.zeros_like(m)
    for i in range(n):
        for j in range(n):
            #
            # At time stage i, worker j:
            #
            # - Starts out holding k and v vectors both with indices
            #   (j - 1) % N
            #
            # - Computes the dot product of its q (i.e. q_j) with the
            #   k it's currently holding

            k_index = (i + j - 1) % n
            s = torch.dot(q[j], k[k_index]) * scale
            m[j], l[j], O[j] = update(m[j], l[j], O[j], s, v[k_index])

            # - If the time stage is not already (n - 1), we imagine
            #   sending the k and v we're currently holding to our
            #   neighbor in the ring, worker (j - 1) % n, and
            #   receiving a new k and v from our other neighbor,
            #   worker (j + 1) % n
    return O

def naive_attn_incremental(q, k, v, scale):
    n = q.size(0)
    O = torch.zeros_like(v)
    for i in range(n):
        m = torch.tensor(torch.finfo(q.dtype).min)
        l = torch.tensor(0.0)
        for j in range(n):
            
            s = torch.dot(q[i, : ], k[j, :]) * scale
            m, l, O[i, : ] = update(m, l, O[i, :], s, v[j, : ])

    return O


n = 5
d = 3
d_v = 4

k = torch.randn(n, d)
q = torch.randn(n, d)
v = torch.randn(n, d_v)
scale = d**-0.5

o1 = naive_attn(q.clone(), k.clone(), v.clone(), scale)
o2 = naive_ring_attn(q.clone(), k.clone(), v.clone(), scale)
o3 = naive_attn_incremental(q.clone(), k.clone(), v.clone(), scale)
print("naive:", o1)
print("ring:", o2)
print("incremental:", o3)
print("delta:", torch.abs(o1 - o2)>1e-6)

print("delta naive_attn vs naive_ring_attn:", torch.abs(o1 - o2).sum())
print("delta naive_attn vs naive_attn_incremental:", torch.abs(o1 - o3).sum())


naive: tensor([[-1.0470,  0.4340,  0.2762,  0.6966],
        [-0.0038,  0.0370,  0.6868,  0.1482],
        [-0.3025,  0.1263,  0.7148,  1.4139],
        [-0.7548,  0.1541,  0.5451,  1.1918],
        [-1.2017,  0.6329,  0.1033,  0.4751]])
ring: tensor([[-1.0470,  0.4340,  0.2762,  0.6966],
        [-0.0038,  0.0370,  0.6868,  0.1482],
        [-0.3025,  0.1263,  0.7148,  1.4139],
        [-0.7548,  0.1541,  0.5451,  1.1918],
        [-1.2017,  0.6329,  0.1033,  0.4751]])
incremental: tensor([[-1.0470,  0.4340,  0.2762,  0.6966],
        [-0.0038,  0.0370,  0.6868,  0.1482],
        [-0.3025,  0.1263,  0.7148,  1.4139],
        [-0.7548,  0.1541,  0.5451,  1.1918],
        [-1.2017,  0.6329,  0.1033,  0.4751]])
delta: tensor([[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]])
delta naive_attn vs naive_ring_attn: tensor(1.0058e-06)
delta naive_attn vs naive_a

Both results `naive_ring_attn` and `naive_attn_inceremental` are close to out reference `naive_attn` result.