This project follows this tutorial: https://www.youtube.com/watch?v=kCc8FmEb1nY, and has the goal to practice with the creation of a simple GPT model to generate Shakespeare's like output.

**Reading the input**

As an input, we will use a raw text containing Shakespeare's dialogues

In [1]:
# Start by downloading the file
!curl -o input.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:--  0:00:01 --:--:--     0
  0     0    0     0    0     0      0      0 --:--:--  0:00:02 --:--:--     0
  0     0    0     0    0     0      0      0 --:--:--  0:00:03 --:--:--     0
  0     0    0     0    0     0      0      0 --:--:--  0:00:04 --:--:--     0
  0     0    0     0    0     0      0      0 --:--:--  0:00:05 --:--:--     0
100 1089k  100 1089k    0     0   178k      0  0:00:06  0:00:06 --:--:--  224k


In [2]:
# Read input to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
print('Length of dataset in characters: ', len(text))

Length of dataset in characters:  1115394


In [4]:
# Let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [5]:
# Create a set of all chars in the text, then create a sorted list out of it
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


**Tokenize the characters**

This means converting the raw text, as a string, to some sequence of integers, according to some vocabulary of possible elements.

In [6]:
# Create a dictionary mapping characters to their corresponding indices.
stoi = {ch: i for i, ch in enumerate(chars)}

# Create a dictionary mapping indices to their corresponding characters.
itos = {i: ch for i, ch in enumerate(chars)}

# Encode a string into a list of integers using the character-to-index mapping.
encode = lambda s: [stoi[c] for c in s]

# Decode a list of integers into a string using the index-to-character mapping.
decode = lambda l: ''.join(itos[i] for i in l)

print(encode("Hello world"))
print(decode(encode("Hello world")))

[20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
Hello world


In [7]:
# Encode the entire text dataset and store it into a torch.Tensor
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

**Splitting the dataset into training and validation sets**

An important step is to split the data into train and validation sets. We will use 90% of the data as training, and the rest as validation.

In [8]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

**Data loader**

We'll not feed entire text into transformer all at once, it will be computationally very expensive.
We only work with chunks of the dataset.

In [9]:
block_size = 8 # Chunk size
train_data[:block_size]

tensor([18, 47, 56, 57, 58,  1, 15, 47])

In [10]:
x = train_data[:block_size] # Input to the transformers
y = train_data[1:block_size+1] # Next block_size characters, so it's offset by 1 in comparison to x
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input is {context} the target is: {target}")

When input is tensor([18]) the target is: 47
When input is tensor([18, 47]) the target is: 56
When input is tensor([18, 47, 56]) the target is: 57
When input is tensor([18, 47, 56, 57]) the target is: 58
When input is tensor([18, 47, 56, 57, 58]) the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58


What we've just seen is useful for the transformer: we start the sampling generation with as little as one character of context, so that the transformer knows how to predict the next character with all the way up from just one character up to block size. After block size, we need to start truncating, because the transformer will never receive more than block_size inputs when it's predicting the next character.

**Batch dimension**

We'll feed the transformer with batches of multiple chunks of text, for efficiency reasons

In [11]:
torch.manual_seed(1337) # Set this manually to match the tutorial's samples
batch_size = 4 # How many independent sequences will we process in parallel?
block_size = 8 # What is the maximum context length for predictions?

def get_batch(split):
    # Generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    # Generate batch_size numbers between 0 (inclusive) and (len(data)-block_size) (exclusive)
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # Stack the rows
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x,y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('-----')

for b in range(batch_size): # Batch dimension
    for t in range(block_size): # Time dimension
        context = xb[b, :t+1]
        target = yb[b, t]
        print (f"When input is {context.tolist()} the target is: {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
-----
When input is [24] the target is: 43
When input is [24, 43] the target is: 58
When input is [24, 43, 58] the target is: 5
When input is [24, 43, 58, 5] the target is: 57
When input is [24, 43, 58, 5, 57] the target is: 1
When input is [24, 43, 58, 5, 57, 1] the target is: 46
When input is [24, 43, 58, 5, 57, 1, 46] the target is: 43
When input is [24, 43, 58, 5, 57, 1, 46, 43] the target is: 39
When input is [44] the target is: 53
When input is [44, 53] the target is: 56
When input is [44, 53, 56] the target is: 1
When input is [44, 53, 56, 1] the target is: 58
When input is [44, 53, 56, 1, 58]

In [12]:
print(xb) # Our input to the transformer

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


**Bigram Language Model**

We will start implementing a simple Neural Network, by using the Bigram Language Model.
A Bigram Language Model is a type of statistical language model used in natural language processing and computational linguistics. The main idea behind a Bigram Language Model is to predict the likelihood of the next word in a sequence (i.e., the next token) given the previous word (the preceding token). It assumes that the probability of a word depends only on its immediate predecessor.
Bigram models are straightforward and computationally less intensive compared to more complex language models like n-gram models or neural language models. However, they have limitations due to their local context dependency, which means they cannot capture long-range dependencies or semantics effectively.

Every single integer in our input is going to refer to the embedding table and is going to have a row of the table corresponding to its index.
E.g. integer '24' will be associated to row 24.
Pytorch will arrange then this in a B x T x C (Batch x Time x Channels) tensor.
Batch = 4
Time = 8
Channels = 65 (it's the vocab_size)
This tensor will contain the logits (=scores) for the next character in the sequence


In [13]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # Each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C) tensor, containing the logits (=scores) for the next character in the sequence

        if targets is None:
            loss = None
        else:
            # Reshaping the logits because the cross_entropy function expect a (B,C,T) matrix
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # Stretch the tensor to be 2-dimensional
            targets = targets.view(B*T) # Stretch the tensor to be 1-dimensional
            loss = F.cross_entropy(logits, targets) # Negative log likelihood loss function. This measure the quality of the logits w.r.t. the targets

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            # Get the predictions
            logits, loss = self(idx)
            # Focus only on the last time step (so we're ignoring all the previous token, but just considering the very last one)
            logits = logits[:, -1, :] # Becomes (B,C)
            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B,C)
            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # Append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb) # Pass the inputs and the targets
print(logits.shape)
print(loss)

idx = torch.zeros((1,1), dtype = torch.long) # This 1x1 tensor contains 0, that will be used to kick off the generation. Remember that in our vocabulary 0 is the ' ' character
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist())) # Ask 100 tokens

torch.Size([32, 65])
tensor(5.0364, grad_fn=<NllLossBackward0>)

l-QYjt'CL?jLDuQcLzy'RIo;'KdhpV
vLixa,nswYZwLEPS'ptIZqOZJ$CA$zy-QTkeMk x.gQSFCLg!iW3fO!3DGXAqTsq3pdgq


**Train the model**

In [14]:
# Create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [15]:
batch_size = 32 # Update batch_size to 32 instead of 4
for steps in range(10000):
    # Sample a batch of data
    xb, yb = get_batch('train')
    # Evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.5589075088500977


In [16]:
print(decode(m.generate(idx, max_new_tokens=1000)[0].tolist())) # Ask 1000 tokens


Ong h hasbe pave pirance
RDe hicomyonthar's
PES:
AKEd ith henourzincenonthioneir thondy, y heltieiengerofo'dsssit ey
KINld pe wither vouprroutherccnohathe; d!
My hind tt hinig t ouchos tes; st yo hind wotte grotonear 'so itJas
Waketancotha:
h hay.JUCLUKn prids, r loncave w hollular s O:
HIs; ht anjx?

DUThinqunt.

LaZEESTEORDY:
h l.
KEONGBUCHandspo be y,-JZNEEYowddy scace, tridesar, wnl'shenous s ls, theresseys
PlorseelapinghienHen yof GLANCHI me. strsithisgothers jveere!-e!
QUCotouciullle's fldrwertho s?
NDan'spererds cist ripl chyreer orlese;
Yo jowof h hecere ek? wferommot mowo soaf you f;
Ane his, t, f at. fal whetrimy bupof tor atha Bu!
JOutho fplimimave.
NEDUSt cir selle p wie wede
Ro n apenor f'Y tover witys an sh d w t e w!
CilttiretoaveE IINGAwe n ck. cung.
ORDUSURes hacin benqurd bll, d a r w wistatsowor ath
Fivet bloll ang aror;
ARKIOULemee tsce larry t I Ane szF t
LCay thit,
n.
Faure ds ppplirn!
Whotou ow pyofalondrwist th;thomayo war gmenco, An he waro whiougou he s imaro

**The mathematical trick in self-attention**

In [17]:
# Consider the following toy example:
torch.manual_seed(1337)
B, T, C = 4,8,2 # Batch, Time, Channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

We need to make sure that tokens only communicate with previous tokens, and not with the one that are in the future.
E.g. token number 5 needs to communicate to token number 1,2,3,4

In [18]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C)) # bow = bag of words
for b in range(B): # Iterate over all batches
    for t in range(T): # Iterate over time
        xprev = x[b,:t+1] # Previous tokens are up to, and including, the t-th token # It has shape (t,C)
        xbow[b,t] = torch.mean(xprev, 0) # Average over the 0-th dimension, that is time. This will be a c-shaped dimensional vector

In [19]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [20]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

The first element is equal between x[0] and xbow[0], because we are doing an average of a single element.
But the 2nd element of xbow is an average between the 1st and the 2nd element of x[0].
The last element will be the average of all the elements of x[0].

This is inneficient, so we need a mathematical trick to use Matrix multiplication. Let's start with an example.

In [21]:
torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()
c = a @ b # Normal matrix multiplication
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [22]:
torch.tril(torch.ones(3,3)) # Trill = Triangular, it will return the lower triangolar portion of the original matrix

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

What if we apply that to the previous example?

In [23]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
b = torch.randint(0,10,(3,2)).float()
c = a @ b # Normal matrix multiplication
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


c contains, in the first row, exactly the first row of b.
On the 2nd row, we will have (2+6) and (7+4).
On the 3rd row, we will have (2+6+6) and (7+4+5).
So we are just summing the element of b by incrementing its number one by one.
We can use this approach to also calculate the average in an incremental fashion.
We just need to normalize the rows of a in order to always have sum = 1.


In [24]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b # Normal matrix multiplication
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


This is exactly what we did before with the two for cycle, so we can use this approach to have a more efficient code.

**Version 2: using matrix multiply**

In [25]:
wei = torch.tril(torch.ones(T,T)) # Wei = weights
wei = wei / wei.sum(1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [26]:
xbow2 = wei @ x # (T,T) @ (B,T,C) -> PyTorch will transform the first matrix in (B,T,T) to match the dimensions. So, for each batch element, there will be a (T,T) multiplying a (T,C), exactly as we had in the previous for cycle, because it will return a (B,T,C) matrix
torch.allclose(xbow, xbow2) # checks whether all elements of two tensors, xbow and xbow2, are close within a certain tolerance. It returns a boolean value indicating whether the condition is true for all elements or not.

True

In [27]:
xbow[0], xbow2[0]

(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

**Version 3: adding SoftMax**

Softmax it's also like a normalization operation.
So you get the same exact matrix as before, if you apply it on rows that contains -inf like below.
This because we will apply the exponential function to each element of the row and we will divide by the sum.
So, for example on the first row, the sum that we will get will be:
e^0 + e^(-inf) +e^(-inf) + ... = 1 + 0 + 0 + ... = 1
While the row will look like:
1 0 0 0 0 0 0
So if we divide the row by its sum we will just have the same row as before.
If we apply it on the 2nd row, we will have 2 as a sum, so we will just have:
0.5 0.5 0 0 0 0 0
etc...

Why is this more interesting than v2?
Because we initialize wei as 0. We can see it as the 'interaction strenght', or like a affinity.
It's saying how much of each token from the past do we want to aggregate and average up.
When we set them to -inf, we are saying that token from the past cannot communicate with token from the future.
This is a preview for self-attention, because in the next versions we will take into account the affinity between tokens.

In [28]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) # For all the elements where tril is 0, make them being -inf in the wei matrix
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [29]:
wei = F.softmax(wei, dim=-1) # Take a softmax on every single row.
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [30]:
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

**Version 4: self-attention**

Now it's time to make each token smart, by using a self-attention matrix. Every token, at each position, will emit two vectors:

1. Query vector -> What am I looking for?
2. Key vector -> What do I contain?

The way we get affinity between these tokens is the dot product between the query vector of my current token and all the other key vectors.
That dot product will become `wei`.

We'll implement what is called a single head of self-attention.

**Notes about self-attention**

1. It is like a communication mechanism, where you have a number of nodes in a directed graph, where you have edges pointing between them. Every node has some vector of information, and it gets to aggregate information via a weighted sum from all the nodes that point to it. This is done in a data dependent manner, so depending on the data stored in the nodes.
In our case, we will have 8 nodes (because we are using T = 8), and the first node will only point towards itself. The second node is pointed by the first node and by itself, etc...
2. There is no notion of space, so attention acts over a set of vectors. So the nodes have no idea where they are positioned in the space. That's why we need to encode them positionally and give them some info that is anchored to a specific position, so that they know where they are. This is different from convolution, because the convolutional filters kind of act like space.
3. The elements across the batch dimension, which are independent samples, never talk to each other. So, it's like if we have 4 (because we are using B = 4) different direct graphs that have 8 independent nodes.
4. In our direct graph, future tokens will not communicate to the past tokens. This is not necessarily the constraint in the general case. In many cases all the nodes talk to each other, for example with sentiment analysis.
5. It's "self-attention" because query, key and value are all coming from the same source, that is `x`. In principle, attention is much more general than that, so you can have a case where query comes from x, but keys and values are from a separate source. This is called "cross-attention".
6. "Scaled" attention (from the original paper) additionally divides `wei`by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, `wei` will be unit variance too and Softmax will stay diffuse and not saturate too much.

In [35]:
torch.manual_seed(1337)
B, T, C = 4,8,32
x = torch.randn(B,T,C)

# Let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
# Forward the linear on top of x
# All the tokens in all of the positions in the BxT arrangement, in parallel and independently, produce a key and a query. No communication has happened yet.
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)

# Communication will happen here, with the dot product
wei = q @ k.transpose(-2 , -1) # Transpose the last two dimension, because we need it for the dot product (and the first dimension is the batch) (B,T,16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
v = value(x) # v is the vector that we aggregate, instead of the raw x
out = wei @ v
out.shape

torch.Size([4, 8, 32])

In [37]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)