# "Let's build GPT: from scratch, in code, spelled out" Andre Karpathy lecture

My own implementation of a neural network that generates Shakespearean text using a generative pre-trained transformer, using PyTorch.

Based on https://www.youtube.com/watch?v=kCc8FmEb1nY.

In [46]:
import torch
import torch.nn as nn
from torch.nn import functional as F

import math

# Load Shakespeare text and tokenize

Load the complete set of Shakespeare's works.

In [57]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [30]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f'Length of dataset in characters: {len(text)}')

Length of dataset in characters: 1115394


In [31]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



Our model will work - have its tokens be - at the level of characters (i.e., not subwords or words). How many unique characters do we have? This'll be the size of our 'vocabulary' and the number of output layers we feed to the softmax (i think).

In [32]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


Now, we need to translate our input to a sequence of numbers. This is tokenization. In most LLMs tokens are sub-words, so there's a step to find those sub-words, but we're not doing that here - our vocab is just characters. Translating to a numeric sequence then is just getting a numeric value for each char.

Remember the tradeoff here: shorter numbers of token/sequences of numbers for a given amount of input text means larger vocab sizes (GPT-4 has a vocab size of 100K) while longer sequences of numbers go with smaller vocab sizes, like we have here. That is, just taking indiv characters as tokens make for a small vocab size but longer sequence/context lengths, which is ok to keep things simple when learning.

In [33]:
# two hash tables to map between chars and numbers
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
# and funcs to map a full sequence of chars or numbers to the other side
encode = lambda s: [stoi[c] for c in s] # given an input stream, convert to a list of ints
decode = lambda l: ''.join([itos[i] for i in l]) # given a list of ints, convert to a string

print(encode('hii there'))
print(decode(encode('hii there')))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


Now tokenize the full input text data.

In [9]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [10]:
# set aside test data
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

## Explore multiple training examples from a single sequence of text

For training, we use chunks of tokens. Karpathy calls the size of the chunks 'block size', but that's aka for 'context length', among other names.

Also, fundamentally we define the ground truth as the next token after a set of tokens. If we have a block size of eight, and we get a single block (a sequence of eight numbers/chars) then we actually have multiple training examples - taking the first char, the right answer is the second char; taking the first two chars, the right answer is the third character, etc.

In [11]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [12]:
x = train_data[:block_size]
y = train_data[1:block_size + 1]
x, y

(tensor([18, 47, 56, 57, 58,  1, 15, 47]),
 tensor([47, 56, 57, 58,  1, 15, 47, 58]))

In [14]:
for t in range(block_size):
    context = x[:t + 1]
    target = y[t]
    print(f'when input is {context} the target is: {target}')

when input is tensor([18]) the target is: 47
when input is tensor([18, 47]) the target is: 56
when input is tensor([18, 47, 56]) the target is: 57
when input is tensor([18, 47, 56, 57]) the target is: 58
when input is tensor([18, 47, 56, 57, 58]) the target is: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58


And, add in the batch dimension, because when we train, for efficiency/use of GPU parallelism - we'll train on multiple chunks of text/multiple blocks at the same time.

In [None]:
torch.manual_seed(1337)
batch_size = 4 # number of independent sequences we'll process in parallel
block_size = 8 # max context length for predictions

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) # get four random starting points (ints), where each has at least a full eight chars after
    x = torch.stack([data[i:i + block_size] for i in ix]) # gen a 2D tensor of size (4, 8) with each row being the eight chars for the random chunk
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix]) # gen a 2D tensor of size (4, 8) that's shifted one char to the right - each element is the target for the prev chars
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('-----')

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
-----


In [21]:
# show all the train/target examples we can get from the prev (4, 8) tensors
print(f'{batch_size} batches with {block_size} tokens in each gives us {batch_size * block_size} indiv training examples')
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t + 1]
        target = yb[b, t]
        print(f'when input is {context.tolist()} the target: {target}')

4 batches with 8 tokens in each gives us 32 indiv training examples
when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53, 56, 1, 58] the target: 46
when input is [44, 53, 56, 1, 58, 46] the target: 39
when input is [44, 53, 56, 1, 58, 46, 39] the target: 58
when input is [44, 53, 56, 1, 58, 46, 39, 58] the target: 1
when input is [52] the target: 58
when input is [52, 58] the target: 1
when input is [52, 58, 1] the target: 58
when input is [52, 58, 1, 58] the target: 46
when input is [52, 58, 1

According to Karpathy, it sounds like we feed in the two (4,8) 2D tensors (matrices) to the model and (on a GPU) PyTorch processes the 32 combinations of inputs and targets simultaneously. I'm curious to see how this is embodied in torch code.

In [23]:
xb

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])

# Bigram language model

Starting super simple, try a bigram language model, where I think we predict the next token/character here based only on the previous token. (There's more on bigram models in the makemore series.)

In [74]:
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()

        # we want to learn an embedding/vector for each token 
        # each token directly reads the logits for the next token from a lookup tbl
        # given a token, we need the probabilities for each other token, so the size is (vocab_size, vocab_size) 
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensors of integers - they're our xb and yb from above
        # since idx is (B, T) and we want an embedding vector for each element of that 2D matrix, I think that's where the 'C' comes in
        # i.e., if each embedding vector is of size 10, we get back (4, 8, 10) - I think
        # also, while conceptually I think of an embedding as a representation in N-dimensional space for a given token, here we specifically take the embedding 
        # to be the probability of every possible token given the current/a single token - so C/size of the embedding
        # vector is the same as the vocab_size - i.e., ultimately 'logits' is (4,8,65)
        logits = self.token_embedding_table(idx) # (batch, time, channel) 

        # logits is a prediction, and we want to evalaute the goodness of that prediction
        # cross entropy (aka negative log likelihood) is a good way to eval that for multi-class classification probs
        # note that it expects logits - i.e., unnormalized scores not normalized probabilities
        # conceptually, since logits is (4,8,65) (length 65 vector for each of 4*8=32 chars/numbers) it makes sense that targets is (4,8) (32 correct answers one for each char/number)
        # also, torch cross_entropy expects multi-dim input to be (B, C, T) not the (B, T, C) we have
        # so, to keep simple we'll just flatten the 4,8 2D matrix into a single 1D vector of length 4*8
        if targets is None:
            loss = None # enable calling of forward for generation w/o needing to calc loss
        else:
            B, T, C = logits.shape # size of each dimension, i.e., 4, 8, 65
            logits = logits.view(B*T, C) # matrix of 32 rows each with 65 cols, each w/ logit for a particular char/number
            targets = targets.view(B*T) # vector of length 32 each with a char/number for the correct char, I think could also use .view(-1) as the param and let torch infer shape?
            #print(logits.shape, targets.shape)
            loss = F.cross_entropy(logits, targets) # (32,65)

        return logits, loss # (32, 65) tensor and then a scalar (the loss, a floating point number)
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context - i.e, all 4*8=32 numbers
        for _ in range(max_new_tokens):
            # note that we feed in the full context for each seq even though we only use the 
            # prev/last char of the seq, since we're starting w/ a bigram model - but in the future we'll want to use the full context, so we'll keep things complicated here for now
            # get the predictions - i.e., prob of each of 65 chars; again for all 32 values
            logits, _ = self(idx) # logits is (4,8,65)
            # focus only on the last time step - go from (4,8,65) to (4,65), 65 logits for the last char in each of the four sequences in the batch
            logits = logits[:, -1, :] # becomes (B, C), -1 again is Python slice notation for last element
            # apply softmax to get probabilities, one prob for each of the 65 chars
            probs = F.softmax(logits, dim=-1) # (B, C) - -1 says apply over the last dimension, which here (and usually) is the 'features or class' dimension
            # sample from distribution, go from (4, 65) to (4, 1) with a predicted num/char for each sequence
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to running sequence - on the first iter of this loop, append the 
            # (4,1) 1D matrix to the (4,8) matrix to make a (4,9) matrix, then (4,10), etc.
            # dim=1 means to concat along the second dimension - i.e., the time dimension here, which is why we get (B,T+1) 
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        
        return idx


m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb) # this actually calls the forward method, PyTorch impls __call__ and does stuff before/after calling forward
print(logits.shape, loss.shape, loss)

# now generate some chars using the model
# start with a single sequence and one char in that seq w/ number 0 (which is newline)
start_idx = torch.zeros((1, 1), dtype=torch.long) 

def generate_some_text(m, max_new_tokens):
    return decode(m.generate(start_idx, max_new_tokens=max_new_tokens)[0].tolist())

print(generate_some_text(m, 100))

torch.Size([256, 65]) torch.Size([]) tensor(4.7462, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [49]:
-math.log(1/65)

4.174387269895637

Since we haven't trained the model yet, we'd guess that our loss should be -ln(1/65), since we have 65 possibilities and so the probability that we pick the correct output is 1/65. Our loss is 4.88, which is a bit higher than the expected 4.17, which just says that our predictions aren't perfectly uniform, which is also not surprising given that they're random. (Karpathy says 'our initial predictions aren't super diffuse - they have a bit of entropy'.)

# 35m - start of training

The optimizer object takes the parameters and updates the gradients. In makemore we typically used the simpler stochastic gradient descent optimizer, but here we'll use a more powerful and more typically used Adam optimizer. 


NOTE to future self: if in playing around the loss isn't going down, run the above code to create the model instance in m, and then don't forget to run the optimizer and training loop. I think once when I saw the loss not changing, I might have forgotten to recreate the optimizer - this is a result of keeping the structure of this notebook like Andrej's notebook vs rewriting the code to make sure it reruns when it should automatically.

In [81]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3) # pretty high LR, ok for very small networks like this (3e-4 is better normally)

In [97]:
batch_size = 32 # a lot larger than the four from above, for actual training

for steps in range(10000):
    xb, yb = get_batch('train')

    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())
print(generate_some_text(m, 100))

2.5055267810821533

DAnerery lgens w I tel I aidok I otot raleckin; y, cos st'lllig non thinde s g; st.
Tofan tan yed d 


In [98]:
print(generate_some_text(m, 300))


S hed


Y:
Thivitef ve bes.
t, arde!
INouou f beth rakeruthicol IZAneelst dar ir pou t mut orig
ORINo
O:
An w; Bend:'l engined LAthy, wes,

Wh qur henetheto d bouthele fitrttar myortoksord wileve DUCle,
To the wndetheg thatheerutube bed tise frst an to frd antil sw 'Tul astrshue eingiryous sun d har


He notes in the lecture that yeah we're getting output that's better than random, but it's obviously still not Shakespeare, and says that's because the prediction's not using all of what's come before - it's only using one char, and that's just not enough information to predict anything reasonable. So, we'll move to making a more complex model that looks at more of the context, by moving toward a transformer model rather than just a bigram model. 

At this point, Karpathy moved to using a script in bigram.py, and I'll do the same.

# "The mathematical trick in self-attention"

At about 42.5m in the video.

In [104]:
# a toy example

torch.manual_seed(1337)
B, T, C, = 4, 8, 2 # batch, timestamp (seq location), channels (some information)
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [None]:
x[0] # (8, 2) since we take the zeroth element of something that's (4, 8, 2)

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

So, at some fundamental level, what we want to do is be able to take into account information that exists in other tokens, and not all other tokens, but just tokens that came 'before' the current token in 'time'. So, if we're the fifth token, we want to take into account information from tokens 0-4, but not from tokes 6+ (at a minimum, we can't, because they won't have been generated yet).

What's a super simple way of getting information from previous tokens? Take an average of the previous token values and use it.

In [None]:
# here's an inefficient loop-based way to calc, showing shapes
# we'll do a more efficient I presume tensor-based version below

# we want x[b, t] = mean_{i <= t} x[b, i]
xbow = torch.zeros( (B, T, C) ) # "bow" because it's a 'bag of words' (where 'word' here is the item in data strucutr i think - for us, a char/number, I think) 
for b in range(B):
    for t in range(T):
        xprev = x[b, :t + 1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0) # mean of zeroth/time dimension -> [2] - 1D tensor/vector of size 2)

In [101]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [102]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In the above, vector 0 of size two is the average of just that first element of that sequence, so the same; the second vector is the average of the first two vectors, etc. (element-wise).

And now, the 'trick' is that we can very efficiently do the same operation using matrix multiplication. 

In [116]:
torch.manual_seed(42)
# a = torch.ones(3, 3) # each element of c is the sum of all the cols of b
a = torch.tril(torch.ones(3, 3)) # the lower triangle of the square matrix, now each el of c is the running total (only the last row of c is the total of all vals in the col)
a = a / torch.sum(a, 1, keepdim=True) # sum over second dim keeping same # of dims, now instead of [1, 1, 0] (second row) we have [0.5, 0.5, 0] and for third row instead of [1,1,1] we have [0.333,0.333,0.333] -> this gives us the mean of the particular set of vals like we calculated w/ loops above
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a')
print(a)
print('-----')
print('b')
print(b)
print('-----')
print('c')
print(c)

a
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
-----
b
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
-----
c
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


So, using the technique above, we can vectorize/make more efficient the calculation of the averages of the prev+current elements, w/o loops.

In [None]:
# weights array - how much of each row we want to average up
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [None]:
# This is (T, T) @ (B, T, C); since these don't line up, torch creates an extra
# B dimension to form (B, T, T) @ (B, T, C) -> (B, T, C)
# he calls this a 'batched matrix multiply'
# (I'm not sure exactly what's going on w/ the extra dim and how the matrix mult works w/ the three dims)
xbow2 = wei @ x 
xbow2.shape

torch.Size([4, 8, 2])

In [None]:
torch.allclose(xbow, xbow2) # via loops and via matrix mult are the same

True

And, here's a third way to get xbow, with softmax.

In [None]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # make all elements where tril is zero make them -inf
wei = F.softmax(wei, dim=-1) # normalize -> softmax exponentiates each element, sums, and divides by sum -> where we exp 0 we get 1 and where we exp -info we get zero, so we sum some number of 1s and divide those items by that sum, which is the same op we did above w/ xbow2
xbow3 = wei @ x

torch.allclose(xbow, xbow3)

True

The third approach above is more interesting and what we're going to use in our implementation. That's because the 'wei' element becomes 'interaction strengths' or 'affinities' - i.e., the amount by which we want to 'listen' to each token from the past. Above we start with all zeros (saying we want to listen equally w/ other elements we want to listen to, i think; and -inf saying 'we can't listen to the past'), but what we'll do in the future is learn these/train these based on the data. That is, 'some tokens will find other tokens more or less interesting/have less or more affinity for other tokens.' 

In summary, you can use get weighted averages of past elements by matrix multiplication approach above with lower triangular matrix elements, and the actual values in the lower triangular portion say how much each element should be matter, and these 'how much' numbers are trained via the data. 

# "Small self-attention for a single individual head"

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32 
x = torch.randn(B, T, C)

# single self-attention head
head_size = 16
key = nn.Linear(C, head_size, bias=False)   # key vector: 'what do i contain'
query = nn.Linear(C, head_size, bias=False) # query vector: 'what am I looking for'
value = nn.Linear(C, head_size, bias=False) # what we actually aggregate - 'what i will communicate to you if i find you interesting'

 # to get the affinities - wei, do a dot product between the query and keys  
 # when the key and query are aligned, the affinity will be higher for those tokens 

k = key(x)   # (B, T, 16) head size is 16
q = query(x) # (B, T, 16)
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T), for every batch we have a (T, T) square matrix that's the affinities I think between the tokens

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros( (T, T) )
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
#out = wei @ x

out.shape

torch.Size([4, 8, 16])

So, rather than using the all zeros/all ones as we did above, which means we take an average of all of the tokens, we want each token to be able to pay attention to other particular tokens individually and at different levels, where the degree to which it 'listens' depends on the data. One example he gave of this was if a token is a vowel, then positions where consonants are more likely to be might warrant higher attention.

In [129]:
# now wei is not a constant, instead it's different for each batch because each
# batch has different tokens at diff positions - i.e., it's data-dependent
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

His specific additional notes, from the lecture:

- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [139]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5

In [140]:
k.var()

tensor(1.0104)

In [141]:
q.var()

tensor(1.0204)

In [142]:
wei.var()

tensor(1.1053)

In [None]:
# example of how 'sharpening' inputs to softmax makes the result less diffuse/more peaked
# we don't want peaked because it means we'll aggregate info from just one node vs. many 
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1))
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]) * 8, dim=-1))


tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])


At 1:19m, start of updating .py code for self-attention