In [2]:
import torch
from torch.functional import F
from torch import nn

torch.manual_seed(1337)
B,T,C = 4,8,2 # ---> Batch,Time,Channels
x = torch.randn(B,T,C)
x.shape

  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),


torch.Size([4, 8, 2])

In [3]:
# xbow: backup of words

# we want xbow[b,t] = mean_{i<=t} x[b, i]
xbow = torch.zeros((B, T, C))

for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]
        xbow[b, t] = torch.mean(xprev, 0)

In [4]:
torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [5]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
b = torch.randint(0, 10, (3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [6]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [7]:
w = torch.tril(torch.ones(T, T))
w = w / w.sum(1, keepdim=True)
xbow2 = w @ x # (B, T, T) @ (B, T, C) ---> (B, T, C)

In [8]:
torch.allclose(xbow, xbow2)

False

In [30]:
tril = torch.tril(torch.ones(T,T))
w = torch.zeros(T,T)
w = w.masked_fill(tril==0, float('-inf'))
print(w)
w = F.softmax(w, dim=1)
xbow3 = w @ x

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])


In [10]:
torch.allclose(xbow, xbow3)

False

In [35]:
# self-attention!
B,T,C = 4,8,32
x = torch.randn(B, T, C)

# first head self-attention
hidden_size = 16
query = nn.Linear(C, hidden_size, bias=False)
key = nn.Linear(C, hidden_size, bias=False)
value = nn.Linear(C, hidden_size, bias=False)
q = query(x)
k = key(x)
v = value(x) # for first head

w = q @ k.transpose(-2, -1)
tril = torch.tril(torch.ones(T,T))
w = w.masked_fill(tril == 0., float('-inf'))
w = F.softmax(w, dim=-1) # B, T, T
output = w @ v
output.shape

torch.Size([4, 8, 16])