In [1]:
import torch
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
B, T, C = 4, 8, 32  # Batch = block of text, Time = Token, C = Latent dimension of embedded token

In [3]:
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 32])

In [4]:
"""
t=45m
If you want to give the current token some information from all prior tokens, a weak information algo could
be just to average the current, and all previous token embeddings
"""
x_bag_of_words = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        x_prev = x[b, :t+1]
        x_bag_of_words[b, t] = torch.mean(x_prev, dim=0)

In [5]:
"""
t=48m
matmul
"""
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b

In [6]:
c

tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])

In [8]:
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)  # [T, T]
x_bag_of_words_2 = weights @ x  # [T, T] @ [B, T, C]  -> [B, T, T] @ [B, T, C] <- auto broadcast

In [11]:
x_bag_of_words[0]

tensor([[-1.8227, -1.1273,  1.5287,  0.6238, -0.8932,  0.8044,  0.0726, -1.0536,
          0.1823,  0.4875, -0.6720, -1.6526, -0.8162,  0.4470, -0.4603, -0.2107,
         -0.0129, -1.1054,  0.5221,  0.7951,  1.0665, -2.1980, -0.6582, -0.4577,
         -1.1629,  0.9793, -1.8782, -0.3928,  1.3737,  0.9752,  1.0107, -0.6734],
        [-1.7003, -0.9309,  0.1334, -0.0312, -0.2632, -0.0542,  0.4349,  0.1707,
          0.4232, -0.3119, -0.6285, -0.6307, -0.9627,  0.0278, -0.0460,  0.0322,
          0.4694, -1.5157,  0.5696,  0.5325,  0.3490, -0.8446, -0.6399, -0.5391,
         -0.6487,  0.6789, -0.5102, -1.2916,  0.6075, -0.0218,  0.9780,  1.0115],
        [-1.0051, -0.4442, -0.3487, -0.0123, -0.2081,  0.1498,  0.2510, -0.1040,
          0.0475,  0.3782, -0.3582,  0.0159, -0.7279,  0.3315, -0.3390, -0.4231,
          0.4025, -1.1518,  0.4305,  0.2544, -0.3452, -0.4493, -0.5750, -0.2918,
         -0.5466,  0.4084, -0.6851, -0.9648,  1.1235,  0.0195,  0.0843,  0.5277],
        [-0.7769, -0.3613

In [12]:
assert torch.allclose(x_bag_of_words, x_bag_of_words_2)

In [13]:
"""
56m
softmax
"""
tril = torch.tril(torch.ones(T, T))
weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))  # Wherever the lower diagonal is zero, set to -inf
weights = F.softmax(weights, dim=-1)
x_bag_of_words_3 = weights @ x


In [14]:
assert torch.allclose(x_bag_of_words, x_bag_of_words_3)

In [15]:
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [19]:
"""
1hr5min
"""
torch.manual_seed(1337)
B, T, C = 4, 8, 32  # batch, time, channels
x = torch.rand(B, T, C)

head_size = 16
key = torch.nn.Linear(C, head_size, bias=False)  # -> x [B, T, C] @ W [C, h] (matrix multiplication) -> [B, T, h]
query = torch.nn.Linear(C, head_size, bias=False)
value = torch.nn.Linear(C, head_size, bias=False)

k = key(x)
q = query(x)

"""
Each embedded token vector generates its own key and query vector (k & q)
No information has been shared between tokens yet
"""

'\nEach embedded token vector generates its own key and query vector (k & q)\nNo information has been shared between tokens yet\n'

In [20]:
"""
1hr6m13s
This computes the affinity by dot producting the query and key matricies
    IE each embedding token query vector gets multiplied by every other (and it's own) embedded token key vector
Element B,0,0 is query row 0 dotted with key row 0 (what info does the first token relate to itself)
Element B,1,0 is query row 1 dotted with key row 0 (what info does the second token relate to the first token)
"""
weights = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) -> (B, T, T)

In [21]:
"""
Remove the communication to future tokens
"""
tril = torch.tril(torch.ones(T, T))
weights = weights.masked_fill(tril==0, float('-inf'))
"""
The softmax happens AFTER setting future token information to -inf, therefore even though the communication has happened
No information in the final weights matrix is introduced during the softmax
"""
weights = F.softmax(weights, dim=-1)
v = value(x)
out = weights @ v

In [22]:
out.size()

torch.Size([4, 8, 16])