In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [45]:
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [46]:
# x is of shape B,T,C.
# The B dimension represents each different batch of data. The batching is only used for computational reasons and each batch is independent of eachother.
# The T dimension is an array representing the sequence of tokens.
# The C dimension has an entry for each letter/token in the vocabulary. It holds the probability of each token in the vocabulary at each position in the sequence.
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)
# xbow (X bag-of-words) holds the rolling average of the values in X

In [47]:
wei = torch.tril(torch.ones(T, T)) # Lower triangular matrix with ones in the lower entries
wei = wei / wei.sum(1, keepdim=True) # divide each row by its sum (normalize each row)
xbow2 = wei @ x # matrix multiplication of weights matrix and x. (T, T) @ (B, T, C)
# Pytorch performs this multiplication by repeating over each batch, effectively resulting in:
# (T, T) @ (B, T, C)  -->  (B, T, T) @ (B, T, C)  -->   (B, T, C)

In [None]:
tril = torch.tril(torch.ones(T, T)) # Lower triangular matrix with ones in the lower entries
wei2 = torch.zeros((T, T)) # zero matrix
wei2 = wei2.masked_fill(tril == 0, float('-inf')) # If the entry is zero, replace with -inf. Else, replace with 0. Results in an upper triangular matrix of -inf. With zeros on main diagonal and below.
wei2 = F.softmax(wei2, dim=-1) # softmax exponentiates and then divides by row sum. e^0 = 1 and e^-inf = 0, so this results in the targeted weight matrix.
print(wei2)
xbow3 = wei2 @ x

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [49]:
print(x[0])
print(xbow[0])
print(xbow2[0])

print(torch.allclose(xbow, xbow2)) # is xbow identical to xbow2?
print(torch.allclose(xbow, xbow3)) # is xbow identical to xbow3?

tensor([[-0.2726, -0.5487],
        [-0.3130, -1.5860],
        [ 0.5492, -1.8602],
        [-0.5576, -2.0523],
        [-0.1684, -0.1545],
        [-2.1030,  0.7475],
        [-0.2567,  0.5828],
        [ 0.4637, -0.6562]])
tensor([[-0.2726, -0.5487],
        [-0.2928, -1.0674],
        [-0.0122, -1.3317],
        [-0.1485, -1.5118],
        [-0.1525, -1.2404],
        [-0.4776, -0.9090],
        [-0.4460, -0.6959],
        [-0.3323, -0.6910]])
tensor([[-0.2726, -0.5487],
        [-0.2928, -1.0674],
        [-0.0122, -1.3317],
        [-0.1485, -1.5118],
        [-0.1525, -1.2404],
        [-0.4776, -0.9090],
        [-0.4460, -0.6959],
        [-0.3323, -0.6910]])
True
True
