In [2]:
import torch

In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-01-18 06:50:22--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-01-18 06:50:22 (4.95 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
print('Total number of characters: ', len(text), '\n')
print(text[:100])

Total number of characters:  1115394 

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [6]:
# Get the unique characters
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


### Tokenize
Take the raw text as a string and convert it to some sequence of integers according to a vocabularly of possible elements.

In [7]:
# Create mapping from characters to integers. This is a simple tokenizer that yields a small codebook.
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]           # encoder: takes in a string, outputs a list of integers
decode = lambda l: ''.join([itos[i] for i in l])  # decoder: take a list of integers, output a string

print(encode('hi there'))
print(decode(encode('hi there')))

[46, 47, 1, 58, 46, 43, 56, 43]
hi there


In [8]:
# Tokenize the entire dataset
data = torch.tensor(encode(text), dtype=torch.long)

print(data.shape, data.dtype)
print(data[:100]) # This is what the 100 characters we looked at earlier will look like

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


### Train/Val Split

In [9]:
n = int(0.9 * len(data)) # First 90% is train, rest is val
train_data = data[:n]
val_data = data[n:]

### Notes on Training Transformers
* We never feed the entire text into a transformer all at once. That would be very computationally expensive.
* When we actually train the transformer on these datasets, we only work with chunks of the dataset. We will sample random little chunks out of the training set and train at chunks at a time. These chunks will have some maximum length. This is known as the block size, or context length.
* Below we see the first 9 characters in the train data set. When we sample a sequence of 9 characters like the one below, we must remember that this actually has _multiple_ examples packed into it. That is because all of these characters _follow_ eachother. 
* So, when plugging it into a transformer, we will actually simultaneously train it to make predictions at every single one of these positions. This means for a chunk of `9` characters we have `8` examples packed in there.
* For example, in the context of `18`, `47` likely comes next. In the context of `18, 47`, `56` likely comes next. And so on. 

In [10]:
block_size = 8
train_data[:block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [11]:
# x and y are simply off by 1 (shifted) versions of eachother
x = train_data[:block_size]
y = train_data[1: block_size + 1]
for t in range(block_size):
    context = x[:t + 1]
    target = y[t]
    print(f'When input is {context} the target is: {target}')

When input is tensor([18]) the target is: 47
When input is tensor([18, 47]) the target is: 56
When input is tensor([18, 47, 56]) the target is: 57
When input is tensor([18, 47, 56, 57]) the target is: 58
When input is tensor([18, 47, 56, 57, 58]) the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58


It is worth mentioning that here we train on examples with context between length 1, all the way up to a context with the length of the block size. We train on that not simply for efficiency, but also in order to make the transformer network be used to seeing contexts all the way from as little as 1, all the way to block size. We would like the transformer to be used to seeing everything in between. This will be incredibly useful during _inference_ because while we are sampling, we can start with as little as one character of context. It can then predict everything up to block size, and at that point we will need to start truncating, because the transformer will never receive more than block size input when it's predicting the transformer. 

At this point we have looked at the "**time**" dimension of the data we will be feeding into the transformer. There is one more dimension that we are going to care about and that is the **batch** dimension. The idea of batches is done entirely for efficiency so we can keep our gpus busy :). Note though that each chunk will be processed completely independently, they do not talk to each other. 

In [12]:
torch.manual_seed(1337)
batch_size = 4 # How many indepedent sequences will we process in parallel?
block_size = 8 # What is the maximum context length for predictions?

def get_batch(split):
    # Generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data

    # Get `batch_size` random integers in the range of 0 to len(data) - block size
    # Must subtract block size to ensure any example has a full context window
    ix = torch.randint(len(data) - block_size, (batch_size, ))

    # For each integer, that marks the start of an example. Index into data to grab
    # that integer up to integer + blocksize
    x = torch.stack([data[i:i + block_size] for i in ix])
    
    # Do the same for y, but just shift by 1 (remember this is an autoregressive model)
    y = torch.stack([data[i + 1: i + 1 + block_size] for i in ix])
    
    return x, y


xb, yb = get_batch('train')

print(f'Inputs:\n{xb.shape}\n{xb}')
print(f'Targets:\n{yb.shape}\n{yb}')

Inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
Targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [13]:
xb

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])

Given our `(4, 8)` tensor above, we can see that we actually have `32` examples! These are all completely independent (as far as the transformer is concerned). We can see all 32 independent examples printed below:

In [67]:
for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, 0:t + 1]
        target = yb[b, t]
        print(f"When input is {context.tolist()} the target is: {target}")

When input is [24] the target is: 43
When input is [24, 43] the target is: 58
When input is [24, 43, 58] the target is: 5
When input is [24, 43, 58, 5] the target is: 57
When input is [24, 43, 58, 5, 57] the target is: 1
When input is [24, 43, 58, 5, 57, 1] the target is: 46
When input is [24, 43, 58, 5, 57, 1, 46] the target is: 43
When input is [24, 43, 58, 5, 57, 1, 46, 43] the target is: 39
When input is [44] the target is: 53
When input is [44, 53] the target is: 56
When input is [44, 53, 56] the target is: 1
When input is [44, 53, 56, 1] the target is: 58
When input is [44, 53, 56, 1, 58] the target is: 46
When input is [44, 53, 56, 1, 58, 46] the target is: 39
When input is [44, 53, 56, 1, 58, 46, 39] the target is: 58
When input is [44, 53, 56, 1, 58, 46, 39, 58] the target is: 1
When input is [52] the target is: 58
When input is [52, 58] the target is: 1
When input is [52, 58, 1] the target is: 58
When input is [52, 58, 1, 58] the target is: 46
When input is [52, 58, 1, 58, 46

So, this `x` will be fed into the transformer. The transformer will simultaneously process all of the examples (independently) and look up the correct integers to predict in every one of these positions in the tensor `y`. 

### 1. Simplest Implementation Possible: Bigram Language Model

In [79]:
print(xb) # Our input into the transformer

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


In this case we are predicting what comes next based on just the individual identity of a single token. For this basic model the token are not talking to eachother and they are not seeing any context, they are just seeing themselves. For instance, `xb[0][0] = 5`. Given token `5` you can actually make decent predictions just by knowing that you are token `5`. 

A note on the embedding table below: In our `xb` example above, the first entry `xb[0][0]` is `24`. In that case `24` will be based in and will pluck out the 24th row of the embedding table. Also, note that after using our embedding table our `xb` will go from being `(4, 8)` to `(4, 8, 65)`. In other words, _each entry_ of `xb` will effectively be _mapped_ to a `65` dimensional "embedding". This embedding is really just a vector holding the logits associated with the next character in the sequence we are trying to predict.  

In [114]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337) 

class BigramLanguageModel(nn.Module):
    """Bigram language model class
    
    Notes:
    - B: batch, T: time, C: channel. In this case: B: 4, T: 8, C: 65 (vocab size)
    """

    def __init__(self, vocab_size):
        super().__init__()
        
        # Each token directly reads off the "predicted" logits from a lookup table for what token comes next in the sequence
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        
        # Pluck out rows, arange them in the shape (B, T, C), and interpret them as the logits, which are effectively
        # the scores for which character comes next in the sequence. 
        logits = self.token_embedding_table(idx) 

        if targets is None:
            loss = None
        else:
            # Cross entropy expects (B*T, C). In other words, it wants a 2 dimensional tensor as input where each row 
            # holds a prediction (logits) and it can be matched up with the correct targets. We can think of this as
            # "stretching out" our array
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)


            loss = F.cross_entropy(logits, targets) # Here we are passing in target indices, not one hot encoded probabilities

        return logits, loss

    def generate(self, idx, max_new_tokens):

        # idx is (B, T) array of indices in the current context. E.g. this is xb.
        # This function is meant to take (B, T) and generate subsequent tokens, making it 
        # (B, T+1), then (B, T+2), and so on, up to T+max_new_tokens
        for _ in range(max_new_tokens):
            # Get the predictions
            logits, _ = self.forward(idx)

            # focus only on the last time step (grab last element of time dimension, predictions for what comes next)
            logits = logits[:, -1, :] # becomes (B, C)

            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)

            # Sample from distribution. Get a single prediction for what comes next for each batch dimension
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

            # Append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        
        return idx

m = BigramLanguageModel(vocab_size)

# Remember, nn.Module as a method __call__ that points to forward, so this is a forward pass
logits, loss = m(xb, yb) 

print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)


Because we have 65 vocabulary elements and right now we are effectively predicting them at random, we can expect our loss to roughly be: `-ln(1/65) = 4.17`. Our initial loss is 4.87, meaning our predictions aren't entirely _diffuse_ and they have a bit of _entropy_, hence we are guessing wrong. 

Lets give `generate` a try:

In [115]:
x_sample_input = torch.zeros((1, 1), dtype=torch.long) 
x_sample_input

tensor([[0]])

In [116]:
torch.manual_seed(1337) 

generated_output = m.generate(x_sample_input, max_new_tokens=100)
print(generated_output.shape)
generated_output

torch.Size([1, 101])


tensor([[ 0, 15, 54, 64, 50, 37, 39, 31,  1, 15, 41, 64, 42, 43, 15,  5, 39, 17,
         20, 37,  6, 37, 64, 56, 64, 28, 19, 22, 33, 12, 37, 43, 46, 27, 24, 47,
          0, 15, 48,  0, 15, 64, 53, 36,  3, 49, 33, 58, 31, 55, 13, 33, 33,  5,
         64, 28, 63, 34, 22, 47, 19,  1, 20, 32, 53, 57, 58, 58, 62, 11, 39,  2,
         33, 34, 64, 63, 13, 52, 57, 64,  7, 47, 51, 57, 64, 28, 13, 49, 44, 63,
         39, 11, 10, 45, 33, 11, 30, 32, 26, 14, 19]])

In [117]:
# Convert to list
generated_output = generated_output[0].tolist()

In [118]:
# Make human readable
decode(generated_output)    


"\nCpzlYaS CczdeC'aEHY,YzrzPGJU?YehOLi\nCj\nCzoX$kUtSqAUU'zPyVJiG HTosttx;a!UVzyAnsz-imszPAkfya;:gU;RTNBG"

Yikes - this prediction is awful! But in fairness we just passed it a newline character and our model hasn't been trained at all haha. Not that the way in which we are writing the `generate` function we are not using _any_ of the history when making a prediction-we only use the current character! It is written this way so that when we start working with more advanced models we can effectively keep the `generate` function fixed. So this way we are providing all context to the model and the model gets to decide how much (if any) it wants to use to make it's predictions.

### 1.1 Training the Bigram Model

In [128]:
# Create a pytorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3) # For very small networks we can get away with 1e-3 (larger nets would need 1e-4)

batch_size = 32

for steps in range(10000):

    # Sample a batch of data
    xb, yb = get_batch('train')

    # Evaluate loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())



2.4315860271453857


In [133]:
generated_output = m.generate(x_sample_input, max_new_tokens=500)[0].tolist()

print(decode(generated_output))


IARI he umin,
YORISer Clco? cas tht magic m bou? s-efuril win d fiouby fon, asheche MELor? as.
MED theean G sthith widangushenen, u athoonchar

Aieavece I bu pls thayolste
Thr ld wen tchty' lle, toug ouites 'Wout ath y se tharere gn.
YCH:
Whasonong mpprrde fofan'shano,
Q! by hinomoor, heatis, trmbuty, wofron o,


Wivit t,

An.
BEDUShea
IDokliseathime sethen s winoirepe g,
I borevo ncay! h cowhot'literdicow je wier tinthaloug t s spuer emmom'd thatheny-taweallove waby angeal tethexperd mes thacot


#### A Few reminders about our training process!
First just recall what a single entry of our inputs, mini batch `xb`, and their targets, `yb`, look like. `yb` is simply `xb` shifted by `1`:

In [139]:
print(xb[0])
print(yb[0])

tensor([50, 53, 59, 57,  1, 46, 53, 53])
tensor([53, 59, 57,  1, 46, 53, 53, 42])


Next, recall what is happening during our training process. We specifically: 
1. Get a batch of 32 examples. Remember an example is currently `8` consecutive characters. For each of these characters we know the target (the subsequent character). So, for each of the 8 characters we can "generate" a prediction by keying into the "embedding" table
2. So, for our batch of `32` examples we have `8` examples for each (I know, 'example' is overloaded here). So, that means we really have a total of `256` examples associated with a batch.
3. We generate predictions for all `256` examples and compare them with the true target (remember, the target is simply the token (it's index specifically) that follows that follows our current token).
4. We now have `256` predictions (each prediction is a `65` dimensional array of logits) and targets. We compute the loss on these set of predictions and targets. 
5. We then call `loss.backward()` which computes the gradient and updates our weights accordingly. 

#### Note on data
I understand _why_ the code was written this way. However, it could have been easier to reason about if it was converted into a _tabular_ form and stored in a pandas dataframe. Or is it just that that is what I am used to? 

#### Note on embeddings
I really don't like calling this an embedding, although it is one. The reason is that is doesn't conform with the usual use of the term embedding. Here we are taking our _objects_, the tokens in our vocabulary, and coming up with a `65` dimensional representation of them. Each dimension in this representation space corresponds to one of the tokens. The _value_ associated with each dimension for a particular tokens embeddings represents a sense of _how likely_ this token is to be followed by the token of the associated embedding dimension. 

Now, as we talk through this, calling this an embedding is really not that crazy. We have taken objects, tokens, and embedded them in a continuous space ($\mathbb{R}^{65}$). Say we then have two tokens that have _similar_ embeddings (close by in euclidean space). The interpretation here is that these two tokens tend to be followed by other similar tokens!

# 2. Transformer
### 2.1 Mathematical Trick of Self Attention
We currently have `8` tokens in a _batch_ and they currently are _not_ talking to each other. We would like for them to talk to each other. We would like to couple them. 

We want to couple them in a very specific way. The token in the 5th location should not communicate with tokens in the 6th, 7th and 8th location because those are **future** tokens in the sequence. The token in the 5th location should only talk to the tokens in the 4th, 3rd, 2nd and 1st. So information only flows from the previous context to the current time step. We cannot get any information from the future because we are about to try to predict the future. 

What is the easiest way that we can go about this? The easiest way would be: if we are at the 5th token and we would like to communicate with our past, the simplest thing we can do would be an _average_ of all the preceeding elements. So we can take information from the _channels_ at the current (5th) step, as well as the 4th, 3rd, 2nd and 1st step, and then average those up. Then that would be a feature vector that summarizes the current step (5th) in the context of it's history. 

Now, of course simply doing a sum or average is an _extremely weak_ form of interaction. This communication is extremely **lossy**. We have lost a ton of information about the spatial arangement of those tokens. However, that is okay for now. We will see how we can bring that information back later!

For now we would like to calculate:
* For every single batch element independently 
* For every $t$th token in that sequence
* We would like to calculate the average of all of the vectors in all of the previous tokens and also at this token. 


In [182]:
torch.manual_seed(1337)

B, T, C = 4, 8, 2 # Batch, Time, Channels

x = torch.randn(B, T, C)
x.shape


torch.Size([4, 8, 2])

In [183]:
# Version 1: No vectorization

# We want x[b, t] = mean_{i <= t} x[b, i]
xbow = torch.zeros((B, T, C))  # Calling this x bag of words (since bow is a term used when averaging up things)
for b in range(B):
    for t in range(T):
        xprev = x[b, 0:t + 1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)


In [185]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [155]:
x[0][0:3].mean(axis=0)

tensor([ 0.1490, -0.3199])

In [156]:
xbow[0][0:3]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199]])

In [160]:
# Vectorize
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)

wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [165]:
# For a single example in the batch
wei @ x[0] 

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [166]:
# For all examples in the batch (pytorch is smart with broadcasting)
wei @ x

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

Pytorch will see that these dimensions are not the same so it will create a batch dimension and broadcast (copy) the `(T, T)` matrix, meaning it will be of shape `(B, T, T)`, and hence then be multiplied by a `(B, T, C)`. So, for a given batch, we will have a `(T, T)` matrix multiplied by a `(T, C)` matrix. This will result in a `(T, C)` matrix.

In [167]:
# Version 2: Vectorized 

wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x  # (T, T) @ (B, T, C) ----> Broadcast ----> (B, T, T) @ (B, T, C) ----> (B, T, C)                

In [168]:
torch.allclose(xbow, xbow2)

True

So the trick is that we will use a _batched matrix multiply_ to perform a _weighted aggregation_, where the weights are specified in this `(T, T)` array. We are effectively doing weighted sums, where the token at the $t$th dimension only gets information from the tokens preceeding it. 

We can finally rewrite this in one more way. 

In [171]:
# Version 3: Use Softmax

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))                       # Currently set by us to be 0
wei = wei.masked_fill(tril == 0, float('-inf')) # Tokens from the past cannot communicate with future
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x                                 # Aggregation 

torch.allclose(xbow, xbow3)


True

The reason that we will use this form in **self attention** is that the weights begin as `0`. You can think of this as an **interaction strength** (i.e. an **affinity**). It is telling us _how much of each token from the past do we want to aggregate (i.e. average up)_. 

So in the line:
```
wei = torch.zeros((T, T))       
```

It is currently just set by us to be `0`s. But these affinities between the tokens are not just going to be constant at `0`. They are going to be _data dependent_. These tokens are going to start looking at eachother. And some tokens will find other tokens more or less interesting. Depending on what their values are, they are going to find each other interesting to different amounts (affinities). 

So, the TLDR is: 
> You can do **weighted aggregations** of your past elements by using matrix multiplication of a lower triangular fashion. The elements in the lower triangular part are telling you how much of each element "fuses" into this position. 

## Self Attention
Different tokens will find other tokens more or less interesting, and we want that to be _data dependent_. For instance, say the current token is a vowel. Then maybe it is specifically interested in the consonants it's past and it would like that information to flow to it. 

So, we want to gather information from the past and do it in a data dependent way. This is the problem that **self attention** solves! 

Self attention solves this as follows:
1. Every single node (token), at each position, will emit two vectors: a **query** and a **key**. 
  * The query vector is: what am I looking for?
  * The key vector is: what do I contain?
2. The way that we get _affinities_ between the tokens in a sequence is we do a **dot product** between the keys and the queries. So our query dot products with all of the keys of the tokens, and that dot product now becomes `wei`, below. 
3. If the key and the query are aligned to a high degree they will interact to a high amount, and we will be able to learn more about that specific token as opposed to any other token in the sequence.

Remember: So far we have been saying that we will aggregate information that could be useful for a token, say `5`, via taking the average of the previous tokens in the sequence (e.g. the 4th, 3rd, 2nd and 1st). This average vector will be a feature vector we can use. 

What we are saying now is that the average vector will be a _weighted average_, where the weights are based on how interesting/useful a certain token is with respect to our current token. This is a _learned_ process (i.e. our network learns which tokens a specific current token should focus on!)

Also, we can think of `x` as private information to this token. So we could say: "I'm the 5th token, I have some identity. My information is kept in vector `x`. And (for the purposes of this single head) here is what I am interested in, `query(x)`, here is what I have `key(x)`, and if you find me interesting here is what I will communicate to you `value(x)`. So `v` is the thing that gets aggregated for the purposes of this single head.


In [172]:
# Version 4: Self Attention
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C) # Batch, time, channels

# A single head performing self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# All tokens, in all positions of the (B, T) arangement, in parallel
# and independently, produce a key and a query. So NO communication
# has happened yet
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)

# Communication comes NOW!
# For every row, b in B, we will have a (T, T) matrix giving us the 
# affinities
wei = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) ---> (B, T, T) 

# This is the same as version 3
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # Tokens from the past cannot communicate with future
wei = F.softmax(wei, dim=-1)

v = value(x) 
out = wei @ v       # Aggregation 

out.shape

torch.Size([4, 8, 32])

In [178]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Consider the row below. This was the 8th token. The 8th token knows what **content** it has, and it knows what **position** it is in. Now, based on that, the 8th token creates a query that says: "hey, I'm looking for this kind of stuff - I'm a vowel, in the 8th position, I'm looking for any consonants up to position 4...". Then all of the nodes get to emit keys. Maybe one of the channels could be: "I am a consonant and I am in a position up to 4". Then that key would have a high number in that specific channel. That is how the query and the key when they dot product they can find each other and create a high affinity. 

If they have a high affinity then via the softmax it will end up aggregating a lot of its information into the 8th tokens vector. 

In [179]:
wei[0][-1]

tensor([0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391],
       grad_fn=<SelectBackward0>)

Now, every single batch element will contain different weights/affinities! This is because every single batch element will contain different tokens! 

### Notes
1. Attention is a **communication mechanism**. 
    * Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights. 
    * Put another way: every node has some vector of information and it gets to aggregate information, via a weighted sum, of all the nodes that point to it. This is done in a data dependent manner. 
    * In our scenario the directed graph is very simple. The first node in a single example of our batch (i.e. 8 tokens) will just point to itself. The second node is pointed to by the first node and itself. The third nodes is point to by the 1st and 2nd node, and itself. All the way up to the 8th node which is pointed to by the first 7 nodes and itself. In this way our `tril` matrix effectively acts as a DAG in matrix form. 
    * In principle attention can be applied to _any arbitrary directed graph_! It is simply a communication mechanism between the nodes. 
2. There is no notion of **space**. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens. 
3. Each example across the batch dimensions is of course processed completely independently and they never "talk" to each other. 
4. In the case of language modeling we have a constraint that future tokens will never talk to past tokens. In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate (you may want this to be the case in sentiment analysis). This block we have implemented above is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
5. "self-attention" just means that the keys and values are produced from the same source as queries. In principle, however, attention is much more general than that. For instance, the queries could be based on `x`, but the keys and values could come from an entirely separate source. In "cross-attention", the queries still get produced from `x`, but the keys and values come from some other, external source (e.g. an encoder module)
6. "Scaled" attention additional divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below


In [180]:
encode('mynameis')

[51, 63, 52, 39, 51, 43, 47, 57]

### Other Optimizations
#### Residual Connections
The supervision that we experience from the loss will hop through every node all the way to the input, and also fork off into the residual blocks. We can think of this as a "gradient super highway" that goes all the way to the input, unimpeded. The blocks are initialized in such a way that in the beginning they contribute very little to the loss. But over time they "come online" and start to contribute. However, at initialization you can go directly from the loss to the input, the gradient just flows unimpeded, and over time the blocks start to kick in. 