Dot-product attention (especially a specific use of it known as self-attention) is a major component of the transformer architecture. The very high-level idea of attention is usually easy to grasp, but I find that the deeper intuition behind attention is a little hard to tease out of the conventional mathematical notation used to describe/implement. I suspect this is due to the mathematical notation likely following from the implementation where matrix multiplication is used for obvious reasons of efficiency and compactness. However, I find that "unraveling" dot-product attention matrix operations to a point of vector and scalar operations under a few for-loops helped me facilitate a better understanding of what is going on. Specifically, I found this crucial to more clearly think about how the attention mechanism moves information around tokens. After that, it is easy to think about attention as a specific method of "token mixing" for which there could be many alternatives [[1]](https://arxiv.org/abs/2111.11418) (i.e. pooling , MLP mixer [[2]](https://arxiv.org/abs/2105.01601), etc.).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# imagine we have 10 tokens, each represented with an 256-D embedding
x = torch.randn(10, 256)

In [None]:
# project the sequence into query, key, and value matrices
q_proj = nn.Linear(256, 64)
k_proj = nn.Linear(256, 64)
v_proj = nn.Linear(256, 64)
# we could pack this into one linear layer and split the output i.e. nn.Linear(256, 3 * 64), but doing it separately just for clarity

queries = q_proj(x)  # shape = (10, 64)
keys = k_proj(x)  # shape = (10, 64)
values = v_proj(x)  # shape = (10, 64)

In [None]:
# the vectorized dot product attention
att = queries @ keys.transpose(-2, -1)  # shape = (10, 10)
att = F.softmax(att, dim=-1)
out_vectorized = att @ values  # shape = (10, 64)

In [None]:
# let's unravel this a bit in order to gain a better intuition of what is going on

# init an empty matrix to collect dot products
att = torch.zeros(10, 10)


for q_idx, q in enumerate(queries):
    # for the current token q we want to compute the dot product against every other token
    for k_idx, k in enumerate(keys):
        qk_similarity = torch.dot(q, k)
        att[q_idx, k_idx] = qk_similarity

# softmax so the sum of each row is 1, i.e. att[0].sum() == 1
att = F.softmax(att, dim=-1)
# each row corresponds to our tokens while the column values tell us how much a token attends to every other token

# init empty out matrix to collect weighted sums of token embedddings
out_looped = torch.zeros(10, 64)

for row, a in enumerate(att):
    # for each token we look at how much it attended to all the other tokens (including itself) and use those attention 
    # values to do a weighted sum of all the token embeddings, which then becomes the new embedding for the current token

    # init intermediate version of the values that we will weight with the current token's attention values
    weighted_values = torch.zeros_like(values)  # shape = (10, 64)
    for v_idx, v in enumerate(values):
        # apply the attention weight
        weighted_values[v_idx] = a[v_idx] * v

    # complete the weighted sum for the current token
    weighted_values_sum = weighted_values.sum(0)  # shape = (64,)

    # now we have a new embedding for the current token!
    out_looped[row] = weighted_values_sum


In [None]:
# lets check if our looped output matches our vectorized output
torch.allclose(out_vectorized, out_looped, atol=0.000001)  # up to some precision difference

In most language modeling settings, we want tokens to only attend to itself and the previous tokens in the sequence, but not any token after itself in the sequence. 

i.e. for the sequence "The quick brown fox ...", we want "brown" to attend to "brown", "quick" and "The", but not "fox"

We can achieve this in our loop solution by skipping the dot product when `k_idx > q_idx` and setting the attention value to -inf (because this will be converted to 0.0 by the softmax later)

In [None]:
# init an empty matrix to collect dot products
att = torch.zeros(10, 10)


for q_idx, q in enumerate(queries):
    # for the current token q we want to compute the dot product against every other token (in the form of the keys)
    for k_idx, k in enumerate(keys):
        # restricts attention to current token + previous token, but disallows attention to future tokens
        if k_idx > q_idx:
            att[q_idx, k_idx] = float('-inf')
            continue
        qk_similarity = torch.dot(q, k)
        att[q_idx, k_idx] = qk_similarity

# softmax so the sum of each row is 1, i.e. att[0].sum() == 1
att = F.softmax(att, dim=-1)
# each row corresponds to our tokens while the column values tell us how much a token attends to every other token

# init empty out matrix to collect weighted sums of token embedddings
out_looped = torch.zeros(10, 64)

for row, a in enumerate(att):
    # for each token we look at how much it attended to all the other tokens (including itself...hence self-attention) and use those attention values to do a weighted sum of all the token embeddings, which then becomes the new embedding for the current token

    # init intermediate version of the values that we will weight with the current token's attention values
    weighted_values = torch.zeros_like(values)  # shape = (10, 64)
    for v_idx, v in enumerate(values):
        # apply the attention weight
        weighted_values[v_idx] = a[v_idx] * v

    # complete the weighted sum for the current token
    weighted_values_sum = weighted_values.sum(0)  # shape = (64,)

    # now we have a new embedding for the current token!
    out_looped[row] = weighted_values_sum

In [None]:
# and this is how we can achieve this in our vectorized solution
causal_mask = torch.tril(torch.ones(10, 10))
att = queries @ keys.transpose(-2, -1)
att = att.masked_fill(causal_mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)
out_vectorized = att @ values

In [None]:
# lets check if our looped output matches our vectorized output
torch.allclose(out_vectorized, out_looped, atol=0.000001)  # up to some precision difference