# Bag of Words

In [24]:
import torch
from torch.nn import functional as F
import torch.nn as nn

In [18]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# X bag of words
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xbow[b, t] = torch.mean(x[b, :t + 1], 0)

In [19]:
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)
xbow_vectorized = weights @ x

In [20]:
tril = torch.tril(torch.ones(T, T))
weights_s = torch.zeros((T, T))
weights_s = weights_s.masked_fill(tril == 0, float('-inf'))
weights_s = F.softmax(weights_s, dim=1)
xbow_softmax = weights_s @ x

In [21]:
xbow_vectorized[0][0], xbow_softmax[0][0], xbow[0][0]

(tensor([ 0.9734,  0.3734,  0.7993,  0.5965, -1.0673, -0.1763,  0.7493, -0.8393,
         -0.2653,  0.2085, -0.1127,  1.5200,  1.5352, -0.2248,  0.0479,  0.5205,
         -0.0579, -0.1327, -1.0962, -0.0525, -1.2933, -3.1977,  1.8125, -0.4697,
          0.8258,  0.3248, -0.4220,  0.9590, -0.1772, -0.6007, -0.6869,  1.2002]),
 tensor([ 0.9734,  0.3734,  0.7993,  0.5965, -1.0673, -0.1763,  0.7493, -0.8393,
         -0.2653,  0.2085, -0.1127,  1.5200,  1.5352, -0.2248,  0.0479,  0.5205,
         -0.0579, -0.1327, -1.0962, -0.0525, -1.2933, -3.1977,  1.8125, -0.4697,
          0.8258,  0.3248, -0.4220,  0.9590, -0.1772, -0.6007, -0.6869,  1.2002]),
 tensor([ 0.9734,  0.3734,  0.7993,  0.5965, -1.0673, -0.1763,  0.7493, -0.8393,
         -0.2653,  0.2085, -0.1127,  1.5200,  1.5352, -0.2248,  0.0479,  0.5205,
         -0.0579, -0.1327, -1.0962, -0.0525, -1.2933, -3.1977,  1.8125, -0.4697,
          0.8258,  0.3248, -0.4220,  0.9590, -0.1772, -0.6007, -0.6869,  1.2002]))

# Self-attention

In [39]:
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)
q = query(x)
v = value(x)

# Scaling by sqrt of head_size to scale down the variance of the weights
weights = q @ k.transpose(-1, -2) * head_size**-.5

In [40]:
tril = torch.tril(torch.ones(T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=1)
out = weights @ v