In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
# a mathematical trick example
# --------
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [3]:
xbow = torch.zeros_like(x)
for b in range(B):
    for t in range(T):
        xpre = x[b, :t+1, :]  # shape of (t+1, C) max: (8, 2)
        xbow[b, t, :] = xpre.mean(0)
print(xbow, xbow.shape)

tensor([[[ 0.3985,  0.6269],
         [ 0.0140,  0.4470],
         [ 0.0497, -0.2374],
         [ 0.0138, -0.3427],
         [ 0.2345, -0.0921],
         [ 0.0558,  0.0701],
         [ 0.0624,  0.2481],
         [-0.0887, -0.0222]],

        [[ 1.9692, -1.3647],
         [ 0.6544, -0.2590],
         [ 0.4700,  0.2303],
         [ 0.3728,  0.4367],
         [ 0.1888,  0.4668],
         [ 0.4476,  0.5956],
         [ 0.4177,  0.6097],
         [ 0.2087,  0.6970]],

        [[ 1.6457, -0.3854],
         [ 1.1802,  0.2686],
         [ 1.2405, -0.1026],
         [ 1.3393, -0.1850],
         [ 0.9978, -0.3931],
         [ 0.8247, -0.2600],
         [ 0.5636, -0.1821],
         [ 0.4217,  0.1186]],

        [[-1.4700, -1.6396],
         [-0.8222, -1.0085],
         [-0.0022, -0.8849],
         [-0.0171, -0.7637],
         [-0.2281, -0.3650],
         [-0.2700, -0.0518],
         [-0.4623,  0.1826],
         [-0.6522,  0.3009]]]) torch.Size([4, 8, 2])


In [4]:
xbow1 = torch.tril(torch.ones((B, T, T)))
xbow1 = xbow1 / xbow1.sum(2, keepdim=True)
print(xbow1.shape)
print(xbow1)
print(torch.allclose(xbow1 @ x, xbow))

torch.Size([4, 8, 8])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
         [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
         [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
 

In [5]:
# Adding Softmax
tril = torch.tril(torch.ones(T, T))  # will be broadcasted to (B, T, T)
wei = tril.masked_fill(tril == 0, float('-inf'))  # (T, T)
print('wei before softmax:')
print(wei)
wei = F.softmax(wei, dim=1)  # (T, T)
print('wei after softmax:')
print(wei)
xbow2 = wei @ x  # (T, T) @ (B, T, C) --> (B, T, T) @ (B, T, C) --> (B, T, C)
torch.allclose(xbow, xbow2)

wei before softmax:
tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., -inf],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
wei after softmax:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0

False

In [6]:
# Self-attention
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size)
query = nn.Linear(C, head_size)
value = nn.Linear(C, head_size)

k = key(x)  # (B, T, head_size)
q = query(x)  # (B, T, head_size)
v = value(x)  # (B, T, head_size)

wei = q @ k.transpose(-2, -1)  # (B, T, T)
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))  # (B, T, T)
wei = F.softmax(wei, dim=-1)  # (B, T, T)
out = wei @ x  # (B, T, T) @ (B, T, C) -> (B, T, C)
out, out.shape

(tensor([[[ 0.1808, -0.0700, -0.3596,  ..., -0.8016,  1.5236,  2.5086],
          [ 0.0431, -0.0996, -0.1361,  ..., -0.4207,  1.5377,  2.0335],
          [-0.7365,  0.3676,  0.1696,  ..., -0.0146, -0.2657, -0.2375],
          ...,
          [-1.5323,  1.0555,  0.8590,  ..., -0.1288,  0.9983, -0.4741],
          [-0.9746,  0.5715,  0.4782,  ..., -0.2807, -0.0986, -0.3467],
          [-0.9873,  0.0690,  0.1581,  ..., -0.5465, -0.3491, -0.1816]],
 
         [[ 0.4562, -1.0917, -0.8207,  ...,  0.0512, -0.6576, -2.5729],
          [ 0.1049,  0.6014, -1.1665,  ...,  0.6442, -1.0551,  0.5637],
          [ 1.9679, -0.2913,  0.3591,  ..., -0.5409,  0.8500, -0.9528],
          ...,
          [ 0.8053,  0.5475,  0.1131,  ...,  0.1434,  0.1629, -0.1028],
          [ 0.2223,  0.2819,  0.0030,  ..., -0.5282,  0.7499,  0.5875],
          [ 0.5424, -0.0802, -0.4083,  ..., -0.1604,  0.1579, -0.0927]],
 
         [[-0.6067,  1.8328,  0.2931,  ...,  1.0041,  0.8656,  0.1688],
          [-0.4175,  0.7678,