# Transformers from scratch

<a href="https://colab.research.google.com/github/EffiSciencesResearch/ML4G-2.0/blob/master/workshops/transformer/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Thanks to Callum McDougall for this Notebook.

We also advise you not to look too much at the bonus or collapsed nerd-sniping stuff.

<img src="https://raw.githubusercontent.com/callummcdougall/TransformerLens-intro/main/images/page_images/transformer-building.png" width="350">


# Introduction

This is a clean, first principles implementation of GPT-2 in PyTorch. The architectural choices closely follow those used by the TransformerLens library (which you'll be using a lot more in later exercises).

Each exercise will have a difficulty and importance rating out of 5, as well as an estimated maximum time you should spend on these exercises and sometimes a short annotation. You should interpret the ratings & time estimates relatively (e.g. if you find yourself spending about 50% longer on the exercises than the time estimates, adjust accordingly). Please do skip exercises / look at solutions if you don't feel like they're important enough to be worth doing, and you'd rather get to the good stuff!


## Content & Learning Objectives


#### 1️⃣ Understanding Inputs & Outputs of a Transformer

In this section, we'll take a first look at transformers - what their function is, how information moves inside a transformer, and what inputs & outputs they take.

> ##### Learning objectives
>
> - Understand what a transformer is used for
> - Understand causal attention, and what a transformer's output representsalgebra operations on tensors
> - Learn what tokenization is, and how models do it
> - Understand what logits are, and how to use them to derive a probability distribution over the vocabulary

#### 2️⃣ Clean Transformer Implementation

Here, we'll implement a transformer from scratch, using only PyTorch's tensor operations. This will give us a good understanding of how transformers work, and how to use them. We do this by going module-by-module, in an experience which should feel somewhat similar to last week's ResNet exercises. Much like with ResNets, you'll conclude by loading in pretrained weights and verifying that your model works as expected.

> ##### Learning objectives
>
> * Understand that a transformer is composed of attention heads and MLPs, with each one performing operations on the residual stream
> * Understand that the attention heads in a single layer operate independently, and that they have the role of calculating attention patterns (which determine where information is moved to & from in the residual stream)
> * Learn about & implement the following transformer modules:
>     * (Bonus) LayerNorm (transforming the input to have zero mean and unit variance)
>     * Positional embedding (a lookup table from position indices to residual stream vectors)
>     * Attention (the method of computing attention patterns for residual stream vectors)
>     * MLP (the collection of linear and nonlinear transformations which operate on each residual stream vector in the same way)
>     * Embedding (a lookup table from tokens to residual stream vectors)
>     * Unembedding (a matrix for converting residual stream vectors into a distribution over tokens)

## Setup (don't read, just run!)


In [1]:
%pip install transformer_lens einops jaxtyping circuitsvis -q

Note: you may need to restart the kernel to use updated packages.


In [30]:
# os.environ['ACCELERATE_DISABLE_RICH'] = "1"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import einops
from dataclasses import dataclass
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new
import torch
from torch import Tensor
import torch.nn as nn
import math
from tqdm.notebook import tqdm
from jaxtyping import Float, Int
import os
import pickle

device = torch.device("cpu")

if os.path.exists('reference_gpt2.pkl'):
    # Load the reference_gpt2 model from the file
    with open('reference_gpt2.pkl', 'rb') as f:
        reference_gpt2 = pickle.load(f)
else:
    reference_gpt2 = HookedTransformer.from_pretrained(
        "gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False
    ).to(device)

    # Save the reference_gpt2 model to a file
    with open('reference_gpt2.pkl', 'wb') as f:
        pickle.dump(reference_gpt2, f)


# 1️⃣ Understanding Inputs & Outputs of a Transformer


> ### Learning Objectives
>
> * Understand what a transformer is used for
> * Understand causal attention, and what a transformer's output represents
> * Learn what tokenization is, and how models do it
> * Understand what logits are, and how to use them to derive a probability distribution over the vocabulary


## What is the point of a transformer?


**Transformers exist to model text!**

We're going to focus GPT-2 style transformers. Key feature: They generate text! You feed in language, and the model generates a probability distribution over tokens. And you can repeatedly sample from this to generate text!

(To explain this in more detail - you feed in a sequence of length $N$, then sample from the probability distribution over the $N+1$-th word, use this to construct a new sequence of length $N+1$, then feed this new sequence into the model to get a probability distribution over the $N+2$-th word, and so on.)

### How is the model trained?

You give it a bunch of text, and train it to predict the next token.

Importantly, if you give a model 100 tokens in a sequence, it predicts the next token for *each* prefix, i.e. it produces 100 logit vectors (= probability distributions) over the set of all words in our vocabulary, with the `i`-th logit vector representing the probability distribution over the token *following* the `i`-th token in the sequence.

<details>
<summary>Aside - logits</summary>

If you haven't encountered the term "logits" before, here's a quick refresher.

Given an arbitrary vector $x$, we can turn it into a probability distribution via the **softmax** function: $x_i \to \frac{e^{x_i}}{\sum e^{x_j}}$. The exponential makes everything positive; the normalization makes it add to one.

The model's output is the vector $x$ (one for each prediction it makes). We call this vector a logit because it represents a probability distribution, and it is related to the actual probabilities via the softmax function.
</details>

How do we stop the transformer by "cheating" by just looking at the tokens it's trying to predict? Answer - we make the transformer have *causal attention* (as opposed to *bidirectional attention*). Causal attention only allows information to move forwards in the sequence, never backwards. The prediction of what comes after token 50 is only a function of the first 50 tokens, *not* of token 51. We say the transformer is **autoregressive**, because it only predicts future words based on past data.


<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/transformer-overview-new.png" width="900">


## Tokens - Transformer Inputs


Our tranformer's input is natural language (i.e. a sequence of characters, strings, etc). But ML models generally take vectors as input, not langage. How do we convert language to vectors?

We can factor this into 2 questions:

1. How do we split up language into small sub-units?
2. How do we convert these sub-units into vectors?

Let's start with the second of these questions.


### Converting sub-units to vectors

We basically make a massive lookup table, which is called an **embedding**. It has one vector for each possible sub-unit of language we might get (we call this set of all sub-units our **vocabulary**). We label every element in our vocabulary with an integer (this labelling never changes), and we use this integer to index into the embedding.

A key intuition is that one-hot encodings let you think about each integer independently. We don't bake in any relation between words when we perform our embedding, because every word has a completely separate embedding vector.

<details>
<summary>Aside - one-hot encodings</summary>

We sometimes think about **one-hot encodings** of words. These are vectors with zeros everywhere, except for a single one in the position corresponding to the word's index in the vocabulary. This means that indexing into the embedding is equivalent to multiplying the **embedding matrix** by the one-hot encoding (where the embedding matrix is the matrix we get by stacking all the embedding vectors on top of each other).

$$
\begin{aligned}
W_E &= \begin{bmatrix}
\leftarrow v_0 \rightarrow \\
\leftarrow v_1 \rightarrow \\
\vdots \\
\leftarrow v_{d_{vocab}-1} \rightarrow \\
\end{bmatrix} \quad \text{is the embedding matrix (size }d_{vocab} \times d_{embed}\text{),} \\
\\
t_i &= (0, \dots, 0, 1, 0, \dots, 0) \quad \text{is the one-hot encoding for the }i\text{th word (length }d_{vocab}\text{)} \\
\\
v_i &= t_i W_E \quad \text{is the embedding vector for the }i\text{th word (length }d_{embed}\text{).} \\
\end{aligned}
$$

</details>

Now, let's answer the first question - how do we split language into sub-units?


### (Bonus) Splitting language into sub-units

We need to define a standard way of splitting up language into a series of substrings, where each substring is a member of our **vocabulary** set.

Could we use a dictionary, and have our vocabulary be the set of all words in the dictionary? No, because this couldn't handle arbitrary text (e.g. URLs, punctuation, etc). We need a more general way of splitting up language.

Could we just use the 256 ASCII characters? This fixes the previous problem, but it loses structure of language - some sequences of characters are more meaningful than others. For example, "language" is a lot more meaningful than "hjksdfiu". We want "language" to be a single token, but not "hjksdfiu" - this is a more efficient use of our vocab.

What actually happens? The most common strategy is called **Byte-Pair encodings**.

We begin with the 256 ASCII characters as our tokens, and then find the most common pair of tokens, and merge that into a new token. Note that we do have a space character as one of our 256 tokens, and merges using space are very common. For instance, here are the five first merges for the tokenizer used by GPT-2 (you'll be able to verify this below).

```
" t"
" a"
"he"
"in"
"re"
```

Note - you might see the character `Ġ` in front of some tokens. This is a special character that indicates that the token begins with a space. Tokens with a leading space vs not are different.

You can run the code below to see some more of GPT-2's tokenizer's vocabulary:


In [3]:
sorted_vocab = sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda n: n[1])
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()

[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19)]

[('ľ', 250), ('Ŀ', 251), ('ŀ', 252), ('Ł', 253), ('ł', 254), ('Ń', 255), ('Ġt', 256), ('Ġa', 257), ('he', 258), ('in', 259), ('re', 260), ('on', 261), ('Ġthe', 262), ('er', 263), ('Ġs', 264), ('at', 265), ('Ġw', 266), ('Ġo', 267), ('en', 268), ('Ġc', 269)]

[('Ġprodu', 990), ('Ġstill', 991), ('led', 992), ('ah', 993), ('Ġhere', 994), ('Ġworld', 995), ('Ġthough', 996), ('Ġnum', 997), ('arch', 998), ('imes', 999), ('ale', 1000), ('ĠSe', 1001), ('ĠIf', 1002), ('//', 1003), ('ĠLe', 1004), ('Ġret', 1005), ('Ġref', 1006), ('Ġtrans', 1007), ('ner', 1008), ('ution', 1009)]



As you get to the end of the vocabulary, you'll be producing some pretty weird-looking esoteric tokens (because you'll already have exhausted all of the short frequently-occurring ones):


In [4]:
print(sorted_vocab[-20:])

[('Revolution', 50237), ('Ġsnipers', 50238), ('Ġreverted', 50239), ('Ġconglomerate', 50240), ('Terry', 50241), ('794', 50242), ('Ġharsher', 50243), ('Ġdesolate', 50244), ('ĠHitman', 50245), ('Commission', 50246), ('Ġ(/', 50247), ('âĢ¦."', 50248), ('Compar', 50249), ('Ġamplification', 50250), ('ominated', 50251), ('Ġregress', 50252), ('ĠCollider', 50253), ('Ġinformants', 50254), ('Ġgazed', 50255), ('<|endoftext|>', 50256)]


Transformers in the `transformer_lens` library have a `to_tokens` method that converts text to numbers. It also prepends them with a special token called BOS (beginning of sequence) to indicate the start of a sequence. You can disable this with the `prepend_bos=False` argument.



### Some tokenization annoyances

There are a few funky and frustrating things about tokenization, which causes it to behave differently than you might expect. For instance:

#### Whether a word begins with a capital or space matters!


In [5]:
print(reference_gpt2.to_str_tokens("Ralph"))
print(reference_gpt2.to_str_tokens(" Ralph"))
print(reference_gpt2.to_str_tokens(" ralph"))
print(reference_gpt2.to_str_tokens("ralph"))

['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']


> ### Key Takeaways
>
> * We learn a dictionary of vocab of tokens (sub-words).
> * We (approx) losslessly convert language to integers via tokenizing it.
> * We convert integers to vectors via a lookup table.
> * Note: input to the transformer is a sequence of *tokens* (ie integers), not vectors


## Text generation

Now that we understand the basic ideas here, let's go through the entire process of text generation, from our original string to a new token which we can append to our string and plug back into the model.

#### **Step 1:** Convert text to tokens

The sequence gets tokenized, so it has shape `[batch, seq_len]`. Here, the batch dimension is just one (because we only have one sequence).


In [6]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text).to(device)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0]])
torch.Size([1, 35])
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


#### **Step 2:** Map tokens to logits


From our input of shape `[batch, seq_len]`, we get output of shape `[batch, seq_len, vocab_size]`. The `[i, j, :]`-th element of our output is a vector of logits representing our prediction for the `j+1`-th token in the `i`-th sequence.


In [7]:
logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 35, 50257])


(`run_with_cache` tells the model to cache all intermediate activations. This isn't important right now; we'll look at it in more detail later.)


#### **Step 3:** Convert the logits to a distribution with a softmax

This doesn't change the shape, it is still `[batch, seq_len, vocab_size]`.


In [8]:
probs = logits.softmax(dim=-1)
print(probs.shape)

torch.Size([1, 35, 50257])


#### **Bonus step:** What is the most likely next token at each position?


In [9]:
most_likely_next_tokens = reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])


print("          token", "-->", "next_token")
for token, next_token in zip(reference_gpt2.to_str_tokens(tokens), most_likely_next_tokens):
    print(f"{token!r:>15} --> {next_token!r}")

          token --> next_token
'<|endoftext|>' --> '\n'
            'I' --> "'m"
          ' am' --> ' a'
          ' an' --> ' avid'
     ' amazing' --> ' person'
         ' aut' --> 'od'
          'ore' --> 'sp'
     'gressive' --> '.'
            ',' --> ' and'
         ' dec' --> 'ently'
         'oder' --> ','
            '-' --> 'driven'
         'only' --> ' programmer'
            ',' --> ' and'
           ' G' --> 'IM'
           'PT' --> '-'
            '-' --> 'only'
            '2' --> '.'
       ' style' --> ','
 ' transformer' --> '.'
            '.' --> ' I'
         ' One' --> ' of'
         ' day' --> ' I'
           ' I' --> ' will'
        ' will' --> ' be'
      ' exceed' --> ' my'
       ' human' --> 'ly'
       ' level' --> ' of'
' intelligence' --> ' and'
         ' and' --> ' I'
        ' take' --> ' over'
        ' over' --> ' the'
         ' the' --> ' world'
       ' world' --> '.'
            '!' --> ' I'


We can see that, in a few cases (particularly near the end of the sequence), the model accurately predicts the next token in the sequence. We might guess that `"take over the world"` is a common phrase that the model has seen in training, which is why the model can predict it.


#### **Step 4:** Map distribution to a token


In [10]:
next_token = logits[0, -1].argmax(dim=-1)
next_char = reference_gpt2.to_string(next_token)
print(repr(next_char))

' I'


Note that we're indexing `logits[0, -1]`. This is because logits have shape `[1, sequence_length, vocab_size]`, so this indexing returns the vector of length `vocab_size` representing the model's prediction for what token follows the **last** token in the input sequence.

In this case, we can see that the model predicts the token `' I'`.


### **Step 5:** Add this to the end of the input, re-run

There are more efficient ways to do this (e.g. where we cache some of the values each time we run our input, so we don't have to do as much calculation each time we generate a new value), but this doesn't matter conceptually right now.


In [11]:
print(f"Sequence so far: {reference_gpt2.to_string(tokens)[0]!r}")

for i in range(10):
    print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
    # Define new input sequence, by appending the previously generated token
    tokens = torch.cat([tokens, next_token[None, None]], dim=-1)
    # Pass our new sequence through the model, to get new output
    logits = reference_gpt2(tokens)
    # Get the predicted token at the end of our sequence
    next_token = logits[0, -1].argmax(dim=-1)
    # Decode and print the result
    next_char = reference_gpt2.to_string(next_token)

Sequence so far: '<|endoftext|>I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!'
36th char = ' I'
37th char = ' am'
38th char = ' a'
39th char = ' very'
40th char = ' talented'
41th char = ' and'
42th char = ' talented'
43th char = ' person'
44th char = ','
45th char = ' and'


## Key takeaways

* Transformer takes in language, predicts next token (for *each* token in a causal way)
* We convert language to a sequence of integers with a tokenizer.
* We convert integers to vectors with a lookup table.
* Output is a vector of logits (one for each input token), we convert to a probability distn with a softmax, and can then convert this to a token (eg taking the largest logit, or sampling).
* We append this to the input + run again to generate more text (Jargon: *autoregressive*)
* Meta level point: Transformers are sequence operation models, they take in a sequence, do processing in parallel at each position, and use attention to move information between positions!


# 2️⃣ Clean Transformer Implementation


> ##### Learning objectives
>
> * Understand that a transformer is composed of attention heads and MLPs, with each one performing operations on the residual stream
> * Understand that the attention heads in a single layer operate independently, and that they have the role of calculating attention patterns (which determine where information is moved to & from in the residual stream)
> * Learn about & implement the following transformer modules:
>     * (Bonus) LayerNorm (transforming the input to have zero mean and unit variance)
>     * Positional embedding (a lookup table from position indices to residual stream vectors)
>     * Attention (the method of computing attention patterns for residual stream vectors)
>     * MLP (the collection of linear and nonlinear transformations which operate on each residual stream vector in the same way)
>     * Embedding (a lookup table from tokens to residual stream vectors)
>     * Unembedding (a matrix for converting residual stream vectors into a distribution over tokens)


## High-Level architecture

Go watch Neel's [Transformer Circuits walkthrough](https://www.youtube.com/watch?v=KV5gbOmHbjU) if you want more intuitions!

(Diagram is bottom to top.)


<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/transformer-new.png" width="850">


### Tokenization & Embedding

The input tokens $t$ are integers. We get them from taking a sequence, and tokenizing it (like we saw in the previous section).

The token embedding is a lookup table mapping tokens to vectors, which is implemented as a matrix $W_E$. The matrix consists of a stack of token embedding vectors (one for each token).


### Residual stream

The residual stream is the sum of all previous outputs of layers of the model, is the input to each new layer. It has shape `[batch, seq_len, d_model]` (where `d_model` is the length of a single embedding vector).

The initial value of the residual stream is denoted $x_0$ in the diagram, and $x_i$ are later values of the residual stream (after more attention and MLP layers have been applied to the residual stream).

The residual stream is *really* fundamental. It's the central object of the transformer. It's how model remembers things, moves information between layers for composition, and it's the medium used to store the information that attention moves between positions.


### Transformer blocks

Then we have a series of `n_layers` **transformer blocks** (also sometimes called **residual blocks**).

Note - a block contains an attention layer *and* an MLP layer, but we say a transformer has $k$ layers if it has $k$ blocks (i.e. $2k$ total layers).


#### Attention

First we have attention. This moves information from prior positions in the sequence to the current token.

We do this for *every* token in parallel using the same parameters. The only difference is that we look backwards only (to avoid "cheating"). This means later tokens have more of the sequence that they can look at.

Attention layers are the only bit of a transformer that moves information between positions (i.e. between vectors at different sequence positions in the residual stream).

Attention layers are made up of `n_heads` heads - each with their own parameters, own attention pattern, and own information how to copy things from source to destination. The heads act independently and additively, we just add their outputs together, and back to the stream.

Each head does the following:
* Produces an **attention pattern** for each destination token, a probability distribution of prior source tokens (including the current one) weighting how much information to copy.
* Moves information (via a linear map) in the same way from each source token to each destination token.

A few key points:

* What information we copy depends on the source token's *residual stream*, but this doesn't mean it only depends on the value of that token, because the residual stream can store more information than just the token identity (the purpose of the attention heads is to move information between tokens).
* We can think of each attention head as consisting of two different **circuits**:
    * One circuit determines **where to move information to and from** (this is a function of the residual stream for the source/key and destination/query tokens)
    * The other circuit determines **what information to move** (this is a function of only the source token's residual stream)
    * For reasons which will become clear later, we refer to the first circuit as the **QK circuit**, and the second circuit as the **OV circuit**


Below is a schematic diagram of the attention layers. Don't worry if you don't follow this right now, we'll go into more detail during implementation.


<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/transformer-attn-new.png" width="1100">


### MLP

The MLP layers are just a standard neural network, with a singular hidden layer and a nonlinear activation function. The exact activation isn't conceptually important ([GELU](https://paperswithcode.com/method/gelu) seems to perform best).

Our hidden dimension is normally `d_mlp = 4 * d_model`. Exactly why the ratios are what they are isn't super important (people basically cargo-cult what GPT did back in the day!).

Importantly, **the MLP operates on positions in the residual stream independently, and in exactly the same way**. It doesn't move information between positions.

Intuition - once attention has moved relevant information to a single position in the residual stream, MLPs can actually do computation, reasoning, lookup information, etc. *What the hell is going on inside MLPs* is a pretty big open problem in transformer mechanistic interpretability - see the [Toy Model of Superposition Paper](https://transformer-circuits.pub/2022/toy_model/index.html) for more on why this is hard.



<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/transformer-mlp-new-2.png" width="650">


### Unembedding

Finally, we unembed!

This just consists of applying a linear map $W_U$, going from final residual stream to a vector of logits - this is the output.



<details>
<summary>Bonus things - less conceptually important but key technical details</summary>

#### LayerNorm

* Simple normalization function applied at the start of each layer (i.e. before each MLP, attention layer, and before the unembedding)
* Converts each input vector (independently in parallel for each batch x position residual stream vector) to have mean zero and variance 1.
* Then applies an elementwise scaling and translation
* Cool maths tangent: The scale & translate is just a linear map. LayerNorm is only applied immediately before another linear map. Linear compose linear = linear, so we can just fold this into a single effective linear layer and ignore it.
    * `fold_ln=True` flag in `from_pretrained` does this for you.
* LayerNorm is annoying for interpertability - the scale part is not linear, so you can't think about different bits of the input independently. But it's *almost* linear - if you're changing a small part of the input it's linear, but if you're changing enough to alter the norm substantially it's not linear.



#### Positional embeddings

* **Problem:** Attention operates over all pairs of positions. This means it's symmetric with regards to position - the attention calculation from token 5 to token 1 and token 5 to token 2 are the same by default
    * This is dumb because nearby tokens are more relevant.
* There's a lot of dumb hacks for this.
* We'll focus on **learned, absolute positional embeddings**. This means we learn a lookup table mapping the index of the position of each token to a residual stream vector, and add this to the embed.
    * Note that we *add* rather than concatenate. This is because the residual stream is shared memory, and likely under significant superposition (the model compresses more features in there than the model has dimensions)
    * We basically never concatenate inside a transformer, unless doing weird shit like generating text efficiently.
* This connects to **attention as generalized convolution**
    * We argued that language does still have locality, and so it's helpful for transformers to have access to the positional information so they "know" two tokens are next to each other (and hence probably relevant to each other).
</details>

## Actual Code!

### Parameters and Activations

It's important to distinguish between parameters and activations in the model.

* **Parameters** are the weights and biases that are learned during training.
    * These don't change when the model input changes.
* **Activations** are temporary numbers calculated during a forward pass, that are functions of the input.
    * We can think of these values as only existing for the duration of a single forward pass, and disappearing afterwards.
    * We can use hooks to access these values during a forward pass (more on hooks later), but it doesn't make sense to talk about a model's activations outside the context of some particular input.
    * Attention scores and patterns are activations (this is slightly non-intuitve because they're used in a matrix multiplication with another activation).



#### Print All Parameters Shapes of Reference Model


In [12]:
for name, param in reference_gpt2.named_parameters():
    # Only print for first layer
    if ".0." in name or "blocks" not in name:
        print(f"{name:18} {tuple(param.shape)}")

embed.W_E          (50257, 768)
pos_embed.W_pos    (1024, 768)
blocks.0.ln1.w     (768,)
blocks.0.ln1.b     (768,)
blocks.0.ln2.w     (768,)
blocks.0.ln2.b     (768,)
blocks.0.attn.W_Q  (12, 768, 64)
blocks.0.attn.W_O  (12, 64, 768)
blocks.0.attn.b_Q  (12, 64)
blocks.0.attn.b_O  (768,)
blocks.0.attn.W_K  (12, 768, 64)
blocks.0.attn.W_V  (12, 768, 64)
blocks.0.attn.b_K  (12, 64)
blocks.0.attn.b_V  (12, 64)
blocks.0.mlp.W_in  (768, 3072)
blocks.0.mlp.b_in  (3072,)
blocks.0.mlp.W_out (3072, 768)
blocks.0.mlp.b_out (768,)
ln_final.w         (768,)
ln_final.b         (768,)
unembed.W_U        (768, 50257)
unembed.b_U        (50257,)


### Config

The config object contains all the hyperparameters of the model. We can print the config of the reference model to see what it contains:


In [13]:
# As a reference - note there's a lot of stuff we don't care about in here, to do with library internals or other architectures
print(reference_gpt2.cfg)

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': np.float64(8.0),
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': 'cpu',
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': np.float64(0.02886751345948129),
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LN',
 'num_experts': None,
 'o

We define a stripped down config for our model:


In [14]:
@dataclass
class Config:
    d_model: int = 768  # dimension of the residual_stream
    debug: bool = True
    d_vocab: int = 50257
    n_ctx: int = 1024  # max nb of tokens that the model can handle
    d_head: int = 64  # dimension of each key/query/value
    d_mlp: int = 3072  # dimension of the hidden layer inside the MLPs
    n_heads: int = 12  # Nb of heads
    n_layers: int = 12  # Nb of (Attention+ MLP) in the GPT

    layer_norm_eps: float = 1e-5  # (Bonus)
    init_range: float = 0.02  # (bonus) standard deviation of 0.02 for weight initialization


cfg = Config()
print(cfg)

Config(d_model=768, debug=True, d_vocab=50257, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12, layer_norm_eps=1e-05, init_range=0.02)


## Tests


Tests are great, write lightweight ones to use as you go!

**Naive test:** Generate random inputs of the right shape, input to your model, check whether there's an error and print the correct output.


In [15]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = torch.randn(shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output shape:", output.shape, "\n")


def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = torch.randint(100, 1000, shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output shape:", output.shape, "\n")


def load_gpt2_test(cls, gpt2_layer, input):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    print("Input shape:", input.shape)
    output = layer(input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output shape:", output.shape)
    try:
        reference_output = gpt2_layer(input)
    except:
        reference_output = gpt2_layer(input, input, input)
    print("Reference output shape:", reference_output.shape, "\n")
    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct\n")

## Embedding

```c
Difficulty: 🟠🟠⚪⚪⚪
Importance: 🟠🟠🟠⚪⚪

You should spend up to 5-10 minutes on this exercise.
```

This is basically a lookup table from tokens to residual stream vectors.

(Hint - you can implement this in just one line, without any complicated functions.)


In [16]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch token"]) -> Float[Tensor, "batch token d_model"]:
        """Compute the embedding of the input tokens."""
        # Hide: all
        return self.W_E[tokens]
        # Hide: None


rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 45])
Output shape: torch.Size([1, 45, 768])
Reference output shape: torch.Size([1, 45, 768]) 

100.00% of the values are correct



<details>
<summary>Help - I keep getting <code>RuntimeError: CUDA error: device-side assert triggered</code>.</summary>

This is a uniquely frustrating type of error message, because it (1) forces you to restart the kernel, and (2) often won't tell you where the error message actually originated from!

You can fix the second problem by adding the line `os.environ['CUDA_LAUNCH_BLOCKING'] = "1"` to the very top of your file (after importing `os`). This won't fix your bug, but it makes sure the correct origin point is identified.

As for actually fixing the bug, this error usually ends up being the result of bad indexing, e.g. you're trying to apply an embedding layer to tokens which are larger than your maximum embedding.
</details>

## Positional Embedding

```c
Difficulty: 🟠🟠⚪⚪⚪
Importance: 🟠🟠🟠⚪⚪

You should spend up to 10-15 minutes on this exercise.
```

Positional embedding can also be thought of as a lookup table, but rather than the indices being our token IDs, the indices are just the numbers `0`, `1`, `2`, ..., `seq_len-1` (i.e. the position indices of the tokens in the sequence).


In [17]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch token"]) -> Float[Tensor, "batch token d_model"]:
        # Hint: You should use the einops.repeat or torch.reapeat function
        # to repeat batch-wise the positional embedding.
        # Hide: hard
        # The value of tokens is not important here, only the size of the tensor!
        # Hide: all
        batch, seq_len = tokens.shape
        return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch=batch)
        # Or self.W_pos[:seq_len].repeat(batch, 1, 1)
        # Hide: none


rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 45])
Output shape: torch.Size([1, 45, 768])
Reference output shape: torch.Size([1, 45, 768]) 

100.00% of the values are correct



## Attention

```c
Difficulty: 🟠🟠🟠🟠⚪
Importance: 🟠🟠🟠🟠🟠

You should spend up to 30-45 minutes on this exercise.
```

* **Step 1:** Produce an attention pattern - for each destination token, probability distribution over previous tokens (including current token)
    * Linear map from input -> query, key shape `[batch, seq_posn, head_index, d_head]`
    * Dot product every *pair* of queries and keys to get attn_scores `[batch, head_index, query_pos, key_pos]` (query = dest, key = source)
    * **Scale** and mask `attn_scores` to make it lower triangular, i.e. causal
    * Softmax along the `key_pos` dimension, to get a probability distribution for each query (destination) token - this is our attention pattern!
* **Step 2:** Move information from source tokens to destination token using attention pattern
    * Linear map from input -> value `[batch, key_pos, head_index, d_head]`
    * Mix along the `key_pos` with attn pattern to get `z`, which is a weighted average of the value vectors `[batch, query_pos, head_index, d_head]`
    * Map to output, `[batch, position, d_model]` (position = query_pos, we've summed over all heads)

Note - when we say **scale**, we mean dividing by `sqrt(d_head)`. The purpose of this is to avoid vanishing gradients (which is a big problem when we're dealing with a function like softmax - if one of the values is much larger than all the others, the probabilities will be close to 0 or 1, and the gradients will be close to 0).

Below is a much larger, more detailed version of the attention head diagram from earlier. This should give you an idea of the actual tensor operations involved. A few clarifications on this diagram:

* Whenever there is a third dimension shown in the pictures, this refers to the `head_index` dimension. We can see that all operations within the attention layer are done independently for each head.
* The objects in the box are activations; they have a batch dimension (for simplicity, we assume the batch dimension is 1 in the diagram). The objects to the right of the box are our parameters (weights and biases); they have no batch dimension.

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/transformer-attn-21.png" width="1400">


<details>
<summary><b>A few extra notes on attention (optional)</b></summary>

Usually we have the relation `e = n * h` (i.e. `d_model = num_heads * d_head`). There are some computational justifications for this, but mostly this is just done out of convention (just like how we usually have `d_mlp = 4 * d_model`!).

---

The names **keys**, **queries** and **values** come from their analogy to retrieval systems. Broadly speaking:

* The **queries** represent some information that a token is **"looking for"**
* The **keys** represent the information that a token **"contains"**
    * So the attention score being high basically means that the source (key) token contains the information which the destination (query) token **is looking for**
* The **values** represent the information that is actually taken from the source token, to be moved to the destination token

---

This diagram can better help us understand the difference between the **QK** and **OV** circuit. We'll discuss this just briefly here, and will go into much more detail later on.

The **QK** circuit consists of the operation of the $W_Q$ and $W_K$ matrices. In other words, it determines the attention pattern, i.e. where information is moved to and from in the residual stream. The functional form of the attention pattern $A$ is:

$$
A = \text{softmax}\left(\frac{x^T W_Q W_K^T x}{\sqrt{d_{head}}}\right)
$$

where $x$ is the residual stream (shape `[seq_len, d_model]`), and $W_Q$, $W_K$ are the weight matrices for a single head (i.e. shape `[d_model, d_head]`).

The **OV** circuit consists of the operation of the $W_V$ and $W_O$ matrices. Once attention patterns are fixed, these matrices operate on the residual stream at the source position, and their output is the thing which gets moved from source to destination position.

The functional form of an entire attention head is:

$$
\begin{aligned}
\text{output} &= \text{softmax}\left(\frac{x W_Q W_K^T x^T}{\sqrt{d_{head}}}\right) (x W_V W_O) \\
    &= Ax W_V W_O
\end{aligned}
$$

where $W_V$ has shape `[d_model, d_head]`, and $W_O$ has shape `[d_head, d_model]`.

Here, we can clearly see that the **QK circuit** and **OV circuit** are doing conceptually different things, and should be thought of as two distinct parts of the attention head.

Again, don't worry if you don't follow all of this right now - we'll go into **much** more detail on all of this in subsequent exercises. The purpose of the discussion here is just to give you a flavour of what's to come!

</details>


First, it's useful to visualize and play around with attention patterns - what exactly are we looking at here? (Click on a head to lock onto just showing that head's pattern, it'll make it easier to interpret)


In [18]:
import circuitsvis as cv
from IPython.display import display

html = cv.attention.attention_patterns(
    tokens=reference_gpt2.to_str_tokens(reference_text),
    attention=cache["pattern", 0][0],
)
display(html)

You can also use the `attention_heads` function, which has similar syntax but presents the information in a different (sometimes more helpful) way.


<details>
<summary>Help - my <code>attention_heads</code> plots are behaving weirdly.</summary>

This seems to be a bug in `circuitsvis` - on VSCode, the attention head plots continually shrink in size.

Until this is fixed, one way to get around it is to open the plots in your browser. You can do this inline with the `webbrowser` library:

```python
attn_heads = cv.attention.attention_heads(
    tokens=reference_gpt2.to_str_tokens(reference_text),
    attention=cache["pattern", 0][0]
)

path = "attn_heads.html"

with open(path, "w") as f:
    f.write(str(attn_heads))

webbrowser.open(path)
```

To check exactly where this is getting saved, you can print your current working directory with `os.getcwd()`.
</details>


---

Note - don't worry if you don't get 100% accuracy here; the tests are pretty stringent. Even things like having your `einsum` input arguments in a different order might result in the output being very slightly different. You should be getting at least 99% accuracy though, so if the value is lower then this it probably means you've made a mistake somewhere.

Also, this implementation will probably be the most challenging exercise on this page, so don't worry if it takes you some time! You should look at parts of the solution or hints if you're stuck.


<details>
<summary>Hint: high level steps</summary>

```python
# Calculate query, key and value vectors
# Calculate attention scores
# Then scale and apply mask and apply softmax on the correct dimension to get probabilities
# Take weighted sum of value vectors, according to attention probabilities
# Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
```
</details>

<details>
<summary>Hint: detailed steps with only a few blanks to fill.</summary>

```python
# Calculate query, key and value vectors
q = (
    einops.einsum(
        normalized_resid_pre,
        self.W_Q,
        "batch posn d_model, nheads d_model d_head -> ???",
    )
    + self.b_Q
)

k = ...

v = ...

# Calculate attention scores
attn_scores = einops.einsum(
    q,
    k,
    "???,??? -> batch nheads posn_Q posn_K",
)

# then scale and apply mask and apply softmax on the correct dimension to get probabilities
attn_scores_masked = self.apply_causal_mask(attn_scores / self.cfg.d_head**0.5)
attn_pattern = attn_scores_masked.softmax(dim=...)

# Take weighted sum of value vectors, according to attention probabilities
z = einops.einsum(
    v,
    attn_pattern,
    "???,??? -> batch posn_Q nheads d_head",
)

# Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
attn_out = (
    einops.einsum(
        z,
        self.W_O,
        "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q d_model",
    )
    + self.b_O
)
return attn_out
```
</details>





In [19]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device=device))

    # fmt: off
    def forward(
        self, normalized_resid_pre: Float[Tensor, "batch token d_model"]
    ) -> Float[Tensor, "batch token d_model"]:
        # Use einops!
        # And to help us understand your code quickly, try to use only the following names with einops:
        # batch head token token_k token_q d_model d_head

        # Hide: hard
        # Calculate query, key and value vectors
        # Hide: all
        q = einops.einsum(normalized_resid_pre, self.W_Q,
                "batch token d_model, head d_model d_head -> batch token head d_head",
            ) + self.b_Q
        k = einops.einsum(normalized_resid_pre, self.W_K,
                "batch token d_model, head d_model d_head -> batch token head d_head",
            ) + self.b_K
        v = einops.einsum(normalized_resid_pre, self.W_V,
                "batch token d_model, head d_model d_head -> batch token head d_head",
            ) + self.b_V
        # Hide: hard

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        # Hide: all
        attn_scores = einops.einsum(q, k,
            "batch token_q head d_head, batch token_k head d_head -> batch head token_q token_k")
        attn_scores_masked = self.apply_causal_mask(attn_scores / self.cfg.d_head**0.5)
        attn_pattern = attn_scores_masked.softmax(-1)
        # Hide: hard

        # Take weighted sum of value vectors, according to attention probabilities
        # Hide: all
        z = einops.einsum(v, attn_pattern,
            "batch token_k head d_head, batch head token_q token_k -> batch token_q head d_head")
        # Hide: hard

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        # Hide: all
        attn_out = einops.einsum(z, self.W_O,
                "batch token_q head d_head, head d_head d_model -> batch token_q d_model",
            ) + self.b_O

        return attn_out
        # Hide: none

    # fmt: on
    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch head token_q token_k"]
    ) -> Float[Tensor, "batch head token_q token_k"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        # Define a mask that is True for all positions we want to set probabilities to zero for
        all_ones = torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
        mask = torch.triu(all_ones, diagonal=1).bool()
        # Apply the mask to attention scores, then return the masked scores
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores


rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["normalized", 0, "ln1"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



Note: The `"IGNORE"` buffer is a very large negative number. This is the value you should mask your attention scores with (i.e. set them to this number wherever you want the probabilities to be zero).

<details>
<summary>Question - why do you think we mask the attention scores by setting them to a large negative number, rather than the attention probabilities by setting them to zero?</summary>

If we masked the attention probabilities, then the probabilities would no longer sum to 1.

We want to mask the scores and *then* take softmax, so that the probabilities are still valid probabilities (i.e. they sum to 1), and the values in the masked positions have no influence on the model's output.
</details>


## MLP

```c
Difficulty: 🟠🟠⚪⚪⚪
Importance: 🟠🟠🟠🟠⚪

You should spend up to 10-15 minutes on this exercise.
```

Next, you should implement the MLP layer, which consists of:

* A linear layer, with weight `W_in`, bias `b_in`
* A nonlinear functino (we usually use GELU; the function `gelu_new` has been imported for this purpose)
* A linear layer, with weight `W_out`, bias `b_out`


In [20]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch token d_model"]
    ) -> Float[Tensor, "batch token d_model"]:
        # Hide: all
        pre = (
            einops.einsum(
                normalized_resid_mid,
                self.W_in,
                "batch token d_model, d_model d_mlp -> batch token d_mlp",
            )
            + self.b_in
        )
        post = gelu_new(pre)
        mlp_out = (
            einops.einsum(
                post,
                self.W_out,
                "batch token d_mlp, d_mlp d_model -> batch token d_model",
            )
            + self.b_out
        )
        return mlp_out
        # Hide: none


rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["normalized", 0, "ln2"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## (Bonus) LayerNorm

```c
Difficulty: 🟠🟠🟠⚪⚪
Importance: 🟠🟠🟠⚪⚪

You should spend up to 10-15 minutes on this exercise.
```

You should fill in the code below, and then run the tests to verify that your layer is working correctly.

Your LayerNorm should do the following:

* Make mean 0
* Normalize to have variance 1
* Scale with learned weights
* Translate with learned bias

You can use the PyTorch [LayerNorm documentation](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) as a reference. A few more notes:

* Your layernorm implementation always has `affine=True`, i.e. you do learn parameters `w` and `b` (which are represented as $\gamma$ and $\beta$ respectively in the PyTorch documentation).
* Remember that, after the centering and normalization, each vector of length `d_model` in your input should have mean 0 and variance 1.
* As the PyTorch documentation page says, your variance should be computed using `unbiased=False`.
* The `layer_norm_eps` argument in your config object corresponds to the $\epsilon$ term in the PyTorch documentation (it is included to avoid division-by-zero errors).
* We've given you a `debug` argument in your config. If `debug=True`, then you can print output like the shape of objects in your `forward` function to help you debug (this is a very useful trick to improve your coding speed).

In [21]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(
        self, residual: Float[Tensor, "batch token d_model"]
    ) -> Float[Tensor, "batch token d_model"]:
        # Hide: hard
        residual_mean = residual.mean(dim=-1, keepdim=True)
        residual_std = (
            residual.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps
        ).sqrt()

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b
        # Hide: none


rand_float_test(LayerNorm, [2, 4, 768])
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, cache["resid_post", 11])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Transformer Block

```c
Difficulty: 🟠🟠⚪⚪⚪
Importance: 🟠🟠🟠⚪⚪

You should spend up to ~10 minutes on this exercise.
```

Now, we can put together the attention, MLP and layernorms into a single transformer block. Remember to implement the residual connections correctly!


In [22]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, resid_pre: Float[Tensor, "batch token d_model"]
    ) -> Float[Tensor, "batch token d_model"]:
        # Hide: hard
        # First, we add in the attention, but the residual stream needs to be normalized beforehand
        # Hide: all
        resid_mid = resid_pre + self.attn(self.ln1(resid_pre))
        # Hide: hard

        # Then, we add in the MLP, again, the input of the MLP needs to be normalized beforehand
        # Hide: all
        resid_post = resid_mid + self.mlp(self.ln2(resid_mid))
        # Hide: hard

        return resid_post
        # Hide: none


rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Unembedding

```c
Difficulty: 🟠🟠⚪⚪⚪
Importance: 🟠🟠🟠⚪⚪

You should spend up to ~10 minutes on this exercise.
```

The unembedding is jus a linear layer (with weight `W_U` and bias `b_U`).


In [23]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch token d_model"]
    ) -> Float[Tensor, "batch token d_vocab"]:
        # Hide: all
        return (
            einops.einsum(
                normalized_resid_final,
                self.W_U,
                "batch token d_model, d_model d_vocab -> batch token d_vocab",
            )
            + self.b_U
        )
        # Or, could just do `normalized_resid_final @ self.W_U + self.b_U`
        # Hide: none


rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257]) 

100.00% of the values are correct



## Full Transformer

```c
Difficulty: 🟠🟠⚪⚪⚪
Importance: 🟠🟠🟠⚪⚪

You should spend up to ~10 minutes on this exercise.
```


In [24]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens: Int[Tensor, "batch token"]) -> Float[Tensor, "batch token d_vocab"]:
        # Hint: modules are defined in the order they should be used
        # Hide: all
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        logits = self.unembed(self.ln_final(residual))
        return logits
        # Hide: none


rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 50257]) 

Input shape: torch.Size([1, 45])
Output shape: torch.Size([1, 45, 50257])
Reference output shape: torch.Size([1, 45, 50257]) 

100.00% of the values are correct



**Try it out!**


In [25]:
demo_gpt2 = DemoTransformer(Config(debug=False)).to(device)
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)

demo_logits = demo_gpt2(tokens)

Let's take a test string, and calculate the loss!

We're using the formula for **cross-entropy loss**. The cross entropy loss between a modelled distribution $Q$ and target distribution $P$ is:

$$
-\sum_x P(x) \log Q(x)
$$

In the case where $P$ is just the empirical distribution from target classes (i.e. $P(x^*) = 1$ for the correct class $x^*$) then this becomes:

$$
-\log Q(x^*)
$$

in other words, the negative log prob of the true classification.


In [26]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"], tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:
    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    # Hide: hard
    log_probs_for_tokens = (
        log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    )
    # Hide: none

    return log_probs_for_tokens


pred_log_probs = get_log_probs(demo_logits, tokens)
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(f"Avg cross entropy loss for uniform distribution: {math.log(demo_gpt2.cfg.d_vocab):4f}")
print(f"Avg probability assigned to correct token: {pred_log_probs.exp().mean():4f}")

Avg cross entropy loss: 4.0442
Avg cross entropy loss for uniform distribution: 10.824905
Avg probability assigned to correct token: 0.098628


We can also greedily generate text, by taking the most likely next token and continually appending it to our prompt before feeding it back into the model:


In [27]:
test_string = (
    """The Total Perspective Vortex derives its picture of the whole Universe on the principle of"""
)
for i in tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).to(device)
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())

print(test_string)

  0%|          | 0/100 [00:00<?, ?it/s]

The Total Perspective Vortex derives its picture of the whole Universe on the principle of the total perspective. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The total perspective is the view of the whole Universe from the point of view of the observer. The


If you've finished this, congrats! 
You should ask the TA what to do next. One option is 
to look at training and sampling from tranformers in the rest of the arena notebook, which you can find it at https://colab.research.google.com/github/EffiSciencesResearch/ML4G-2.0/blob/master/workshops/transformer/transformer-arena.ipynb.