# Setup

In [1]:
from dataclasses import dataclass
import math
import os
from pprint import pprint

import einops
from jaxtyping import Float, Int
import torch as t
from torch import Tensor
import torch.nn as nn
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new
from tqdm.notebook import tqdm

import C1P1_tests as tests

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

device = t.device(
    "mps"
    if t.backends.mps.is_available()
    else "cuda"
    if t.cuda.is_available()
    else "cpu"
)

MAIN = __name__ == "__main__"

reference_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small",
    fold_ln=False,
    center_unembed=False,
    center_writing_weights=False,
    device=device,
)

  return t.to(


Loaded pretrained model gpt2-small into HookedTransformer


In [21]:
print(device)

cuda


# Clean Transformer Implementation

### Learning Objectives

- Understand that a transformer is composed of attention heads and MLPs, with each one performing operations on the residual stream.

- Understand that attention heads in a single layer operate independently, and that they have the role of calculating attention patterns. These patterns determine where information is moved to and from in the residual stream.

- Implement the following transformer modules:

  - **Embedding**: a lookup table from tokens to residual stream vectors.

  - **Positional embedding**: a lookup table from position indices to residual stream vectors.

  - **LayerNorm**: transforming the input to have zero mean and unit variance.

  - **Attention**: computing attention patterns for residual stream vectors.

    - **Causal Mask**: how we enforce the prediction of the next token to only depend on preceding tokens in the sequence.

  - **MLP**: the collection of linear and nonlinear transformations that operate on the residual stream vectors in the same way.

  - **Unembedding**: converting the residual stream vectors into a distribution over tokens.

## High-level architecture

- See illustration [here](./C1P1_transformer_architecture.png)
- See Neel Nanda's [Tranformer Circuits Walkthrough](https://www.youtube.com/watch?v=KV5gbOmHbjU) for more intuition.

### Tokenisation and embedding

- The input tokens $t$ are integers. We get them from taking a sequence and tokenising it (see previous section).

- The token embedding is a lookup table mapping tokens to vectors. This is implemented as a matrix $W_E$ consisting of a row-stack of token embedding vectors, one per token in the model's vocabulary.

### Residual Stream

- The sum of all previous outputs of layers in the model, and the input to each new layer.

- Shape `(batch, seq_len, d_model)`, where `d_model` is the length of a single embedding vector.

- Initial value is denoted $x_0$ (see [diagram](./C1P1_transformer_architecture.png)). $x_i$ denotes the $i$-th later value of the residual stream after attention and MLP layers have been applied.

- The residual stream is **fundamental**; it is the central object of the transformer. It serves as the model's memory, moves information between layers for composition, and stores the information that attention moves between positions.

  - A key idea of transformers is the residual stream as [output accumulation](https://www.lesswrong.com/posts/X26ksz4p3wSyycKNB/gears-level-mental-models-of-transformer-interpretability#Residual_Stream_as_Output_Accumulation:~:text=The%20Models-,Residual%20Stream%20as%20Output%20Accumulation,-The%20residual%20stream).
    
    - As we move through the model's layers, shifting information around and processing it, the residual stream's values represent the accumulation of all the inferences made by the transformer up to that point.

    - This is neatly illustrated by the **logit lens**, using which we can take the value of the residual stream midway through the model and convert it into a distribution over tokens, rather than only getting predictions from the very end of the model. When we do this, we find surprisingly coherent predictions, especially in the last few layers!

### Transformer blocks

- Transformers consist of a series of `n_layers` of **transformber blocks** (sometimes called **residual blocks**)

- A block contains both an attention layer and an MLP, but we say that *a transformer has $k$ layers if it has $k$ blocks* (i.e., $2k$ total layers).

- See diagram [here](./C1P1_transformer_block.png)

#### Attention

- Moves information from prior positions in the sequence to the current token. 

- Processing is performed for *every* token in parallel using the same parameters. The only difference is that we look *backwards* only (to avoid "cheating"). This means that later tokens have more of the sequence to look at.

- Attention layers are the only component of a transformer that moves information between positions; i.e., between vectors at difference sequence positions in the residual stream.

- Attention layers consits of `n_heads` heads. Each head has its own parameters, attention pattern, and instructions on how to copy information from source to destination.
  - The heads act independently and additively. Their outputs are added together and then to the residual stream.

- Each head:
  - Produces an attention pattern for each desitnation token, representing a probability distribution of preceding source tokens (ncluding the destination token itself!) that weights how much information to copy.
  - Moves information via a linear map in the same way from each source token to each destination token.

- Note:
  - The information to copy depends on the source token's residual steam, but *this doesn't mean it only depends on the value of that token*. The residual steam can store more information than just the token identity.
    - Remember that the *purpose* of attention heads is to *move* information between vectors at different positions in the residual stream.

  - Each attention head consists of two different circuits:
    - The **QK circuit**, which determines *where to move information to and from*.
      - This is a function of the residual stream for the source and destination tokens.

    - The **OV circuit** determines *what information to move*.
      - This is a function only of the source token's residual stream.

  - We can think of attention as a kind of generalised convolution: 
    - Standard convolution layers work by imposing a "prior of locality". I.e., the assumption that pixels that are close together are more likely to share information.

    - Although language has some locality - two words next to each other are more likely to share information than two words 100 tokens apart, the picture is more nuanced. This is because the tokens that are relevant to other tokens depends on the *context* of the text.

    - E.g., in the sentence `"When Mary and John went to the store, John gave a drink to Mary."`, the names in this sentence are the most important tokens for predicting that the final token will be `"Mary"`. This is because of the particular context of this sentence rather than the token's position.
    
    - Attention layers are a way to tell the transformer: *"Don't impose a prior of locality, but instead develop your own algorithm to figure out which tokens are important to which others in any given sequence."* 

See diagram of an attention layer [here](./C1P1_attention_layer.png).

#### MLP

- A standard neural network, with a singular hidden layer and a nonlinear activation function.
  - The specific activation isn't too important conceptually, but [GELU](https://paperswithcode.com/method/gelu) seems to perform best.

- The hidden dimension is normally `d_mlp = 4 * d_model`.
  - The reasons for the ratios aren't too important. People basically cargo-cult what GPT did in the past.

- Importantly, the MLP **operates on positions in the residual stream independently, and in exactly the same way**. It doesn't move information between positions.

- Once attention has moved relevant information to a single position in the residual stream, MLPs can actually do *computation*, *reasoning*, *information lookup*, etc.
  - What is actually going on inside MLPs remains a big open problem in transformer mechanistic interpretability.

  - The [Toy Models of Superposition paper](https://transformer-circuits.pub/2022/toy_model/index.html) helps explain why this is hard.

- See diagram of an MLP layer [here](./C1P1_mlp.png)




##### MLPs as key-value pairs

We can write an MLP's output as:

$$
f(x^T W^{in})W^{out}
$$
where,
- $W^{in}$ and $W^{out}$ are the different weights of the MLP, ignoring biases,
- $f$ is the activation function, and
- $x$ is a vector in the residual stream.

This can be rewritten as:

$$
f(x^T W^{in})W^{out} = \sum^{d_{mlp}}_{i=1}f(x^T W^{in}_{[:, i]})W^{out}_{[i, :]}
$$

where we can view $W^{in}_{[:, i]}$ as the **input directions** and $W^{out}_{[i, :]}$ as the **output directions**. We say the input directions are **activated** by certain textual features, and when they are activated, vectors are written in the corresponding output direction. 
- This is very similar to the concept of keys and values in attention layers, which is why these vectors are sometimes called keys and values (e.g., see [this paper](https://arxiv.org/pdf/2012.14913.pdf))

**Note**: sometimes we refer to each of these $d_{mlp}$ input-output pairs as **neurons** - see diagram [here](./C1P1_mlp_neurons.png)

---

A step-by-step breakdown of the linear algebra:

$$
\begin{aligned}
x^T W^{in} &= x^T [W^{in}_{[:, 1]}\,, ...\;, W^{in}_{[:, n]}] \\
&= (x^T W^{in}_{[:, 1]}\,, \; ...\;, \; x^T W^{in}_{[:, n]})
\end{aligned}
$$

where $W^{in}_{[:, i]}$ are the columns of $W^{in}$. In other words, these values (the pre-GELU activations) are projections of $x$ along the input directions of the neurons.

If we add our activation function and the second matrix, then we get:

$$
\begin{aligned}
f(x^T W^{in})W^{out} &= (f(x^T W^{in}_{[:, 1]})\,, \; ...\;,\; f(x^T W^{in}_{[:, n]})) \begin{bmatrix} \leftarrow W^{out}_{[1, :]} \rightarrow \\ \vdots \\ \leftarrow W^{out}_{[n, :]} \rightarrow \end{bmatrix} \\
&= f(x^T W^{in}_{[:, 1]}) W^{out}_{[1, :]} + \;...\; + f(x^T W^{in}_{[:, n]}) W^{out}_{[n, :]} \\
&= \sum_{i=1}^n f(x^T W^{in}_{[:, i]}) W^{out}_{[i, :]}
\end{aligned}
$$

where $W^{out}_{[i, :]}$ are the rows of $W^{out}$. In other words, our output is a linear combination of the rows of $W^{out}$, with the coefficients of that linear combination given by the projections of $x$ along the columns of $W^{in}$.

##### MLPs as knowledge storage

- The attention mechanism is what moves information around between sequence positions.

- The MLPs process this information, and new information is written into the residual stream which is a function of the old information.

- The key-value pairs model can be applied to MLPs as a kind of associative memory system, where the key serves as a unique identifier, and the value holds the related information.

- MLPs can also be modelled as **memory management**.
  - In an idealised case, we might find that the $i$-th neuron satisfies $W^{in}_{[:, i]} \approx -W^{out}_{[i, :]} \approx \vec{v}$ for some unit vector $\vec{v}$, meaning that it may be responsible for erasing the positive component of $\vec{x}$ in the direction $\vec{v}$. This can free up space in the residual stream for other components to write to.

#### Unembedding

- The model's output!

- Applies a linear map $W_U$, going from the final residual stream to a vector of logits.

##### Tied embeddings

- Where the same weights are used for $W_E$ and $W_U$

- If using a tied embedding, to get the logit score for a particular token at some sequence position, we just take the vector in the residual stream at that sequence position and take the inner product with the corresponding token embedding vector.

- This is more training efficient: there are fewer parameters in the model.

- Perhaps seems principled at first?
  - If two words have very similar meanings, shouldn't they have similar embedding vectors because the model will treat them the same? And similar unembedding vectors because they could both be substituted for each other in most output?

- Not principled, since **the direct path involving the embedding and unembedding should approximate bigram frequencies.

  - Bigram frequencies refer to the frequencies of pairs of words in the English language. E.g., the bigram frequency of "Barack Obama" is much higher than the product of the individual frequencies of the words "Barack" and "Obama".

  - If our model had no attention heads or MLP layers, then all that is left is a linear map from our on-hot encoded token `T` to a probability distribution over the token following `T`. This map is represented by the linear transformation $t \rightarrow t^T W_E W_U$, where $t$ is our one-hot encoded token vector.

  - Since the output of this transformation can only be a function of `T` and no earlier tokens (no attention layers!), the best we can do is have this map approximate the true frequency of bigrams starting with `T` that appear in the training data.

  - Importantly, **this is not a symmetric map**.
    - We want `T = "Barack"` to result in a high probability of the next token being `" Obama"`, but not the other way around!

- A weaker version of this principle still applies in multi-layer models. Although there are more paths through the model than just the direct path, $W_E W_U$, there will still always exist a direct path, and therefore some nonzero incentive for $W_E W_U$ to approximate bigram frequencies.

#### LayerNorm

- A simple normalisation function applied at the start of each layer.
  - I.e., before each attention layer, MLP, and before the unembedding.

- Converts each input vector to have zero mean and unity variance.
  - This is done independently and in parallel for each `(batch, seq)` residual stream vector.

- An elementwise scaling translation is then applied.
  - The scale ($\odot \gamma$) and translate ($+ \beta$) is just a linear map.
  - *LayerNorm* is only applied immediately before another linear map; either the MLP, or the query/key/value linear maps in the attention head, or the unembedding $W_U$.
  - Since *linear compose linear = linear*, we can just fold this into a single effective linear layer and ignore it.

- *LayerNorm* is annoying for interpretability: division by the variance strictly makes it nonlinear, so the contributions of the input to the output cannot be independently decomposed. That said, it is *almost* linear: if you're changing a small part of the input, you can pretend that $\sqrt{\textrm{Var}[x] + \epsilon}$ is constant so that the LayerNorm operation is linear. But if you change $x$ enough to substantially alter the norm, it's not linear.

- See diagram [here](./C1P1_layernorm.png)

#### Positional embeddings

- Attention operates over all pairs of positions. This means that it is symmetric with respect to position; the attention calculation from token 5 to token 1 and token 5 to token 2 are the same by default.
  - We can do better than this, since nearby tokens are more relevant.

- There are a lot of hacks for this.

- We'll focus on **learned, absolute positional embeddings**.
  - We learn a lookup table mapping the index of the position of each token to a residual stream vector, and this to the embed.
    - Note that we *add* rather than concatenate. The residual stream is shared memory, and likely under significant superposition; i.e., the model compresses more features in there than the model has dimensions.
    - We almost never concatenate inside a transformer, except for, say, efficiently generating text.

- This connects to **attention as generalised convolution**.
  - We argued that language still does have some locality, so it's helpful for transformers to have access to the positional information so that they "know" two tokens are next to each other, and hence probably relevant to each other.

## Actual Code!

### Key

```python
batch = 1
position = 35
d_model = 768
n_heads = 12
n_layers = 12
d_mlp = 3072 # = 4 * d_model
d_head = 64 # = d_model / n_heads
```

### Test run the model

In [2]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"

tokens = reference_gpt2.to_tokens(reference_text).to(device)

logits, cache = reference_gpt2.run_with_cache(tokens, device=device)

### Parameters and Activations

- Important to distinguish these!

#### Parameters

- **are the weights and biases learned during training.**

- Parameters do not change when the model input changes.

#### Activations

- **are temporary numbers calculated during a forward pass that are functions of the input.**

- Activations can be thought of as only existing for the duration of a single forward pass and disappearing afterwards.

- Hooks can be used to access these values during a forward pass, but it doesn't make sense to talk about a model's activations outside the context of some particular input.

- Attention scores and patterns are activations.
  - This is slightly counterintuitive, since they're used in a matrix multiplication with another activation.

#### All Activation Shapes of the Reference Model

In [3]:
for activation_name, activation in cache.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(f"{activation_name:30} {tuple(activation.shape)}")

hook_embed                     (1, 35, 768)
hook_pos_embed                 (1, 35, 768)
blocks.0.hook_resid_pre        (1, 35, 768)
blocks.0.ln1.hook_scale        (1, 35, 1)
blocks.0.ln1.hook_normalized   (1, 35, 768)
blocks.0.attn.hook_q           (1, 35, 12, 64)
blocks.0.attn.hook_k           (1, 35, 12, 64)
blocks.0.attn.hook_v           (1, 35, 12, 64)
blocks.0.attn.hook_attn_scores (1, 12, 35, 35)
blocks.0.attn.hook_pattern     (1, 12, 35, 35)
blocks.0.attn.hook_z           (1, 35, 12, 64)
blocks.0.hook_attn_out         (1, 35, 768)
blocks.0.hook_resid_mid        (1, 35, 768)
blocks.0.ln2.hook_scale        (1, 35, 1)
blocks.0.ln2.hook_normalized   (1, 35, 768)
blocks.0.mlp.hook_pre          (1, 35, 3072)
blocks.0.mlp.hook_post         (1, 35, 3072)
blocks.0.hook_mlp_out          (1, 35, 768)
blocks.0.hook_resid_post       (1, 35, 768)
ln_final.hook_scale            (1, 35, 1)
ln_final.hook_normalized       (1, 35, 768)


#### All Parameter Shapes of the Reference Model

In [4]:
for name, param in reference_gpt2.named_parameters():
    print(f"{name:30} {tuple(param.shape)}")

embed.W_E                      (50257, 768)
pos_embed.W_pos                (1024, 768)
blocks.0.ln1.w                 (768,)
blocks.0.ln1.b                 (768,)
blocks.0.ln2.w                 (768,)
blocks.0.ln2.b                 (768,)
blocks.0.attn.W_Q              (12, 768, 64)
blocks.0.attn.W_O              (12, 64, 768)
blocks.0.attn.b_Q              (12, 64)
blocks.0.attn.b_O              (768,)
blocks.0.attn.W_K              (12, 768, 64)
blocks.0.attn.W_V              (12, 768, 64)
blocks.0.attn.b_K              (12, 64)
blocks.0.attn.b_V              (12, 64)
blocks.0.mlp.W_in              (768, 3072)
blocks.0.mlp.b_in              (3072,)
blocks.0.mlp.W_out             (3072, 768)
blocks.0.mlp.b_out             (768,)
blocks.1.ln1.w                 (768,)
blocks.1.ln1.b                 (768,)
blocks.1.ln2.w                 (768,)
blocks.1.ln2.b                 (768,)
blocks.1.attn.W_Q              (12, 768, 64)
blocks.1.attn.W_O              (12, 64, 768)
blocks.1.attn.b_Q 

### Diagram of full annotated TransformerLens archecture [here](./C1P1_transformerlens_full_architecture.png)

### Config

This config object contains all the **hyperparameters** of the model.

In [5]:
# As a reference - note there's a lot of stuff we don't care about in here, to do with library internals or other architectures
print(reference_gpt2.cfg)

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LN',
 'num_experts': None,
 'original_architecture': 'GPT2LMHeadModel',
 'output_logits_soft_cap': -1.0,
 'parallel_attn_mlp': False,
 'positional

#### A stripped-down version for our model

In [6]:
@dataclass
class Config:
    d_head: int = 64
    d_mlp: int = 3072
    d_model: int = 768
    d_vocab: int = 50257
    debug: bool = True
    init_range: float = 0.02
    layer_norm_eps: float = 1e-5
    n_ctx: int = 1024
    n_heads: int = 12
    n_layers: int = 12


cfg = Config()
pprint(cfg)

Config(d_head=64,
       d_mlp=3072,
       d_model=768,
       d_vocab=50257,
       debug=True,
       init_range=0.02,
       layer_norm_eps=1e-05,
       n_ctx=1024,
       n_heads=12,
       n_layers=12)


## Tests

In [7]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randn(shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output shape:", output.shape, "\n")


def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randint(100, 1000, shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output shape:", output.shape, "\n")


def load_gpt2_test(cls, gpt2_layer, input):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    print("Input shape:", input.shape)
    output = layer(input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output shape:", output.shape)
    try:
        reference_output = gpt2_layer(input)
    except Exception:
        reference_output = gpt2_layer(input, input, input)
    print("Reference output shape:", reference_output.shape, "\n")
    comparison = t.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct\n")

## LayerNorm

Should:
- Make mean zero

- Noramlise to have variance 1

- Scale with learned weights

- Translate with learned bias.

Use PyTorch's [LayerNorm docs](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html).

More notes:
- This LayerNorm implementation always has `affine=True`. I.e., we will learn the parameters `w` and `b`, represented as $\gamma$ and $\beta$ in the PyTorch docs.

- After centering and normalisation, each *vector* of length `d_model` should have mean 0 and variance 1

- Variance should be computed as `unbiased=False` as per PyTorch docs

- `layer_norm_eps` corresponds to the $\epsilon$ term in the PyTorch docs. It avoids division-by-zero errors.

- If `debug=True`, you can print output like the shape of objects in your `forward` method to help debug.

In [8]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(
        self, residual: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        mean = residual.mean(dim=-1, keepdim=True)
        var = residual.var(dim=-1, keepdim=True, unbiased=False)

        res = (residual - mean) / t.sqrt(var + self.cfg.layer_norm_eps)
        res_scaled = res * self.w + self.b

        return res_scaled


rand_float_test(LayerNorm, [2, 4, 768])
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, cache["resid_post", 11])
zero_input = t.zeros_like(cache["resid_post", 11]).to(device)
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, zero_input)

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Embedding

Simple lookup table from tokens to residual stream vectors

In [9]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(
        self, tokens: Int[Tensor, "batch position"]
    ) -> Float[Tensor, "batch position d_model"]:
        return self.W_E[tokens]


rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Positional Embedding

Also a lookup table, but rather than the indices being our token IDs, the indices are just the numbers `0, 1, 2, ..., seq_len-1`. I.e., the position indices of the tokens in the sequence.

In [10]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(
        self, tokens: Int[Tensor, "batch position"]
    ) -> Float[Tensor, "batch position d_model"]:
        batch, seq_len = tokens.shape
        # W_pos has a maximum length of n_ctx. If seq_len is less than n_ctx, we only take the first seq_len positions
        # W_pos is the same for all batches, so we repeat it along the batch dimension
        return self.W_pos[:seq_len].repeat(batch, 1, 1)
        # return einops.repeat(self.W_pos[:seq_len], "position d_model -> batch position d_model", batch=batch)


rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Causal Mask

- A method of the `Attention` class.

- Takes attention scores and applies a mask to them so that the model can only attend to previous positions.

Hints:
- Use `t.where`, or `t.masked_fill_` for masking.
- `t.triu` is useful for mask creation
- `self.IGNORE` attribute should be used to set masked positions to negative infinity.

Why mask attention scores to neg inf, rather than the attention probabilities to zero?
- If we masked the attention probabilities, then the probabilities would no longer sum to 1.
- Want to mask scores and *then* apply *softmax*, so that the probabilities are still valid probailities (i.e., they sum to 1), and the values in the masked positions do not influence the model's ouptut.

In [11]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.register_buffer(
            "IGNORE", t.tensor(float("-inf"), device=device, dtype=t.float32)
        )

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        mask = t.triu(
            t.ones(attn_scores.shape[-2:], device=attn_scores.device), diagonal=1
        )
        # remember, broadcasting starts with trailing dimensions
        return attn_scores.masked_fill(mask == 1, self.IGNORE)


tests.test_causal_mask(Attention.apply_causal_mask)

All tests in `test_causal_mask` passed!


## Attention

### Step 1
- Produce an attention pattern. I.e., a probability over previous tokens (including current token) for each destination token.

  - Linear map from input -> query, key shape `(batch, head_index, query_pos, key_pos)`

  - Dot product every *pair* of queries and keys to get attention scores `(batch, head_index, query_pos, key_pos)` (query=dest, key=source)

  - Scale and mask `attn_scores` to make it lower triangular; i.e., causal.

  - Softmax along the `key_pos` dimension to get a probability distribution for each query (destination) token
    - This is the attention pattern!

### Step 2
- Move information from source tokens to destination token using attention pattern (move = apply linear map)

  - Linear map from input -> value `(batch, key_pos, head_index, d_head)`

  - Mix along the `key_pos` with attn pattern to get `z`, the weighted average of the value vectors `(batch, query_pos, head_index, d_head)`

  - Map to output, `(batch, position, d_model)` (position=query_pos, we've summed over all heads)


### Notes

- Scale means dividing by `sqrt(d_head)`. This avoids vanishing gradients; a big problem when dealing with *softmax*. If one of the values is much larger than all the others, the probabilities will be close to 0 or 1, and the gradients will be close to 0.

- Usually we have the relation `e = n * h` (i.e., `d_model = num_heads * d_head`). There are some computational justifications for this, but mostly this is just convention, just like `d_mlp = 4 * d_model`.

- The names **keys**, **queries**, and **values** come from their analogy to retrieval systems. Broadly speaking:
  - **queries** represent some information that a token is *looking for*

  - **keys** represent the information that a token *contains*
    - So that the attention score being high means that the source (key) token contains the information which the destination (query) token is *looking for*.

  - **Values** represent the actual information that is taken from the source token, to be moved to the destination token.

- Should be getting at least 99% accuracy on tests. Small tweaks like the order of `einsum` args can result in slightly different outputs, so not a big deal if 100% isn't achieved.

- Don't forget attention score scaling! Comes before masking.

- Overwrite the earlier `Attention` class with a new implementation that initialises all parameters, but copy in the `apply_causal_mask` function.


### Useful, highly detailed diagram of attention [here](./C1P1_attention_layer_detailed.png). 

- The diagram helps illustrate the difference between the **QK** and **OV** circuits.

#### The **QK** circuit
- Consists of the operation of the $W_Q$ and $W_K$ matrices. It determines the attention pattern; i.e., where information is moved to and from in the residual stream. The function form of attention pattern $A$ is:

$$
A=\textrm{softmax}(\frac{x^T W_Q W^T_Kx}{\sqrt{d_{head}}})
$$

where
- $x$ is the residual stream (shape `(seq_len, d_model)`), and
- $W_Q$ and $W_K$ are the weight matrices for a single head (shape `(d_model, d_head)`)


#### The **OV** circuit
- Consists of the operation of the $W_V$ and $W_O$ matrices. Once the attention patterns are fixed, these matrices operate on the residual stream at the source position. Their output is what gets moved from the source to the destination position.

The function form of an entire attention head is:

$$
\begin{align}
\textrm{output}&=\textrm{softmax}\left(\frac{x^T W_Q W^T_Kx}{\sqrt{d_{head}}}\right)(x W_V W_O) \\
&= A x W_V W_O
\end{align}
$$

where
- $W_V$ had shape `(d_model, d_head)`
- $W_O$ has shape `(d_head, d_model)`


We can cleary see the QK and OV circuits doing different things. They should be thought of as two distinct parts of the attention head.

In [12]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer(
            "IGNORE", t.tensor(float("-inf"), device=device, dtype=t.float32)
        )

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        """
        Applies a causal mask to attention scores, and returns masked scores.
        """
        mask = t.triu(
            t.ones(attn_scores.shape[-2:], device=attn_scores.device), diagonal=1
        )
        # remember, broadcasting starts with trailing dimensions
        return attn_scores.masked_fill(mask == 1, self.IGNORE)

    def forward(
        self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        batch, seq_len, d_model = normalized_resid_pre.shape

        K = (
            einops.einsum(
                self.W_K,
                normalized_resid_pre,
                "n_heads d_model d_head, batch posn d_model -> batch posn n_heads d_head",
            )
            + self.b_K
        )
        Q = (
            einops.einsum(
                self.W_Q,
                normalized_resid_pre,
                "n_heads d_model d_head, batch posn d_model -> batch posn n_heads d_head",
            )
            + self.b_Q
        )

        attn_scores = einops.einsum(
            Q,
            K,
            "batch posn_Q n_heads d_head, batch posn_K n_heads d_head -> batch n_heads posn_Q posn_K",
        )
        attn_scores_scaled = attn_scores / math.sqrt(self.cfg.d_head)
        attn_scores_scaled_masked = self.apply_causal_mask(attn_scores_scaled)

        attn_probs = attn_scores_scaled_masked.softmax(dim=-1)

        V = (
            einops.einsum(
                self.W_V,
                normalized_resid_pre,
                "n_heads d_model d_head, batch posn d_model -> batch posn n_heads d_head",
            )
            + self.b_V
        )
        z = einops.einsum(
            attn_probs,
            V,
            "batch n_heads posn_Q posn_K, batch posn_K n_heads d_head -> batch posn_Q n_heads d_head",
        )

        return (
            einops.einsum(
                z,
                self.W_O,
                "batch posn n_heads d_head, n_heads d_head d_model -> batch posn d_model",
            )
            + self.b_O
        )


tests.test_causal_mask(Attention.apply_causal_mask)
rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["normalized", 0, "ln1"])

All tests in `test_causal_mask` passed!
Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## MLP

Implement:
- A linear layer, with weight `W_in`, bias `b_in`

- A nonlinear function (usually GELU); can use the imported function `gelu_new`

- A linear layer, weight `W_out`, bias `b_out`

In [13]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        h = gelu_new(
            einops.einsum(
                normalized_resid_mid,
                self.W_in,
                "batch posn d_model, d_model d_mlp -> batch posn d_mlp",
            )
            + self.b_in
        )
        return (
            einops.einsum(
                h, self.W_out, "batch posn d_mlp, d_mlp d_model -> batch posn d_model"
            )
            + self.b_out
        )


rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["normalized", 0, "ln2"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Transformer Block

Put together the attention, MLP and layernorms into a single transformer block.

- Remember to implement the residual connections correctly!

In [14]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, resid_pre: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_model"]:
        normalised_resid_pre = self.ln1(resid_pre)
        attn_output = self.attn(normalised_resid_pre)
        resid_mid = resid_pre + attn_output
        normalised_resid_mid = self.ln2(resid_mid)
        mlp_output = self.mlp(normalised_resid_mid)
        resid_post = resid_mid + mlp_output
        return resid_post


rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Unembedding

Linear layer with weight `W_U` and bias `b_U`

In [15]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        return (
            einops.einsum(
                normalized_resid_final,
                self.W_U,
                "batch posn d_model, d_model d_vocab -> batch posn d_vocab",
            )
            + self.b_U
        )


rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257]) 

100.00% of the values are correct



## Full Transformer

In [16]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg.n_layers)]
        )
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(
        self, tokens: Int[Tensor, "batch position"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)

        resid = embed + pos_embed

        for block in self.blocks:
            resid = block(resid)

        resid_final = self.ln_final(resid)
        return self.unembed(resid_final)


rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 50257]) 

Input shape: torch.Size([1, 35])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257]) 

100.00% of the values are correct



### Try it out!

- Create a new instance of our `DemoTransformer`.

- Since our local implementation matches GPT-2's architecture, we can load pre-learned parameters into our `DemoTransformer` from our (pre-loaded) reference GPT-2 model.
  - These are contained in `state_dict`

In [17]:
demo_gpt2 = DemoTransformer(Config(debug=False)).to(device)
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)

print("Test set of tokens:")
print(reference_gpt2.to_str_tokens(tokens))

demo_logits = demo_gpt2(tokens)  # Run our custom model on the same input

Test set of tokens:
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


### Calculate the loss

We'll use **cross-entropy loss**. The cross entropy loss between a modelled distribution $Q$ and a target distribution $P$ is given by:

$$
\textrm{loss} = -\sum_x{P(x) \log{Q(x)}}
$$

In the case where $P$ is the empirical distribution from target classes - i.e., $P(x^*)=1$ for the correct class $x^*$ - this becomes:

$$
\textrm{loss} = -\log{Q(x)}
$$

I.e., the negative log prob of the true classification.

In [18]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"], tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:
    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = (
        log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    )

    return log_probs_for_tokens


pred_log_probs = get_log_probs(demo_logits, tokens)
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(
    f"Avg cross entropy loss for uniform distribution: {math.log(demo_gpt2.cfg.d_vocab):4f}"
)
print(f"Avg probability assigned to correct token: {pred_log_probs.exp().mean():4f}")

Avg cross entropy loss: 4.5647
Avg cross entropy loss for uniform distribution: 10.824905
Avg probability assigned to correct token: 0.087910


The solutions gave the following output:

```terminal
Avg cross entropy loss: 4.0442
Avg cross entropy loss for uniform distribution: 10.824905
Avg probability assigned to correct token: 0.098628
```

This is similar to the results above - a good sign! But consistently lower... why?

### Generate text

Use "greedy" approach for now; i.e., take the most likely next token and continually append it to our prompt before feeding back into the model: 

In [19]:
test_string = """The Total Perspective Vortex derives its picture of the whole Universe on the principle of"""
for i in tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).to(device)
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())

  0%|          | 0/100 [00:00<?, ?it/s]

In [20]:
pprint(test_string)

('The Total Perspective Vortex derives its picture of the whole Universe on '
 'the principle of the total perspective. The total perspective is the view of '
 'the whole Universe from the point of view of the observer. The total '
 'perspective is the view of the whole Universe from the point of view of the '
 'observer. The total perspective is the view of the whole Universe from the '
 'point of view of the observer. The total perspective is the view of the '
 'whole Universe from the point of view of the observer. The total perspective '
 'is the view of the whole Universe from the point of view of the observer. '
 'The')
