In [1]:
import torch
import torch.nn.functional as F

In [2]:

torch.manual_seed(1337)
B,T,C = 4,8,2

x = torch.randn(B,T,C)
# print(x)
x.shape

torch.Size([4, 8, 2])

In [3]:
x_bow = torch.zeros(B,T,C)

for b in range(B):
    for t in range(T):
        x_prev = x[b,:t+1]
        x_bow[b,t]=torch.mean(x_prev, 0)

In [4]:
print(x[1])
x_bow[1]

tensor([[ 1.3488, -0.1396],
        [ 0.2858,  0.9651],
        [-2.0371,  0.4931],
        [ 1.4870,  0.5910],
        [ 0.1260, -1.5627],
        [-1.1601, -0.3348],
        [ 0.4478, -0.8016],
        [ 1.5236,  2.5086]])


tensor([[ 1.3488, -0.1396],
        [ 0.8173,  0.4127],
        [-0.1342,  0.4395],
        [ 0.2711,  0.4774],
        [ 0.2421,  0.0694],
        [ 0.0084,  0.0020],
        [ 0.0712, -0.1128],
        [ 0.2527,  0.2149]])

In [5]:
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)

In [6]:
# x_bow2 = torch.zeros(B,T,C)

# for b in range(B):
#     x_prev = wei @ x[b] # (T, T) @ (T, C)
#     x_bow2[b] = x_prev

x_bow2 = wei @ x

In [7]:
x_bow2.shape

torch.Size([4, 8, 2])

In [8]:
x_bow2[1], x_bow[1]

(tensor([[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]]),
 tensor([[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]]))

In [9]:
torch.allclose(x_bow, x_bow2, atol=1e-06, rtol=1e-05)

True

In [10]:
tril = torch.tril(torch.ones(T,T))
wei = torch.ones(T,T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
print(wei)
x_bow3 = wei @ x

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [11]:
torch.allclose(x_bow, x_bow3, atol=1e-06, rtol=1e-05)

True

In [12]:
ten = torch.tensor([[1,2],[3,4],[5,6]])
print(ten.shape)
ten = torch.stack([ten, ten])
print(ten.shape)

torch.Size([3, 2])
torch.Size([2, 3, 2])


## Cross entorpy test

In [13]:
logits = torch.randn((1, 4))
logits = torch.tensor([[0,0,0,1]], dtype=torch.float)
targets = torch.tensor([[0,0,1,0]], dtype=torch.float)
# targets = torch.randint(0, 4, (1,))
print(f"logits shape: {logits.shape} - {logits}")
print(f"Target shape: {targets.shape} - {targets}")
ret = F.cross_entropy(logits, targets)
print(ret)


logits shape: torch.Size([1, 4]) - tensor([[0., 0., 0., 1.]])
Target shape: torch.Size([1, 4]) - tensor([[0., 0., 1., 0.]])
tensor(1.7437)


In [14]:

logits_m = torch.tensor([[0.2,0.2,0.4,0.2]], dtype=torch.float)
idx_next = torch.multinomial(logits_m, num_samples=10, replacement=True)
print(idx_next)

tensor([[2, 1, 2, 1, 2, 2, 3, 0, 0, 2]])


## Back to Self Attention

In [15]:
import torch.nn as nn


torch.manual_seed(1337)
B,T,C = 4,8,32

x = torch.randn(B,T,C)
# print(x)
x.shape

torch.Size([4, 8, 32])

In [31]:
### Let's see how single head perform self attention

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, C) -> (B, T, H)
q = query(x) # (B, T, C) -> (B, T, H)
v = value(x) # (B, T, C) -> (B, T, H)
wei = q @ k.transpose(1,2) * head_size**-0.5 # (B, T, H) @ (B, H, T) -> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1)
out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)

In [32]:
wei

tensor([[[0.1097, 0.0911, 0.0841, 0.1163, 0.1970, 0.0668, 0.1125, 0.2224],
         [0.0607, 0.1189, 0.1101, 0.1670, 0.1863, 0.1133, 0.1264, 0.1173],
         [0.0943, 0.1137, 0.1480, 0.1122, 0.1300, 0.1291, 0.1243, 0.1485],
         [0.1474, 0.1636, 0.0977, 0.1487, 0.1151, 0.1510, 0.1118, 0.0647],
         [0.1788, 0.1536, 0.1348, 0.0648, 0.1254, 0.0996, 0.1036, 0.1395],
         [0.0770, 0.1333, 0.1739, 0.0686, 0.1598, 0.0756, 0.1332, 0.1786],
         [0.1853, 0.1340, 0.1158, 0.1295, 0.1070, 0.1192, 0.1124, 0.0968],
         [0.1069, 0.0991, 0.1179, 0.1340, 0.1730, 0.1059, 0.1212, 0.1419]],

        [[0.2106, 0.1120, 0.1458, 0.1403, 0.1225, 0.0921, 0.0849, 0.0917],
         [0.1155, 0.1646, 0.1409, 0.1305, 0.0768, 0.1074, 0.1497, 0.1145],
         [0.2243, 0.0637, 0.1222, 0.0937, 0.1749, 0.0955, 0.0942, 0.1315],
         [0.1593, 0.0987, 0.1577, 0.0714, 0.1391, 0.0910, 0.1019, 0.1810],
         [0.1316, 0.1208, 0.1327, 0.1202, 0.1110, 0.1234, 0.1209, 0.1394],
         [0.0131, 0.296