In [15]:
import torch

torch.manual_seed(2024)
B, T, C = 4, 8, 2  # Batch size, sequence length (or time), embedding size (or channels)
x = torch.randn(B, T, C)
x

tensor([[[-1.2262, -0.0093],
         [ 1.5420, -0.4657],
         [ 0.2795, -0.2610],
         [ 0.6230, -1.1561],
         [ 0.1171, -1.8865],
         [ 2.1822, -0.1930],
         [ 0.5358, -0.8898],
         [-0.3099,  0.7741]],

        [[ 0.1236, -2.1807],
         [ 0.3700,  0.4144],
         [ 1.8567,  1.9776],
         [-0.4322,  1.3667],
         [ 0.8432, -0.0421],
         [ 1.6579, -1.3085],
         [ 0.9962,  0.9391],
         [ 1.4148,  0.6343]],

        [[ 2.7266, -1.4753],
         [-1.4808,  0.0498],
         [ 1.2883, -0.6491],
         [-0.8969,  1.2634],
         [ 0.8273,  0.4594],
         [ 0.3922, -1.0767],
         [-0.0576, -0.0596],
         [ 0.2764, -0.2403]],

        [[ 0.7203, -1.4108],
         [-0.4384,  0.3551],
         [ 0.3730, -1.3050],
         [-0.7983,  1.0442],
         [-0.1227,  0.4022],
         [-1.4295, -0.5656],
         [ 0.6971,  0.1258],
         [-0.0434,  0.5366]]])

In [45]:
# version 1
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]  # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)  # (2)

x[0]

tensor([[-1.2262, -0.0093],
        [ 1.5420, -0.4657],
        [ 0.2795, -0.2610],
        [ 0.6230, -1.1561],
        [ 0.1171, -1.8865],
        [ 2.1822, -0.1930],
        [ 0.5358, -0.8898],
        [-0.3099,  0.7741]])

In [26]:
xbow[0]

tensor([[-1.2262, -0.0093],
        [ 0.1579, -0.2375],
        [ 0.1984, -0.2453],
        [ 0.3045, -0.4730],
        [ 0.2671, -0.7557],
        [ 0.5862, -0.6619],
        [ 0.5790, -0.6945],
        [ 0.4679, -0.5109]])

In [30]:
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print("a=")
print(a)
print("----")
print("b=")
print(b)
print("----")
print("c=")
print(c)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
----
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
----
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [31]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print("a=")
print(a)
print("----")
print("b=")
print(b)
print("----")
print("c=")
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
----
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
----
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [37]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print("a=")
print(a)
print("----")
print("b=")
print(b)
print("----")
print("c=")
print(c)


a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
----
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
----
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [41]:
# version 2
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x  # (T, T) @ (B, T, C) -> (B, T, C)

In [43]:
xbow2[0]

tensor([[-1.2262, -0.0093],
        [ 0.1579, -0.2375],
        [ 0.1984, -0.2453],
        [ 0.3045, -0.4730],
        [ 0.2671, -0.7557],
        [ 0.5862, -0.6619],
        [ 0.5790, -0.6945],
        [ 0.4679, -0.5109]])

In [44]:
torch.allclose(xbow, xbow2)

True

In [49]:
# version 3
import torch.nn.functional as F
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

In [50]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [11]:
# version 4: self-attention
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)
B, T, C = 4, 8, 32  # batch size, sequence length, channels
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x)  # (B, T, head_size) -> (B, T, 16)
q = query(x)  # (B, T, head_size) -> (B, T, 16)
wei = q @ k.transpose(-2, -1)  # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x

In [12]:
wei

tensor([[[-1.7629e+00,        -inf,        -inf,        -inf,        -inf,
                 -inf,        -inf,        -inf],
         [-3.3334e+00, -1.6556e+00,        -inf,        -inf,        -inf,
                 -inf,        -inf,        -inf],
         [-1.0226e+00, -1.2606e+00,  7.6228e-02,        -inf,        -inf,
                 -inf,        -inf,        -inf],
         [ 7.8359e-01, -8.0143e-01, -3.3680e-01, -8.4963e-01,        -inf,
                 -inf,        -inf,        -inf],
         [-1.2566e+00,  1.8719e-02, -7.8797e-01, -1.3204e+00,  2.0363e+00,
                 -inf,        -inf,        -inf],
         [-3.1262e-01,  2.4152e+00, -1.1058e-01, -9.9305e-01,  3.3449e+00,
          -2.5229e+00,        -inf,        -inf],
         [ 1.0876e+00,  1.9652e+00, -2.6213e-01, -3.1579e-01,  6.0905e-01,
           1.2616e+00, -5.4841e-01,        -inf],
         [-1.8044e+00, -4.1260e-01, -8.3061e-01,  5.8985e-01, -7.9869e-01,
          -5.8560e-01,  6.4332e-01,  6.3028e-01]],

In [13]:
wei[0]

tensor([[-1.7629,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-3.3334, -1.6556,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.0226, -1.2606,  0.0762,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.7836, -0.8014, -0.3368, -0.8496,    -inf,    -inf,    -inf,    -inf],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,    -inf,    -inf,    -inf],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,    -inf,    -inf],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,    -inf],
        [-1.8044, -0.4126, -0.8306,  0.5898, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

In [21]:
out.shape

torch.Size([4, 8, 32])

In [None]:
# Illustrate scaled(unit variance)
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1)