In [28]:
import torch
import torch.nn as nn 
from torch.nn import functional as F

In [29]:
torch.manual_seed(133)
B,T,C = 4, 8, 32
x = torch.randn(B,T,C)

tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf')) #find all spots in tril which are 0 and 
#go to the corresponding spot in wei and replace it with -inf
wei = F.softmax(wei, dim=-1)
out = wei @ x 
out.shape

torch.Size([4, 8, 32])

In [30]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [31]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [32]:
#Now, every single vector  will emit a query and a key vector. Now the way we get an affinity 
#bw these tokens by doing the dot product bw keys and queries of all the vectors. 

In [55]:
torch.manual_seed(132)
B,T,C = 4, 8, 32
x = torch.randn(B,T,C)

#Lets implement one single head of self_attention (Only one Wq, Wk)
head_size = 16 #hyperparameter AKA dk (dimension size)
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) #(B,T,16)
q = query(x) #(B,T,16)

wei = q @ k.transpose(-2,-1) * head_size**-0.5 #We only want to transpose the last two dimensions and not the batch 
#(B, T, 16) @ (B, 16,T) -> (B, T, T)




In [56]:
wei[0]

tensor([[-0.2598,  0.5358,  0.1013,  0.0323, -0.0110,  0.0885, -0.2039,  0.2836],
        [-0.1062,  0.4194,  0.7516, -0.2718,  0.3188, -0.0167,  0.1712, -0.0058],
        [-0.0644,  0.1771, -0.1038,  0.2501, -0.4286, -0.2471,  0.0573,  0.6216],
        [-0.2873,  0.0468, -0.0543,  0.4543, -0.2738, -0.3870, -0.0581, -0.6063],
        [-0.1954,  0.2519, -0.1002, -0.4227,  0.3870, -0.1670,  0.1485,  0.0466],
        [-0.0610,  0.7959,  0.5536, -0.0731,  0.3638,  0.3262, -0.0692,  0.0164],
        [ 0.0413,  0.0369, -0.0124, -0.1075,  0.0900,  0.2772,  0.0048,  0.0801],
        [-0.3004,  0.4733,  0.3640, -0.5678,  0.4607,  0.3939,  0.0692,  0.3841]],
       grad_fn=<SelectBackward0>)

In [57]:

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril==0, float('-inf')) #find all spots in tril which are 0 and 
#go to the corresponding spot in wei and replace it with -inf
wei = F.softmax(wei, dim=-1)

v = value(x) # (B,T,16)

out = wei @ v
out.shape

torch.Size([4, 8, 16])

In [58]:
wei[0] #We havent trained the attention as of now. However the 
#random initializations of key and query function allow for a softmax. 

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3715, 0.6285, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3092, 0.3936, 0.2972, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1737, 0.2426, 0.2192, 0.3646, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1600, 0.2502, 0.1760, 0.1274, 0.2864, 0.0000, 0.0000, 0.0000],
        [0.1088, 0.2562, 0.2011, 0.1075, 0.1663, 0.1602, 0.0000, 0.0000],
        [0.1411, 0.1405, 0.1338, 0.1216, 0.1482, 0.1787, 0.1361, 0.0000],
        [0.0744, 0.1612, 0.1445, 0.0569, 0.1592, 0.1489, 0.1076, 0.1474]],
       grad_fn=<SelectBackward0>)