In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.set_default_device("mps")  # use gpu
torch.manual_seed(1337)  # seed for reproducibility

<torch._C.Generator at 0x106245f90>

In [2]:
# some stuff from last time


In [3]:
# the mathematical trick in self-attention
#   - we want the tokens to "talk to each other" (communicate information between themselves)
#   - as we process tokens, we want information from past tokens to only flow forwards to the current token,
#     but not backwards from future tokens to the current token (which wouldn't causally make sense).
#   - important idea: information from all past tokens needs to be preserved and used to predict the next token.
B, T, C = 4, 8, 2  # (batch, time, channels) or (batch, token, embedding)
x = torch.randn(B, T, C)
print(x.shape)

torch.Size([4, 8, 2])


In [4]:
# version 1: averaging past context with for loops (weakest form of aggregation) a.k.a. "bag of words"
#   - xbow1[b, t] = mean_{i <= t} x[b, i]
xbow1 = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]  # (t, C)
        xprev_mean = torch.mean(xprev, dim=0)  # (C,)
        xbow1[b, t] = xprev_mean

print("--- x[0] ---")
print(x[0])
print("--- xbow[0] ---")
print(xbow1[0])

--- x[0] ---
tensor([[-0.8739, -0.8078],
        [ 0.1450,  0.3556],
        [-1.1429, -0.6025],
        [-0.5363, -1.2178],
        [-2.2164,  0.6130],
        [-0.7868,  0.3607],
        [ 0.3933,  0.7113],
        [ 1.4555, -1.0590]], device='mps:0')
--- xbow[0] ---
tensor([[-0.8739, -0.8078],
        [-0.3645, -0.2261],
        [-0.6239, -0.3516],
        [-0.6020, -0.5681],
        [-0.9249, -0.3319],
        [-0.9019, -0.2165],
        [-0.7168, -0.0839],
        [-0.4453, -0.2058]], device='mps:0')


In [5]:
# version 2: faster implementation of version 1 ("bag of words") using matrix multiplication and tril matrix
#   - xbow2[b, t] = mean_{i <= t} x[b, i]
tril = torch.tril(torch.ones(T, T))  # sums the aggregated context
tril_avg = tril / torch.sum(tril, dim=1, keepdim=True)  # averages the aggregated context
print("--- tril_avg ---")
print(tril_avg)

# my impl
# xbow2 = (x.transpose(1, 2) @ tril_avg.T).transpose(1, 2)  # [(B, C, T) @ (1, T, T)]^T  -->  [(B, C, T)]^T  -->  (B, T, C)

# Karpathy impl (wow! this works??)
#   - yes! it works cuz:
#       1. `tril_avg` gets broadcast to shape (1, T, T)
#       2. only the last two dimensions of `tril_avg` and `x` get matrix multiplied
xbow2 = tril_avg @ x  # (1, T, T) @ (B, T, C)  -->  (B, T, C)

print("--- x[0] ---")
print(x[0])
print("--- xbow[0] ---")
print(xbow1[0])

--- tril_avg ---
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]],
       device='mps:0')
--- x[0] ---
tensor([[-0.8739, -0.8078],
        [ 0.1450,  0.3556],
        [-1.1429, -0.6025],
        [-0.5363, -1.2178],
        [-2.2164,  0.6130],
        [-0.7868,  0.3607],
        [ 0.3933,  0.7113],
        [ 1.4555, -1.0590]], device='mps:0')
--- xbow[0] ---
tensor([[-0.8739, -0.8078],
        [-0.3645, -0.2261],
        [-0.6239, -0.3516],
        [-0.6

In [6]:
# confirm equality of version 1 and version 2
print("xbow1 == xbow2:", torch.allclose(xbow1, xbow2))

xbow1 == xbow2: True


In [7]:
# version 3: implementation of version 2 using softmax
#   - xbow3[b, t] = mean_{i <= t} x[b, i]
tril = torch.tril(torch.ones(T, T))  # sums the aggregated context

wei_logits = torch.zeros((T, T))
wei_logits = wei_logits.masked_fill(tril == 0, float('-inf'))  # softmax(-inf) = 0
tril_avg = F.softmax(wei_logits, dim=-1)  # averages the aggregated context

xbow3 = tril_avg @ x

In [8]:
# confirm equality of version 3 with version 1 and version 2
print("xbow3 == xbow1:", torch.allclose(xbow3, xbow1))
print("xbow3 == xbow2:", torch.allclose(xbow3, xbow2))

xbow3 == xbow1: True
xbow3 == xbow2: True


In [9]:
# version 3 is a "preview" of self-attention.

# specifically, take a look at `wei` - it's very similar to the key/query matrices in attention blocks:
#   - during training of a key/query matrix in an attention block:
#       - tokens will start looking at each other.
#       - some tokens will find other tokens more or less interesting.
#           - i.e. key tokens and query tokens will develop some variable "affinity" to each other.
#           - "affinity" = a measure of how related the key token is to the query token.
#               - similar to the weight in a linear layer's weight matrix.
#   - notice that `wei` roughly represents this key/query matrix of affinities:
#       - `wei[i, j] = -inf` when i > j, which basically says that (key) tokens from the
#          future can't communicate with (query) tokens (which are in the present).
#           - simply, "the future cannot communicate with the past".
#       - `wei[i, j]` where i <= j, represents the affinity between the i-th (query) token
#         and the j-th (key) token of context.
#           - in our case, all the affinities are set to 0 because we're averaging over the context.
#               - i.e. all the connections have the same weight (affinity).
#           - when we train an attention block, these affinities will be learned by the model.
#           - softmax will take these affinities and convert them to probabilities.
#               - these probabilities represent the percent correlation between the tokens (keys and queries).
#               - note that an affinity of `-inf` will be converted to a percent correlation of `0` as we'd like.
print("--- tril_avg ---")
print(tril_avg)
print("--- wei ---")
print(wei_logits)

--- tril_avg ---
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]],
       device='mps:0')
--- wei ---
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
     

In [10]:
# version 4: self-attention! (the crux)
#   - we don't want uniform affinities between tokens (averaging won't work).
#       - Ex. if I'm a vowel, maybe I want to look at consonants in my past and I want that information to flow to me.
#   - problem that self-attention solves:
#       - I want to gather data from the past, but I want to do it in a data-dependant way.
#   - here's how self-attention solves it:
#       - every single token, at each position, will emit two vectors:
#           - query vector - what am I looking for?
#           - key vector - what do I contain?
#       - also every single token has an additional vector:
#           - value vector - if you find me "interesting", here's what I will communicate to you
#       - the way we get affinities between tokens in a sequence now, is we basically just do a "dot product" between
#         the keys and the queries.
#           - my query "dot products" with all the keys of all the other tokens.
#               - if the key and query are "aligned", they will "interact" (dot product) to a very high amount.
#               - I will get to learn more about that specific token as opposed to any other token in the sequence.
#           - this dot product "becomes" `wei` from version 3.
#       - we pass these affinities though a softmax to get a probability distribution over the prior tokens (i.e. context)
#           - now we can aggregate the context up to the current token using this probability distribution.
#               - we essentially get a weighted sum of the context tokens according to their affinity.
#               - this aggregated context will end up being much more relevant and useful than those of versions 3
#                 or below where we used flat averaging.
#           - note: this aggregation is done in the "value" embedding space for this head.
#               - "value" embedding space = the space that value vectors are defined and operate in for this head.
#               - we take the value embeddings of the context tokens and squash them into a single value embedding that
#                 represents the "aggregate" of all the context tokens (i.e. aggregated context).

B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn(B, T, C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)    # (B, T, head_size)
q = query(x)  # (B, T, head_size)
wei_logits = q @ k.transpose(-2, -1)  # (B, T, head_size) @ (B, head_size, T)  -->  (B, T, T)
scaled_wei_logits = wei_logits * head_size**-0.5

tril = torch.tril(torch.ones(T, T))  # sums the aggregated context
scaled_wei_logits = scaled_wei_logits.masked_fill(tril == 0, float('-inf'))  # softmax(-inf) = 0
wei_probs = F.softmax(scaled_wei_logits, dim=-1)  # averages the aggregated context

# xatt = wei_probs @ x
v = value(x)
xatt = wei_probs @ v

print("--- wei_logits[0] ---")
print(wei_logits[0])
print("--- wei_probs[0] ---")
print(wei_probs[0])

--- wei_logits[0] ---
tensor([[ 2.4787,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0791,  0.4195,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.1452, -2.2417,  2.3225,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.3146, -0.9574, -0.6660, -0.9418,    -inf,    -inf,    -inf,    -inf],
        [ 3.0616,  0.9381, -0.8151,  0.1239,  0.7581,    -inf,    -inf,    -inf],
        [-1.2767, -0.7902,  1.0707,  0.4485,  2.8819,  1.7936,    -inf,    -inf],
        [ 1.6953,  0.4527, -0.9625,  1.3604, -2.5567, -2.4997,  2.0441,    -inf],
        [-0.9596, -1.8050,  2.0731, -1.1374,  0.4673,  0.9959, -0.3580, -0.9838]],
       device='mps:0', grad_fn=<SelectBackward0>)
--- wei_probs[0] ---
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4157, 0.5843, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1009, 0.0093, 0.8899, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1726, 0.2467, 0.3

In [13]:
# Notes on attention from Karpathy (the myth, the legend) with some additions by me (yours truly):
# - Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each
#   other and aggregating information with a weighted sum from all nodes that point to them, with
#   data-dependent weights.
#       - I.e. we could apply attention on any directed graph to perform arbitrary calculations.
# - There is no notion of space. Attention simply acts over a set of vectors.
#       - This is why we need to positionally encode tokens.
#       - We have a set of vectors. they communicate. If you want them to have a notion of space (positional
#         information), you have to specifically add it.
#       - Note: this is different from convolution, where positional information is implicitly "baked into"
#         the operation. this is a key difference between the two.
# - Each example across batch dimension is of course processed completely independently and never "talk"
#   to each other.
# - In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all
#   tokens to communicate.
#       - The block we've built here is called a "decoder" attention block because it has triangular masking,
#         and is usually used in autoregressive settings, like language modeling.
# - "self-attention" just means that the keys and values are produced from the same source as queries.
#       - In "cross-attention", the queries still get produced from x, but the keys and values come from some
#         other, external source (e.g. an encoder module)
# - "Scaled" attention additional divides wei by `1 / sqrt(head_size)`. This makes it so when input Q,K are
#   unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much.
#       - Without scaling, the variance will be on the order of `head_size` (so in our case 16).
#       - With scaling, the variance will be on the order of 1.
#       - Note: as the variance of the logits fed to softmax increases, the output of softmax converges to
#         the max function. This is bad because we lose the relationships between the tokens, so we generally
#         want to keep the variance of the logits from "exploding", hence the scaling term.

# Ex. softmax converging to one-hot max
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1))
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1)) # gets too peaky, converges to one-hot max

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872], device='mps:0')
tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000], device='mps:0')
