In [None]:
import numpy as np
import torch
from torch.nn import functional as F
import torch.nn as nn

# In this part we are gonna implement small self-attention for single individual head

# Right now we are working with 4 by 8 arrangement of tokens, where information of each token is 32-dim
* The code above does simple average of all past tokens to the current token
* So previous information and current information are mixed together in average

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x

out.shape

torch.Size([4, 8, 32])

# We achieve by far the structure of lower triangular form where we initialize infinities between tokens to be zero
* And wei gives us structure every single row has uniform numbers
* And matrix multiply do simple average

In [None]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

# But we don't want it to be uniform, because different tokens will find other tokens more or less interesting, so we want them to be data dependent
* We want to gather informations from the past, but in data dependent way

# And single-attention solve this problem, every single token at each position will imit two vectors (query and key)
* Query is telling what am i looking for
* And key is telling what do i contain
* Then the way we get infinities between this tokens is dot product between keys and queries
* So my  query dot product with all keys of other tokens and they dot product becomes wei

# So not when key and query are aligned, they interact with very high amount and i will get to learn more about that specific token, opposed to any other token in the sequence

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x

out.shape

torch.Size([4, 8, 32])

# I introduce hyper parameter head_size and initialize linear modules with bias=False, so it will apply matrix multiply with some fixed weights
* Then we get k and q, by forwarding this modules on x

# And when we forward this linear modules on top of x, all tokens on all the positions in B by T arrangement in parallel and independently produce key and query
* So no communication is happened yet

# Communication comes when all queries will dot product with all keys
* We want wei, so the infinities between this to be q @ k, they have the same size
* But we need to transpose k two last dimensions, from (B, T, 16) to (B, 16, T)
* And this matrix multiply q @ k.transpose(-2, -1) will give us (B, T, T)


In [None]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T, T))
#wei = wei.masked_fill(tril == 0, float('-inf'))
#wei = F.softmax(wei, dim=-1)
out = wei @ x

out.shape

torch.Size([4, 8, 32])

# So for every row of B we have T**2 matrix giving us infinities, which are our wei
* And they are not zeros like previous, but the are coming from dot product between keys and queries
* Weighted aggregation is a function in data dependent manner between keys and queries of the tokens

# And without softmax and masking the dot product operation looks like below
* This is raw infinities and interactions and takes values from -2 to 2

In [None]:
wei[0]

tensor([[-1.7629, -1.3011,  0.5652,  2.1616, -1.0674,  1.9632,  1.0765, -0.4530],
        [-3.3334, -1.6556,  0.1040,  3.3782, -2.1825,  1.0415, -0.0557,  0.2927],
        [-1.0226, -1.2606,  0.0762, -0.3813, -0.9843, -1.4303,  0.0749, -0.9547],
        [ 0.7836, -0.8014, -0.3368, -0.8496, -0.5602, -1.1701, -1.2927, -1.0260],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,  0.8638,  0.3719,  0.9258],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,  1.4187,  1.2196],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,  0.8048],
        [-1.8044, -0.4126, -0.8306,  0.5899, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
#wei = F.softmax(wei, dim=-1)
out = wei @ x

out.shape

torch.Size([4, 8, 32])

# But if i am 5-th token i dont want to aggregate anything from future tokens
* So through masking operation they are not allowed to communicate

In [None]:
wei[0]

tensor([[-1.7629,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-3.3334, -1.6556,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.0226, -1.2606,  0.0762,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.7836, -0.8014, -0.3368, -0.8496,    -inf,    -inf,    -inf,    -inf],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,    -inf,    -inf,    -inf],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,    -inf,    -inf],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,    -inf],
        [-1.8044, -0.4126, -0.8306,  0.5899, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

# And we want to have nice distribution, which sums to one, so we use softmax to exponentiate and normalize

# And now every batch element takes different wei, because it contain different tokens, now it is data dependent
* Por example in the last row the 8-th token (0.2391), know what content it has and at what positionits in
* And based on that it creates query (it is looking that stuff and it contains that informations)
* Then all tokens gets to imitate keys, maybe one of the channel could contain the exact informations as query of 6-th token is looking for
* That key would have high number in that specific channel

# That's how query and key when they dot product, they can find each other and create high infinity
* Like tokens 0.2297 and 0.2391, then through softmax, we will end up aggregating a lot of its informations into its position
* So we will end up learning a lot about it

In [None]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

# And there is one more thing complete single self-attention head
* That is, when we do the aggregation we don't exactly aggregate tokens, but the value
* We create value just like key and query
* And not we don't aggregate x but v, which is achived by propagating linear module on top of x
* Then we output wei @ v, so v becomes the element that we aggregate, instead of raw x

# And it makes that the output of this single head is 16-dimensional, because that is head_size

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
#out = wei @ x

out.shape

torch.Size([4, 8, 16])

# We can thing of x as a private information to the token, por example the 5-th token has some identity and its information is kept in vector x
* And its telling what it is interested in, what it have and if any other token find it interesting, there is also what it communicate to it

# So v is the thing, which is aggregated for the purposes of the single self-attention head, between the different tokens