# Transformer Explainer

The text notes explain the logic prior to key code segments and serve as notes to build intuition about the critical components of the transformer implementation.

In [None]:
import torch
import torch.nn as nn
import os

from torch.nn import functional as F

In [None]:
path = "gpt/data"
os.makedirs(path, exist_ok=True)

In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -P gpt/data

--2025-01-10 00:19:58--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘gpt/data/input.txt.4’


2025-01-10 00:19:58 (4.82 MB/s) - ‘gpt/data/input.txt.4’ saved [1115394/1115394]



In [None]:
path = os.path.join(path, 'input.txt')

In [None]:
with open(path, "r", encoding='utf-8') as f:
  text = f.read()

In [None]:
print(f"Length of dataset in characters: {len(text)}")

Length of dataset in characters: 1115394


In [None]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [None]:
# Get all the unique characters in the dataset.
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(f"Length of vocabulary: {vocab_size}")


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Length of vocabulary: 65


## Tokenization
A more detailed explaination of the tokenization strategy is covered in the `Tokenizer` repo. This is a quick summary of relevant concepts. \
\
Tokenization is a the process of breaking down the string into components and assign each component a specific integer token. We could assign each word a token, so `"Hello"` is `1` and `"World"` is `2`. Or we could assign each character a token so `'a'` is `1` and `'b'` is `2`. \
\
Real world implementations of tokenizers use a **"sub-word"** tokenization strategy, so common words form a single token but uncommon words are broken up into multiple tokens. \
\
Tokenization strategy determines the size of the vocabulary and the length of the tokenized-string. For instance, for the string `"Hello world"`, a word level tokenization has 2 tokens, but for a character level tokenization has 11 tokens. The size of the word level strategy is much larger than that of character level tokenization. \
\
The right tokenization strategy balances between the size of vocabulary and the length of the tokenized input. \
\
In this explainer, to keep things simple, we use a character level tokenizer.

In [None]:
# Tokenization strategy - map each character to an integer.
stoi = { ch:i for i, ch in enumerate(chars)}
itos = { i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]  # Take an input string an return the character indices.
decode = lambda l: ''.join([itos[i] for i in l]) # Take a list of character indices and return the string.

print(encode("Hello, World!"))
print(decode(encode("Hello, World!")))

[20, 43, 50, 50, 53, 6, 1, 35, 53, 56, 50, 42, 2]
Hello, World!


In [None]:
# Encode input in torch.Tensor

data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])  # Tokens of the first 1000 characters.

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [None]:
# Train and validation split.
_TRAIN_RATIO = 0.9
n = int(_TRAIN_RATIO * len(data))
train_data = data[:n]
validation_data = data[n:]

## Context Length
Training the transformer on all of the available data at once is computationally prohibitive. Instead, the input data is randomly sampled for sequences of a fixed maximum length. This is referred to as the **context length** (`block_size` in the code below.) \
\
Each sequence sampled from the input contains multiple training examples. Specifically, a length `N` sequence will contain `N-1` training examples. This is because the token at each position serves as a label for the sequence of tokens preceeding it. So for the token sequence `[1, 2, 3, 4]`:

  *   `2` is the label for input `[1]`,
  *   `3` is the label for input `[1, 2]` and
  *   `4` is the label for input `[1, 2, 3]`.

\
Since there are no tokens preceeding the first token (in this case `1`), the first token is not used as a training label. \
\
Besides the computational efficiency, there's another reason why this approach is useful - it also helps to train the transformer of input sequences of varying length (1 to `block_size`.) This is useful in inference when we can use the transformer to generate sequences with an input of as little as 1 token. \
\
In out implementation, `block_size` represents the maximum input size for the transformer. We will sample inputs of length `block_size + 1` so our final example hits the maximum limit for input size.



In [None]:
_BLOCK_SIZE = 8
train_data[:_BLOCK_SIZE+1]


tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

## Batch Size
As with rest of ML, we batch the data together so different batches can be parallely trained on. Here we batch each sample from the input in its own batch. Out input data now has the following dimensions:
*  Batch Size (B): The number of sequences sampled from input text, processed parallelly.
*  Block Size (T): The number of tokens in each sampled sequence, to the maximum of `block_size + 1`.
*  Examples: Each sampled sequence will contain `block_size` training examples as explained in the Context Length section. \
\
This results in an input tensor of size `batch_size X block_size`. The labels are of the same size (for a samples sequence of length N, inputs are `[0:N-1]` and labels are `[1:N]`.)

In [None]:
torch.manual_seed(1337)
_BATCH_SIZE = 4
_BLOCK_SIZE = 8

def get_batch(split):
  '''Returns a batch of inputs x and target y.'''
  data = train_data if split == "train" else validation_data
  idx = torch.randint(len(data) - _BLOCK_SIZE - 2, (_BATCH_SIZE,))
  x = torch.stack([data[i:i+_BLOCK_SIZE] for i in idx])
  y = torch.stack([data[i+1:i+_BLOCK_SIZE+1] for i in idx])
  return x, y

In [None]:
xb, yb = get_batch("train")
print('Inputs:')
print(xb.shape)
print(xb)
print('Targets:')
print(yb.shape)
print(yb)

Inputs:
torch.Size([4, 8])
tensor([[56, 41, 46,  1, 53,  5,  1, 58],
        [14, 33, 15, 23, 21, 26, 19, 20],
        [58, 46, 43, 56, 10,  0, 14, 59],
        [47, 60, 43,  1, 51, 43,  1, 58]])
Targets:
torch.Size([4, 8])
tensor([[41, 46,  1, 53,  5,  1, 58, 46],
        [33, 15, 23, 21, 26, 19, 20, 13],
        [46, 43, 56, 10,  0, 14, 59, 58],
        [60, 43,  1, 51, 43,  1, 58, 46]])


## Embedding Table
Embedding tables are a common approach in NLP (and other ML fields) to reduce dimensionality by mapping higer dimensional spaces to lower dimensional spaces. *Dimension* here just means the number of elements in the feature vector. \
\
As instance, if we were to use one-hot encoding to represent a movie in the entire universe of movies, the length of the feature vector would be in millions. However, we could represent each movie as a 3-dimensional vector of [`runtime_in_minutes`, `imdb_rating`, `genre`], which would be a much more feasible model. \
\
Refer to [this ~2-min video by Google](https://www.youtube.com/watch?v=my5wFNQpFO0&t=1s) for a more detailed explaination on embeddings. \
\
In our example, each token is represented with an embedding table of size `vocab_size`. While this is the same dimension as one-hot encoding the tokens, this allows us to use the embedding as a probability function - the value at each position will indicate the probability of the corresponding token to be appear next. In the future sections when we don't use the output of the embedding table as the final distribution, this size will change.

### Dimensions
With each token being represented by its embedding, we've added a new dimension to our data:
*  Channels (C): The `vocab_size` representation of the token.


In [None]:
class BigramLanguageModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

  def forward(self, idx, targets=None):
    # Next token prediction for token idx.
    logits = self.token_embedding_table(idx)

    if targets is None:
      loss = None
    else:
      # Compute loss - negative log liklihood.
      # Expected loss without training - -ln(1/65) since liklihood of correct
      # prediction is 1/65 without training.
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
    return logits, loss

  def generate(self, idx, max_new_tokens):
    """Generate the next token given the context idx.

    For the BigramModel, this generation only needs the last character, but this
    function uses the entire history - i.e. all the concatenated inputs. This
    will be kept constant and used in GPT when we use all the history.
    """
    for _ in range(max_new_tokens):
      # Get predictions.
      logits, loss = self(idx)  # Defaults to self.forward(). Make targets optional in forward().

      # Get the output of the last time-step.
      logits = logits[:, -1, :]  # becomes (B, C)

      # Softmax to get probabilities.
      probs = F.softmax(logits, dim=-1) # Still (B, C)

      # Generate a sample from the distribution of tokens.
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

      # Append prediction to index.
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

## Sample
We can use the untrained model created above to generate some text and compute the Cross-Entropy loss. Since it's not been trained yet, we expect it to produce garbage data and can also anticipate the error (based on random prediction over 65 possibilities.)


In [None]:
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
generated_text = decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist())
print(f"Loss: {loss.item()}")
print(f"Generated Text:{generated_text}")

Loss: 4.404358386993408
Generated Text:
l-QYjt'CL?jLDuQcLzy'RIo;'KdhpV
vLixa,nswYZwLEPS'ptIZqOZJ$CA$zy-QTkeMk x.gQSFCLg!iW3fO!3DGXAqTsq3pdgq


In [None]:
# Create an optimizer.
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

_BATCH_SIZE = 32
for steps in range(10000):
  # Sample a batch of data.
  xb, yb = get_batch("train")

  # Evaluate the loss.
  logits, loss = m(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()
  if steps % 1000 == 0:
    print(f"Loss: {loss.item()}")

Loss: 4.581765174865723
Loss: 3.638050079345703
Loss: 3.1831507682800293
Loss: 2.703016996383667
Loss: 2.612287759780884
Loss: 2.675128221511841
Loss: 2.5464842319488525
Loss: 2.5170984268188477
Loss: 2.445920944213867
Loss: 2.4285647869110107


## Sample
Here's the loss and a sample of the generated text after training. \
\
**NOTE:** Here we're only using a Bigram model, which is only looking at character pairs. So while the input is the entire context length, only the last token is used to predict the next token.

In [None]:
generated_text = decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist())
print(f"Loss: {loss.item()}")
print(f"Generated Text: {generated_text}")

Loss: 2.455496072769165
Generated Text: 
Ong h hasbe pave pirance
RDe hicomyonthar's
PES:
AKEd ith henourzincenonthioneir thondy, y heltieiengerofo'dsssit ey
KINld pe wither vouplloutherccnohathe; d!
My hind tt hinig t ouchos tes; st yo hind wotte grotonear 'so itJas
Waketancotha:
h haybet--s n prids, r loncave w hollular s O:
HIs; ht anjx?

DUThinqunt.

LaZEESTEORDY:
h l.
KEONGBUCHandspo be y,-JZNEEYowddy scar t tridesar, wyonthenous s ls, theresseys
PlorseelapinghienHen yof GLUCEN t l-h!E:
I hisgothers je are!-e!
QUCotouciullle's fld


## Beyond Bigram
Until this point, we've been looking only at the current token to predict the next one and we've all of the remaining tokens in the context length. We now want to use **all** of the tokens in our context to predict the next token. \
*  First, we need to ensure that for each prediction, we're only using tokens that appear *before* the target token in the context. We shouldn't be able to look at the future tokens because we're trying to predict those.
*  Second, we need a way to aggregate the information from each of the prior tokens together. For now, we'll use average of all of the embedding vectors preceeding the target token. This is not ideal but good for illustration. \


We've seen how the current dimension for our input is `B x T x C`. Now, for each batch, the token at position `t` will now be an average of off all the embedding vectors of all tokens *preceeding* `t`. This will generate an input of dimension `t x C` for each position `t`.

In [None]:
from typing import cast
torch.manual_seed(1337)

B, T, C = 4, 8, 2  # Batch, Time and Channels.

x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [None]:
# If we want x[b, t] = mean_{i<=t} x[b, i]

xbow = torch.zeros((B, T, C))  # Bag of words since we're averaging everything.
for b in range(B):
  for t in range(T):
    xprev = x[b,:t+1]  # (t, C)
    xbow[b, t] = torch.mean(xprev, 0)

print("x[0]:\n")
print(x[0])
print(xbow[0])

x[0]:

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


## Vectorization
In the code snippet above, we went through the aggregation process using two nested loops. This is terribly inefficeint. If we want to compute the dot-product of two lists of numbers, we could do a similar nested loop strategy or we could convert the lists into PyTorch 1-D tensors and use the PyTorch vector operations. Linear Algebra operations in frameworks like PyTorch and TensorFlow and highly optimized operations making them significantly more compute-effecient that nested loops. So, as much as possible, we want to write our implementation in terms of linear algebra operations as possible. This process is called vectorization. In practice, several of the innefficiencies in training implementations stem from operations that aren't correctly vectorized. \
\
To vectorize the above implementation
*  We multiply the inputs with a lower-triangular matrix of `1's`. This ensures that for each position, we only consider the preceeding tokens.
*  We then compute the average at each position, but dividing by the vector sum.
*  The above two operations can be merged - instead of a lower triangular matrix of `1's`, we can use a lower-triangular matrix of fractions such that each row adds up to `1`. In the code below, this is represented by the `weights` matrix.

In [None]:
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)
xbow2 = weights @ x  # Batched multiplication. (B, T, T) x (B, T, C) => Each batch will have (T, T) x (T, C) product -> (B, T, C)
xbow2[0]
torch.allclose(xbow2, xbow2)

True

## Non-Static Weights
Vectorizing solved the problem of inefficient computation of weights for the prior tokens, however the final weights are still static. In the transformer, we want the token at each position to *learn* the weights for the preceeding tokens. That way individual tokens can have different weights when aggregating the past tokens. To do this, we use a different implementation to create the same mask - using the **Softmax** function. \
\
Each token is assigned a 0-weight and the tokens that appear *after* the target token are assigned `-inf`. We then take a Softmax over the row elements ($\frac{e^i}{\sum_{i}e^i}$), where the `-inf` elements become 0 ($\frac{1}{e^∞} = 0$) and the remaining are normalized ($\frac{1}{e^0} = 0$). This forms the framework for learnt weights in the future.

In [None]:
# 2nd implementation - use Softmax.
# Softmax also has a normalizing operation - -inf becomes zero and same items
# have same probabilities.
# Softmax is preferred since now we can adjust the weights of previous tokens so
# that the weights are now
tril = torch.tril(torch.ones(T, T))
weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

##  Building the Transformer
### Adding a head
In our current approach, we've taken the token index to retrieve a corresponding embedding. We've used the embedding dimension of the size of vocabulary so the embedding could encode a probability distribution.\
\
We're now adding another layer to our model - a linear layer that we're calling a *head*. We're now using the output of this head to encode the probabilities. The inputs to the head will be the token embeddings we retrieve from the embedding table. Since the embedding is no longer used as a probability distribution, we can now reduce the dimensions of the embedding, as hinted in the embeddings section before.

### Positional Embeddings
With aggregating the preceeding tokens, we're able to look at the entire context length available to the transformer. However, notice how the embedding for each token is only dependant on the token value. The relative positions of the preceeding tokens is not yet encoded. \
\
To change this, we adding positional encodings to our token embeddings. Each token embedding is a value from `0` to `block_size - 1` and will have an embedding of the same dimension as the token embedding.

In [None]:
n_embd = 32
block_size = _BLOCK_SIZE

In [None]:
class LanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    # Token Embedding.
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    # Positional Embedding.
    self.positional_embedding_table = nn.Embedding(block_size, n_embd)
    # Output Head.
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    B, T = idx.shape

    # Next token prediction for token idx.
    token_embedding = self.token_embedding_table(idx)  # B, T, C.
    pos_embedding = self.positional_embedding_table(torch.arange(T))  # T, C
    x = token_embedding + pos_embedding # (B, T, C)
    logits = self.lm_head(token_embedding) # B, T, vocab_size

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
    return logits, loss

  def generate(self, idx, max_new_tokens):
    """Generate the next token given the context idx."""
    for _ in range(max_new_tokens):
      # Crop Idx to the last block_size tokens since position embedding will
      # only contains indices upto block size.
      idx_crop = idx[:, -block_size:]
      # Get predictions.
      logits, loss = self(idx_crop)  # Defaults to self.forward(). Make targets optional in forward().

      # Get the output of the last time-step.
      logits = logits[:, -1, :]  # becomes (B, C)

      # Softmax to get probabilities.
      probs = F.softmax(logits, dim=-1) # Still (B, C)

      # Generate a sample from the distribution of tokens.
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

      # Append prediction to index.
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

In [None]:
x2, y2 = get_batch("train")
m2 = LanguageModel()
logits, loss = m2(x2, y2)
print(f"Loss: {loss.item()}")
decode(m2.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist())

Loss: 4.323673248291016


"\nGqTIFOtZYh!nNNB33BqJsGsF?-xAIpfcePfyjhYGv.pl'oevfBNZ. qYvj'Y?eC,pVbPM$U,ISeX'ApBXU?z?w'wF\n!fv$3Ztw,H"

In [None]:
_BATCH_SIZE = 32
optimizer = torch.optim.AdamW(m2.parameters(), lr=1e-3)
for steps in range(10000):
  # Sample a batch of data.
  x2, y2 = get_batch("train")

  # Evaluate the loss.
  logits, loss = m2(x2, y2)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()
  if steps % 1000 == 0:
    print(f"Loss: {loss.item()}")


Loss: 4.242744445800781
Loss: 2.469791889190674
Loss: 2.4862873554229736
Loss: 2.465446949005127
Loss: 2.400810718536377
Loss: 2.5412375926971436
Loss: 2.5605099201202393
Loss: 2.331833600997925
Loss: 2.573298454284668
Loss: 2.5379536151885986


In [None]:
x2, y2 = get_batch("train")
logits, loss = m2(x2, y2)
print(f"Loss: {loss.item()}")
decode(m2.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist())

Loss: 2.4701757431030273


'\nLOULI f iobettesotrits, the sthy tem ardors hthande che thel shithelthave ureest s, stem INLEENRO:\nS'

# Self-Attention intuition
Now that we have positional embeddings for our tokens, we want to combine the tokens in such a way that any previous tokens that are more relevant to the current token are emphasized and those that aren't relevant are penalized.

## Solution with self-attention
In self-attention, the token at each position will generate two embeddings - the key and the query. Intuitively the key is an embedding representation of what this token at this position contains. The query is the embedding representing 'What would be the most relevant information for this token?'

Then, the dot product between the query embedding of the current position, with the key token of all previous tokens will result in positive correlation if the relevance is high, and will be low (or negative) otherwise. These dot-products will be the weights that we want to use when averaging the previous token embeddings.

## Intuitve role of Q, K and V.
For any token, `X` is the information private to the token. That is what the token _really_ is. `Q` is a representation of what this token is looking for, based on data. `K` is a representation of the token's response to queries by other tokens (For instance, I'm a noun.) `V` is what this token will provide other tokens if their `Q.K` have high affinity.

In [None]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# Single head with self-attention.
head_size = 16
key = nn.Linear(C, head_size, bias=False)  # Mat-Mul with fixed weights.
query = nn.Linear(C, head_size, bias=False)
# Value vector is the one that actually represents the token and this is what we
# aggregate when computing the embedding for the current token.
value = nn.Linear(C, head_size, bias=False)
# For each batch in B, the (T, C) matrix gets multiplied by a (C, head_size)
# mat-mul layer producing a (B, T, head_size) result.
k = key(x)
q = query(x)



# Transpose the last (head_size) and second last (T), leaving Batch untouched.
# keys are now (B, head_size, T) and product will be (B, T, T) - each position has a
# weight vector over all other positions.
weights = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)

v = value(x)
output = weights @ v

In [None]:
class Head(nn.Module):
  """Single self-attention head."""

  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)  # Setting bias false means this is basically a matrix multiply operation.
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(_BLOCK_SIZE, _BLOCK_SIZE)))

  def forward(self, x):
    B, T, C = x.shape
    k = self.key(x)    # (B, T, C)
    q = self.query(x)  # (B, T, C)
    # Compute attention scores ("affinities")
    weights = q @ k.transpose(-2, -1) * C**-0.5  # Normarlized by channel size to normalize embeddings.
    weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))  # (B, T, T)
    weights = F.softmax(weights, dim=-1)  # (B, T, T)
    # Perform the weighted aggregation of the values.
    v = self.value(x)  # (B, T, C)
    out = weights @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
    return out



In [None]:
class SelfAttentionLanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    # Token Embedding.
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    # Positional Embedding.
    self.positional_embedding_table = nn.Embedding(block_size, n_embd)
    # Self Attention Head
    self.sa_head = Head(n_embd)
    # Output Head.
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    B, T = idx.shape

    # Next token prediction for token idx.
    token_embedding = self.token_embedding_table(idx)  # B, T, C.
    pos_embedding = self.positional_embedding_table(torch.arange(T))  # T, C
    x = token_embedding + pos_embedding # (B, T, C)
    x = self.sa_head(x)  # Apply self attention.
    logits = self.lm_head(token_embedding) # B, T, vocab_size

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
    return logits, loss

  def generate(self, idx, max_new_tokens):
    """Generate the next token given the context idx."""
    for _ in range(max_new_tokens):
      # Crop Idx to the last block_size tokens since position embedding will
      # only contains indices upto block size.
      idx_crop = idx[:, -block_size:]
      # Get predictions.
      logits, loss = self(idx_crop)  # Defaults to self.forward(). Make targets optional in forward().

      # Get the output of the last time-step.
      logits = logits[:, -1, :]  # becomes (B, C)

      # Softmax to get probabilities.
      probs = F.softmax(logits, dim=-1) # Still (B, C)

      # Generate a sample from the distribution of tokens.
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

      # Append prediction to index.
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

In [None]:
x3, y3 = get_batch("train")
m3 = SelfAttentionLanguageModel()
logits, loss = m3(x3, y3)
print(f"Loss: {loss.item()}")
decode(m3.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist())

Loss: 4.424878120422363


"\neDQ3QAKdPobjyHOG$MP;vO-uJTnOKhjxqaaRz.ypk3OFbY'LjDk3&QuN:qR.NkfIODK-KWz$q&IC?BysmjXxBnxQRdLKhEM ?U?-"

In [None]:
_BATCH_SIZE = 32
optimizer = torch.optim.AdamW(m3.parameters(), lr=1e-3)
for steps in range(10000):
  # Sample a batch of data.
  x3, y3 = get_batch("train")

  # Evaluate the loss.
  logits, loss = m3(x3, y3)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()
  if steps % 1000 == 0:
    print(f"Loss: {loss.item()}")

Loss: 4.418106555938721
Loss: 2.7019495964050293
Loss: 2.554936647415161
Loss: 2.6280925273895264
Loss: 2.4184787273406982
Loss: 2.428046464920044
Loss: 2.3833444118499756
Loss: 2.4863102436065674
Loss: 2.427128791809082
Loss: 2.4417405128479004


In [None]:
x3, y3 = get_batch("train")
logits, loss = m3(x3, y3)
print(f"Loss: {loss.item()}")
print(
    decode(m3.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=1000)[0].tolist())
)

Loss: 2.4954659938812256

CI cy agist hy hatas,'s 'k touther
Th,
Varo sprer de I ighiody richip spre!
S:
Whateld weth I it ing tl Talouthaitid t alyop I cos winourngond pld he ang'larenncord amer, Go k'sthe r
Thesele; wh st, yotinoowind h t Wiowe sose we h fund y,
Wilsthent;

BRor:
E:
I:
O Byay hal rstye tound d teaie,sh urewimatoornen alfrdthands, d wr t mou is pin at wames ngey th, win.
Jun my f has ifur.
OFo wanghele hath t lldoww II o hiedes?

You s, hinton d, bu whndes d ks
Thig y t I asisemowakealoor.

Yove roy s befathowentid timenousthis pu is ber,
I mer ker'de:
Bee CHUSHe, omesenyorrif ULADWI mo l d il ik; ash, kn h an phengnglosothin spullthitcks kn, rde nde.
I g irerllle sioutave anchanind ar withe d othat n m avecaghrle hind am ld whug GUS:
Fowneatoraiy omolodeack beronon he butous breitothablt,
O:
aty or d fe sengund o f f t, Vareaye.




Asosirf h's:
CULIO:
Thonothene tlofu ouprolldund'den ffolth. cher
CEvempr wansatifoundethelleency hise, EE:

ICEEOLLI igoved;


Ha whin 

# Self-Attention abstractly


*   Self-attention is a communication mechanism between a nodes in a directed graph. It so happens that in a language model, the nodes are structured in a way that current node has edges coming from all previous nodes (edges indicate direction of information flow.)
*   There is no notion of space in attention modeling, there are just vectors that determine affinity and weighting of values. Which is why positional embedding is additionally added.

*   There is no communication between the tokens across batches
*   In some language models (eg BERT) all the nodes are allowed to talk to each other and directionality is not limited. The will have an "encoder" block that doesn't have the lower-triangular masking. Current implementation is a decoder block to have the lower-triangular structure.
*   Self-Attention - Q, K and V come from the same token. There can be other attention mechanisims where the queries are produced from X, but K and V come from the encoder block which may add additional context. Cross-Attention - when there is separate set of nodes where we get the information from.
*   In the paper, the aggregation of values get scaled by a factor of sqrt(head_size). This is done for the following reason: Q and K are unit gaussian vectors. If Q.K is performed without normalization, variance of the product is in the order of head_size. If the variance of the inputs to the softmax is too strong, the softmax will become more "sharp" and start getting closer to a one-hot-encoding. This will result in each node only taking inputs from one other node. Normalizing with sqrt(head_size) fixes this issue.



In [None]:
# Weights is no longer a constant, but unique to each batch because each batch
# now different tokens and weigh is token dependent.
weights

# References
*  Andrej Karpathy's [GPT-2 Tutorial](https://www.youtube.com/watch?v=kCc8FmEb1nY)
*  [Attention is All You Need paper](https://arxiv.org/abs/1706.03762)
*  [OpenAI GPT-3 paper](https://arxiv.org/abs/2005.14165)
