how attention works? each token emits two vectors:
* query (what info i'm looking for)
* key (what info i contain)

How to get affinities between those tokens?
* dot product between queries and keys

For a single token it looks in the following way:
* token's query gets dot producted with all the previous tokens' keys

In self-attention, we separately compute queries (𝑞) and keys (𝑘) to determine how much one token should attend to another. The computed attention weights (wei) capture the compatibility between queries and keys. However, the information that is passed along (aggregated) comes from the “values.” By having a separate linear transformation for values, the network can independently control:
* Which information is used to decide the attention weights (via queries and keys), and
* Which information is aggregated and passed to the next layer (via values).

At the end of an attention block, output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

In summary, multiplying 𝑥 by the value linear layer transforms the token representations into a space where the aggregated information (via the attention weights) is most useful. It decouples the content that is aggregated (values) from the mechanism that decides how to aggregate (queries and keys), giving the model the flexibility to learn both processes independently.

In [22]:
import torch
from torch.nn import functional as F
import torch.nn as nn


torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)

# single self-attention head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# here each token produces its key and query
k = key(x)    # (B, T, head_size) vector [head_size] - the key of the token
q = query(x)  # (B, T, head_size) vecotr [head_size] - the query of the token
# communication happens now (basically it's a matrix of affinities), in other words all queries get dot producted with all keys
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
# out = wei @ x
v = value(x) # (B, T, head_size)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [23]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5599, 0.4401, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3220, 0.2016, 0.4764, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1640, 0.0815, 0.2961, 0.4585, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2051, 0.3007, 0.1894, 0.1808, 0.1241, 0.0000, 0.0000, 0.0000],
        [0.0600, 0.1273, 0.0291, 0.0169, 0.0552, 0.7114, 0.0000, 0.0000],
        [0.1408, 0.1025, 0.1744, 0.2038, 0.1690, 0.0669, 0.1426, 0.0000],
        [0.0223, 0.1086, 0.0082, 0.0040, 0.0080, 0.7257, 0.0216, 0.1016]],
       grad_fn=<SelectBackward0>)

Taking the last row (aka the eight token) as an example:
[0.0223, 0.1086, 0.0082, 0.0040, 0.0080, 0.7257, 0.0216, 0.1016]

The eight token (0.1016) finds the sixth token the most interesting (0.7257).