
## Introduction

Thanks to Callum McDougall for creating mauch of this notebooks content!

This exercise should get you familiar with concepts in circuits interpretability in transformers.

## Content & Learning Objectives

#### 1️⃣ Find Inductive Heads via attention patterns

You'll plot attnetion patterns to check out theory of Inductive Heads, and see if you can find them in the model you're working with.
> ##### Learning objectives
> 
> - Use `circuitsvis` to visualise attention heads
> - Understand what the theory of inductive heads predicts about attention patterns
> - Use attention patterns to identify inductive heads
> - Automate this process to find inductive heads in a larger model

#### 2️⃣ Logit Attribution

Here, you'll learn how to use TransfomerLens to implement LogitLens, a tool for attributing logit values to specific components of the model. You'll also learn how to use this tool to identify basic attention heads, that are imortant for Induction tasks
> ##### Learning objectives
> - Perform direct logit attribution to figure out which heads are writing to the residual stream in a significant way
> - Crosscheck your earlier restuls with the results




In [1]:
import os

os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
from pathlib import Path
import torch as t
from torch import Tensor
import numpy as np
import einops
from tqdm.notebook import tqdm
import plotly.express as px
import webbrowser
import re
import itertools
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set, Union
from functools import partial
from IPython.display import display, HTML
from rich.table import Table, Column
from rich import print as rprint
import circuitsvis as cv
from pathlib import Path
from einops import repeat
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
from transformer_lens import HookedTransformerConfig
from huggingface_hub import hf_hub_download
t.set_grad_enabled(False)
device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")



## Loading and Running Models

We will be using two models today:

GPT-2 small, a smaller version of the GPT-2 model, with 12 layers and 80 million parameters.

A two-layer attention only transfomer model.

This toy model is a 2L attention-only transformer trained specifically for today. Some changes to make them easier to interpret:
- It has only attention blocks.
- The positional embeddings are only added to the residual stream before each key and query vector in the attention layers as opposed to the token embeddings - i.e. we compute queries as `Q = (resid + pos_embed) @ W_Q + b_Q` and same for keys, but values as `V = resid @ W_V + b_V`. This means that **the residual stream can't directly encode positional information**.
    - This turns out to make it *way* easier for induction heads to form, it happens 2-3x times earlier - [see the comparison of two training runs](https://wandb.ai/mechanistic-interpretability/attn-only/reports/loss_ewma-22-08-24-11-08-83---VmlldzoyNTI0MDMz?accessToken=8ap8ir6y072uqa4f9uinotdtrwmoa8d8k2je4ec0lyasf1jcm3mtdh37ouijgdbm) here. (The bump in each curve is the formation of induction heads.)
    - The argument that does this below is `positional_embedding_type="shortformer"`.
- It has no MLP layers, no LayerNorms, and no biases.
- There are separate embed and unembed matrices (i.e. the weights are not tied).


In [2]:
gpt2_small: HookedTransformer = HookedTransformer.from_pretrained("gpt2-small")
gpt2_small.set_use_attn_result(True)

cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True,  # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b",
    seed=398,
    use_attn_result=True,
    normalization_type=None,  # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer",
)

REPO_ID = "callummcdougall/attn_only_2L_half"
FILENAME = "attn_only_2L_half.pth"

weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
toy_transformer: HookedTransformer = HookedTransformer(cfg)

toy_transformer.load_state_dict(t.load(weights_path, map_location=device))


Loaded pretrained model gpt2-small into HookedTransformer


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


<All keys matched successfully>


## Caching all Activations

The first basic operation when doing mechanistic interpretability is to break open the black box of the model and look at all of the internal activations of a model. This can be done with `logits, cache = model.run_with_cache(tokens)`. Let's try this out, on the first sentence from the GPT-2 paper.

<details>
<summary>Aside - a note on <code>remove_batch_dim</code></summary>

Every activation inside the model begins with a batch dimension. Here, because we only entered a single batch dimension, that dimension is always length 1 and kinda annoying, so passing in the `remove_batch_dim=True` keyword removes it. 

`gpt2_cache_no_batch_dim = gpt2_cache.remove_batch_dim()` would have achieved the same effect.
</details>

In [3]:
text = "Harry James Potter Evans-Verres went to ML4good, to learn about transformers, that is why Harry James Potter Evans-Verres did this notebook in ML4good."

gpt_tokens = gpt2_small.to_tokens(text)
toymodel_tokens = toy_transformer.to_tokens(text)

gpt_logits, gpt_cache = gpt2_small.run_with_cache(gpt_tokens, remove_batch_dim=True)
toymodel_logits, toymodel_cache = toy_transformer.run_with_cache(toymodel_tokens, remove_batch_dim=True)

If you inspect the `gpt_cache` object, you should see that it contains a very large number of keys, each one corresponding to a different activation in the model. You can access the keys by indexing the cache directly, or by a more convenient indexing shorthand. For instance, the code:


In [4]:
gpt_attn_patterns_layer_0 = gpt_cache["pattern", 0]
toymodel_attn_patterns_layer_0 = toymodel_cache["pattern", 0]

In [5]:
#print shapes
print(gpt_attn_patterns_layer_0.shape)
print(toymodel_attn_patterns_layer_0.shape)

torch.Size([12, 38, 38])
torch.Size([12, 38, 38])


returns the same thing as:

In [6]:
gpt_attn_patterns_layer_0_copy = gpt_cache["blocks.0.attn.hook_pattern"]

t.testing.assert_close(gpt_attn_patterns_layer_0, gpt_attn_patterns_layer_0_copy)


<details>
<summary>Aside: <code>utils.get_act_name</code></summary>

The reason these are the same is that, under the hood, the first example actually indexes by `utils.get_act_name("pattern", 0)`, which evaluates to `"blocks.0.attn.hook_pattern"`.

In general, `utils.get_act_name` is a useful function for getting the full name of an activation, given its short name and layer number.

You can use the diagram from the **Transformer Architecture** section to help you find activation names.
</details>




## Visualising Attention Heads

A key insight from the Mathematical Frameworks paper is that we should focus on interpreting the parts of the model that are intrinsically interpretable - the input tokens, the output logits and the attention patterns. Everything else (the residual stream, keys, queries, values, etc) are compressed intermediate states when calculating meaningful things. So a natural place to start is classifying heads by their attention patterns on various texts.

When doing interpretability, it's always good to begin by visualising your data, rather than taking summary statistics. Summary statistics can be super misleading! But now that we have visualised the attention patterns, we can create some basic summary statistics and use our visualisations to validate them! (Accordingly, being good at web dev/data visualisation is a surprisingly useful skillset! Neural networks are very high-dimensional object.)

Let's visualize the attention pattern of all the heads in layer 0, using [Alan Cooney's CircuitsVis library](https://github.com/alan-cooney/CircuitsVis) (based on Anthropic's PySvelte library). If you did the previous set of exercises, you'll have seen this library before.

We will use the function `cv.attention.attention_patterns`, which takes the following arguments:

* `attention`: Attention head activations. 
    * This should be a tensor of shape `[nhead, seq_dest, seq_src]`, i.e. the `[i, :, :]`th element is the grid of attention patterns (probabilities) for the `i`th attention head.
    * We get this by indexing our `gpt2_cache` object.
* `tokens`: List of tokens (e.g. `["A", "person"]`). 
    * Sequence length must match that inferred from `attention`.
    * This is used to label the grid.
    * We get this by using the `gpt2_small.to_str_tokens` method.
* `attention_head_names`: Optional list of names for the heads.

There are also other circuitsvis functions, e.g. `cv.attention.attention_pattern` (for just a single head) or `cv.attention.attention_heads` (which has the same syntax and but presents the information in a different form).

<details>
<summary>Help - my <code>attention_heads</code> plots are behaving weirdly (e.g. they continually shrink after I plot them).</summary>

This seems to be a bug in `circuitsvis` - on VSCode, the attention head plots continually shrink in size.

Until this is fixed, one way to get around it is to open the plots in your browser. You can do this inline with the `webbrowser` library:

```python
attn_heads = cv.attention.attention_heads(
    tokens=gpt2_str_tokens, 
    attention=attention_pattern,
    attention_head_names=[f"L0H{i}" for i in range(12)],
)

path = "attn_heads.html"

with open(path, "w") as f:
    f.write(str(attn_heads))

webbrowser.open(path)
```

To check exactly where this is getting saved, you can print your current working directory with `os.getcwd()`.
</details>

This visualization is interactive! Try hovering over a token or head, and click to lock. The grid on the top left and for each head is the attention pattern as a destination position by source position grid. It's lower triangular because GPT-2 has **causal attention**, attention can only look backwards, so information can only move forwards in the network.


In [7]:
print(type(gpt_cache))
attention_pattern = gpt_cache["pattern", 10]
print(attention_pattern.shape)
str_toknes = gpt2_small.to_str_tokens(text)

print("Layer 0 Head Attention Patterns:")
display(
    cv.attention.attention_patterns(
        tokens=str_toknes,
        attention=attention_pattern
    )
)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 38, 38])
Layer 0 Head Attention Patterns:


In [8]:
print(type(toymodel_cache))
attention_pattern = toymodel_cache["pattern", 1]
print(attention_pattern.shape)
str_toknes = toy_transformer.to_str_tokens(text)

print("Layer 0 Head Attention Patterns:")
display(
    cv.attention.attention_patterns(
        tokens=str_toknes,
        attention=attention_pattern
    )
)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 38, 38])
Layer 0 Head Attention Patterns:



Hover over heads to see the attention patterns; click on a head to lock it. Hover over each token to see which other tokens it attends to (or which other tokens attend to it - you can see this by changing the dropdown to `Destination <- Source`).

<details>
<summary>Other circuitsvis functions - neuron activations</summary>

The `circuitsvis` library also has a number of cool visualisations for **neuron activations**. Here are some more of them (you don't have to understand them all now, but you can come back to them later).

The function below visualises neuron activations. The example shows just one sequence, but it can also show multiple sequences (if `tokens` is a list of lists of strings, and `activations` is a list of tensors).

```python
neuron_activations_for_all_layers = t.stack([
    gpt2_cache["post", layer] for layer in range(gpt2_small.cfg.n_layers)
], dim=1)
# shape = (seq_pos, layers, neurons)

cv.activations.text_neuron_activations(
    tokens=gpt2_str_tokens,
    activations=neuron_activations_for_all_layers
)
```

The next function shows which words each of the neurons activates most / least on (note that it requires some weird indexing to work correctly).

```python
neuron_activations_for_all_layers_rearranged = utils.to_numpy(einops.rearrange(neuron_activations_for_all_layers, "seq layers neurons -> 1 layers seq neurons"))

cv.topk_tokens.topk_tokens(
    # Some weird indexing required here ¯\_(ツ)_/¯
    tokens=[gpt2_str_tokens], 
    activations=neuron_activations_for_all_layers_rearranged,
    max_k=7, 
    first_dimension_name="Layer", 
    third_dimension_name="Neuron",
    first_dimension_labels=list(range(12))
)
```
</details>

# Finding induction heads



             
Use the [diagram at this link](https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/small-merm.svg) to remind yourself of the relevant hook names.


### Exercise - visualise attention patterns

```yaml
Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to ~10 minutes on this exercise.

It's important to be comfortable using circuitsvis, and the cache object.
```

*This exercise should be very quick - you can reuse code from the previous section. You should look at the solution if you're still stuck after 5-10 minutes.*

Visualise the attention patterns for both layers of your model, on the following prompt:


In [None]:
# YOUR CODE HERE - visualize attention

<details>
<summary>Solution </summary>

We visualise attention patterns with the following code:

```python
str_tokens = model.to_str_tokens(text)
for layer in range(model.cfg.n_layers):
    attention_pattern = cache["pattern", layer]
    display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))
```

We notice that there are three basic patterns which repeat quite frequently:

* `prev_token_heads`, which attend mainly to the previous token (e.g. head `0.7`)
* `current_token_heads`, which attend mainly to the current token (e.g. head `1.6`)
* `first_token_heads`, which attend mainly to the first token (e.g. heads `0.3` or `1.4`, although these are a bit less clear-cut than the other two)

The `prev_token_heads` and `current_token_heads` are perhaps unsurprising, because words that are close together in a sequence probably have a lot more mutual information (i.e. we could get quite far using bigram or trigram prediction). 

The `first_token_heads` are a bit more surprising. The basic intuition here is that the first token in a sequence is often used as a resting or null position for heads that only sometimes activate (since our attention probabilities always have to add up to 1).
</details>


Now that we've observed our three basic attention patterns, it's time to make detectors for those patterns!


### Exercise - write your own detectors

```yaml
Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You shouldn't spend more than 15-20 minutes on these exercises.

These exercises shouldn't be too challenging, if you understand attention patterns. Use the hint if stuck on things like how to correctly index your tensors, or how to access the activation patterns from the cache.
```

You should fill in the functions below, which act as detectors for particular types of heads. Validate your detectors by comparing these results to the visual attention patterns above - summary statistics on their own can be dodgy, but are much more reliable if you can validate it by directly playing with the data.

Tasks like this are useful, because we need to be able to take our observations / intuitions about what a model is doing, and translate these into quantitative measures. As the exercises proceed, we'll be creating some much more interesting tools and detectors!

Note - there's no objectively correct answer for which heads are doing which tasks, and which detectors can spot them. You should just try and come up with something plausible-seeming, which identifies the kind of behaviour you're looking for.



In [9]:

def average_over_condition(tensor, condition):
    I, J, K = tensor.shape
    return [
        sum(tensor[i, j, k] for j in range(J) for k in range(K) if condition(j, k))
        / sum(condition(j, k) for j in range(J) for k in range(K))
        for i in range(I)
    ]


def over_threshhold_attn(cache, condition, threshhold=0.5, sorce="pattern"):
    return_values = []

    for layer, pattern in enumerate( cache.stack_activation("pattern")):
        scores = average_over_condition(pattern, condition)
        indices = [i for i, s in enumerate(scores) if s > threshhold]
        for i in indices:
            return_values.append(f"L{layer+1}H{i}")
    return return_values


def current_attn_detector(cache: ActivationCache, threshhold = .3) -> List[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be current-token heads
    """

    def cond(i, j):
        return i == j

    return over_threshhold_attn(cache, cond, threshhold=threshhold)


def prev_attn_detector(cache: ActivationCache,threshhold =.3 ) -> List[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be prev-token heads
    """

    def cond(i, j):
        return i - j == 1

    return over_threshhold_attn(cache, cond, threshhold=threshhold)


def first_attn_detector(cache: ActivationCache, threshhold = .3) -> List[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be first-token heads
    """

    def cond(i, j):
        return j == 0

    return over_threshhold_attn(cache, cond, threshhold=threshhold)

def find_repeating_rows(tensor):
    """
    Finds repeating rows (vectors) in a 2D torch tensor.

    Args:
    tensor (torch.Tensor): A 2D torch tensor.

    Returns:
    dict: A dictionary where keys are the indices of repeating rows,
          and values are the indices where those rows last occurred.
    """
    last_occurrence = {}
    repeats = {}

    for pos, token in enumerate(tensor[0]):
        id  = token.item()

        if id in last_occurrence:
            repeats[pos] = last_occurrence[id]
        last_occurrence[id] = pos
    
    return repeats

def induction_attn_detector(cache: ActivationCache, tokens, off_by_one = True, threshhold = .3) -> List[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be induction heads

    Remember - the tokens used to generate rep_cache are (bos_token, *rand_tokens, *rand_tokens)
    """
    repeat_dict = find_repeating_rows(t.tensor(tokens))

    def cond(i, j):
        if i not in repeat_dict.keys():
            return False
        to_add = 1 if off_by_one else 0
        return repeat_dict[i] + to_add == j

    return over_threshhold_attn(cache, cond, threshhold=threshhold)

print("Heads attending to current token  = ", ", ".join(current_attn_detector(gpt_cache)))
print("Heads attending to previous token = ", ", ".join(prev_attn_detector(gpt_cache)))
print("Heads attending to first token    = ", ", ".join(first_attn_detector(gpt_cache)))
print("Heads attending to previous mention = ", ", ".join(induction_attn_detector(gpt_cache, gpt_tokens, off_by_one=False)))
print("Heads attending to one after previous mention = ", ", ".join(induction_attn_detector(gpt_cache, gpt_tokens, off_by_one=True)))


print("Heads attending to current token  = ", ", ".join(current_attn_detector(toymodel_cache)))
print("Heads attending to previous token = ", ", ".join(prev_attn_detector(toymodel_cache)))
print("Heads attending to previous mention = ", ", ".join(induction_attn_detector(toymodel_cache, toymodel_tokens, off_by_one=False)))
print("Heads attending to one after previous mention = ", ", ".join(induction_attn_detector(toymodel_cache, toymodel_tokens, off_by_one=True)))

Heads attending to current token  =  L1H1, L1H3, L1H4, L1H5, L2H11, L5H7
Heads attending to previous token =  L2H0, L3H2, L3H9, L4H2, L4H3, L4H6, L4H7, L5H11, L6H6, L7H8, L8H0
Heads attending to first token    =  L1H2, L1H9, L2H3, L2H5, L2H6, L2H7, L2H8, L2H9, L3H1, L3H4, L3H6, L3H11, L4H0, L4H1, L4H3, L4H4, L4H5, L4H9, L4H10, L4H11, L5H0, L5H1, L5H2, L5H3, L5H4, L5H5, L5H6, L5H8, L5H9, L5H10, L6H0, L6H1, L6H2, L6H3, L6H4, L6H5, L6H6, L6H7, L6H8, L6H9, L6H10, L6H11, L7H0, L7H1, L7H2, L7H3, L7H4, L7H5, L7H6, L7H7, L7H8, L7H9, L7H10, L7H11, L8H0, L8H1, L8H2, L8H3, L8H4, L8H5, L8H6, L8H7, L8H8, L8H9, L8H10, L8H11, L9H0, L9H1, L9H2, L9H3, L9H4, L9H5, L9H6, L9H7, L9H8, L9H9, L9H10, L9H11, L10H0, L10H1, L10H2, L10H3, L10H4, L10H5, L10H6, L10H7, L10H8, L10H9, L10H10, L10H11, L11H0, L11H1, L11H2, L11H3, L11H4, L11H5, L11H6, L11H7, L11H8, L11H9, L11H10, L11H11, L12H1, L12H2, L12H3, L12H4, L12H5, L12H6, L12H7, L12H9, L12H10, L12H11
Heads attending to previous mention =  L1H1, L1H5, L4H0
Heads at

  repeat_dict = find_repeating_rows(t.tensor(tokens))



<details>
<summary>Hint</summary>

Try and compute the average attention probability along the relevant tokens. For instance, you can get the tokens just below the diagonal by using `t.diagonal` with appropriate `offset` parameter:

```python
>>> arr = t.arange(9).reshape(3, 3)
>>> arr
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> arr.diagonal()
tensor([0, 4, 8])

>>> arr.diagonal(-1)
tensor([3, 7])
```

Remember that you should be using `cache["pattern", layer]` to get all the attention probabilities for a given layer, and then indexing on the 0th dimension to get the correct head.
</details>

<details>
<summary>Solution </code></summary>

```python
def current_attn_detector(cache: ActivationCache) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be current-token heads
    '''
    attn_heads = []
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of diagonal elements
            score = attention_pattern.diagonal().mean()
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads

def prev_attn_detector(cache: ActivationCache) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be prev-token heads
    '''
    attn_heads = []
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of sub-diagonal elements
            score = attention_pattern.diagonal(-1).mean()
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads

def first_attn_detector(cache: ActivationCache) -> List[str]:
    '''
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be first-token heads
    '''
    attn_heads = []
    for layer in range(model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of 0th elements
            score = attention_pattern[:, 0].mean()
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads
```

Note - choosing `0.4` as a threshold is a bit arbitrary, but it seems to work well enough. In this particular case, a threshold of `0.5` results in no head being classified as a current-token head.
</details>


Compare the printouts to your attention visualisations above. Do they seem to make sense?

# Logit Attribution
A consequence of the residual stream is that the output logits are the sum of the contributions of each layer, and thus the sum of the results of each head. This means we can decompose the output logits into a term coming from each head and directly do attribution like this!

<details>
<summary>A concrete example</summary>

Let's say that our model knows that the token Harry is followed by the token Potter, and we want to figure out how it does this. The logits on Harry are `residual @ W_U`. But this is a linear map, and the residual stream is the sum of all previous layers `residual = embed + attn_out_0 + attn_out_1`. So `logits = (embed @ W_U) + (attn_out @ W_U) + (attn_out_1 @ W_U)`

We can be even more specific, and *just* look at the logit of the Potter token - this corresponds to a column of `W_U`, and so a direction in the residual stream - our logit is now a single number that is the sum of `(embed @ potter_U) + (attn_out_0 @ potter_U) + (attn_out_1 @ potter_U)`. Even better, we can decompose each attention layer output into the sum of the result of each head, and use this to get many terms.
</details>

Your mission here is to write a function to look at how much each component contributes to the correct logit. Your components are:

* The direct path (i.e. the residual connections from the embedding to unembedding),
* Each layer 0 head (via the residual connection and skipping layer 1)
* Each layer 1 head

To emphasise, these are not paths from the start to the end of the model, these are paths from the output of some component directly to the logits - we make no assumptions about how each path was calculated!

A few important notes for this exercise:

* Here we are just looking at the DIRECT effect on the logits, i.e. the thing that this component writes / embeds into the residual stream - if heads compose with other heads and affect logits like that, or inhibit logits for other tokens to boost the correct one we will not pick up on this!
* By looking at just the logits corresponding to the correct token, our data is much lower dimensional because we can ignore all other tokens other than the correct next one (Dealing with a 50K vocab size is a pain!). But this comes at the cost of missing out on more subtle effects, like a head suppressing other plausible logits, to increase the log prob of the correct one.
    * There are other situations where our job might be easier. For instance, in the IOI task (which we'll discuss shortly) we're just comparing the logits of the indirect object to the logits of the direct object, meaning we can use the **difference between these logits**, and ignore all the other logits.
* When calculating correct output logits, we will get tensors with a dimension `(position - 1,)`, not `(position,)` - we remove the final element of the output (logits), and the first element of labels (tokens). This is because we're predicting the *next* token, and we don't know the token after the final token, so we ignore it.

<details>

<summary>Question - why don't we do this to the log probs instead?</summary>

Because log probs aren't linear, they go through `log_softmax`, a non-linear function.
</details>



### Exercise - build logit attribution tool

```yaml
Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪

You shouldn't spend more than 10-15 minutes on this exercise.

This exercise is important, but has quite a few messy einsums, so you might get more value from reading the solution than doing the exercises.
```

You should implement the `logit_attribution` function below. This should return the contribution of each component in the "correct direction". We've already given you the unembedding vectors for the correct direction, `W_U_correct_tokens` (note that we take the `[1:]` slice of tokens, for reasons discussed above).

The code below this function will check your logit attribution function is working correctly, by taking the sum of logit attributions and comparing it to the actual values in the residual stream at the end of your model.



In [10]:
import plotly.graph_objects as go

def plot_attribution_pattern(attribution_scores: Float[Tensor, "layers heads"]):
    num_layers, num_heads = attribution_scores.shape
    
    fig = go.Figure(data=go.Heatmap(
        z=attribution_scores,
        x=[f"Head {i+1}" for i in range(num_heads)],
        y=[f"Layer {i+1}" for i in range(num_layers)],
        colorscale='Viridis',
        hoverongaps=False
    ))
    
    fig.update_layout(
        title={
            'text': "Attribution Scores",
            'font': {'size': 16}
        },
        xaxis_title="Heads",
        yaxis_title="Layers",
        xaxis={'tickfont': {'size': 12}, 'tickangle': -45},
        yaxis={'tickfont': {'size': 12}},
        coloraxis_colorbar={
            'title': '',
            'tickfont': {'size': 12}
        }
    )
    
    fig.show()

In [11]:
def logit_attribution( tokens: Int[Tensor, "batch seq"], model: HookedTransformer, cache: ActivationCache, token_postion: int) -> Float[Tensor, "layers heads"]:

    results = [cache[f"blocks.{i}.attn.hook_result"] for i in range(len(model.blocks))]
    results = t.stack(results, dim=1)
    results = results[token_postion, :, :, :]


    logits = model.unembed(results)

    next_token_id = tokens[0, token_postion + 1]

    attention_pattern  = logits[:,:,next_token_id]
    return attention_pattern


In [21]:
print("Heads attending to current token  = ", ", ".join(current_attn_detector(gpt_cache)))
print("Heads attending to previous token = ", ", ".join(prev_attn_detector(gpt_cache)))
print("Heads attending to first token    = ", ", ".join(first_attn_detector(gpt_cache)))
print("Heads attending to previous mention = ", ", ".join(induction_attn_detector(gpt_cache, repeat_seq, off_by_one=False)))
print("Heads attending to one after previous mention = ", ", ".join(induction_attn_detector(gpt_cache, repeat_seq, off_by_one=True, threshhold=.3)))


print("Heads attending to current token  = ", ", ".join(current_attn_detector(toymodel_cache)))
print("Heads attending to previous token = ", ", ".join(prev_attn_detector(toymodel_cache)))
print("Heads attending to previous mention = ", ", ".join(induction_attn_detector(toymodel_cache, repeat_seq, off_by_one=False)))
print("Heads attending to one after previous mention = ", ", ".join(induction_attn_detector(toymodel_cache, repeat_seq, off_by_one=True)))

Heads attending to current token  =  L1H1, L1H3, L1H4, L1H5, L1H10, L2H11, L5H7, L12H8
Heads attending to previous token =  L3H2, L3H3, L3H5, L3H9, L4H2, L4H3, L4H6, L4H7, L4H8, L5H11, L6H6, L7H8, L8H0
Heads attending to first token    =  L1H9, L2H3, L2H4, L2H5, L2H6, L2H7, L2H8, L2H9, L3H0, L3H1, L3H4, L3H6, L3H8, L3H11, L4H0, L4H1, L4H3, L4H4, L4H5, L4H9, L4H10, L4H11, L5H0, L5H1, L5H2, L5H3, L5H4, L5H5, L5H6, L5H8, L5H9, L5H10, L6H0, L6H1, L6H2, L6H3, L6H4, L6H5, L6H6, L6H7, L6H8, L6H9, L6H10, L6H11, L7H0, L7H1, L7H2, L7H3, L7H4, L7H5, L7H6, L7H7, L7H8, L7H9, L7H10, L7H11, L8H0, L8H1, L8H2, L8H3, L8H4, L8H5, L8H6, L8H7, L8H8, L8H9, L8H10, L8H11, L9H0, L9H1, L9H2, L9H3, L9H4, L9H5, L9H6, L9H7, L9H8, L9H9, L9H10, L9H11, L10H0, L10H1, L10H2, L10H3, L10H4, L10H5, L10H6, L10H7, L10H8, L10H9, L10H10, L10H11, L11H0, L11H1, L11H2, L11H3, L11H4, L11H5, L11H6, L11H7, L11H8, L11H9, L11H10, L11H11, L12H0, L12H1, L12H2, L12H3, L12H4, L12H5, L12H6, L12H7, L12H9, L12H10, L12H11
Heads attending to 


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [19]:
rand_seq = t.randint(1000, 10000, (1, 10))
repeat_seq = t.cat([rand_seq, rand_seq], dim=1)


gpt_logits, gpt_cache = gpt2_small.run_with_cache(repeat_seq, remove_batch_dim=True)
toymodel_logits, toymodel_cache = toy_transformer.run_with_cache(repeat_seq, remove_batch_dim=True)

token_postion = 14
attribution_scores = logit_attribution(repeat_seq, toy_transformer, toymodel_cache, token_postion)
plot_attribution_pattern(attribution_scores)

token_postion = 14
attribution_scores = logit_attribution(repeat_seq, gpt2_small, gpt_cache, token_postion)
plot_attribution_pattern(attribution_scores)

In [16]:
str_toknes.

tensor([[142, 800, 731, 726, 586, 386, 409, 964, 798, 231, 142, 800, 731, 726,
         586, 386, 409, 964, 798, 231]])

In [20]:

attention_pattern = gpt_cache["pattern", 9]
print(attention_pattern.shape)
str_toknes = toy_transformer.to_str_tokens(repeat_seq)

print("Layer 0 Head Attention Patterns:")
display(
    cv.attention.attention_patterns(
        tokens=str_toknes,
        attention=attention_pattern
    )
)

torch.Size([12, 20, 20])
Layer 0 Head Attention Patterns:


In [None]:
import matplotlib.pyplot as plt

def plot_attribution_pattern(attribution_scores: Float[Tensor, "layers heads"]):
    num_layers, num_heads = attribution_scores.shape

    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(attribution_scores, cmap='viridis', aspect='auto')

    ax.set_xticks(range(num_heads))
    ax.set_yticks(range(num_layers))
    ax.set_xticklabels([f"Head {i+1}" for i in range(num_heads)], fontsize=12)
    ax.set_yticklabels([f"Layer {i+1}" for i in range(num_layers)], fontsize=12)

    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    ax.set_title("Attribution Scores", fontsize=16)
    ax.set_xlabel("Heads", fontsize=14)
    ax.set_ylabel("Layers", fontsize=14)

    cbar = ax.figure.colorbar(im, ax=ax)
    cbar.ax.tick_params(labelsize=12)

    fig.tight_layout()
    plt.show()

plot_attribution_pattern(attribution_scores)


<details>
<summary>Solution</summary>


```python
def logit_attribution(
    embed: Float[Tensor, "seq d_model"],
    l1_results: Float[Tensor, "seq nheads d_model"],
    l2_results: Float[Tensor, "seq nheads d_model"],
    W_U: Float[Tensor, "d_model d_vocab"],
    tokens: Int[Tensor, "seq"]
) -> Float[Tensor, "seq-1 n_components"]:
    '''
    Inputs:
        embed: the embeddings of the tokens (i.e. token + position embeddings)
        l1_results: the outputs of the attention heads at layer 1 (with head as one of the dimensions)
        l2_results: the outputs of the attention heads at layer 2 (with head as one of the dimensions)
        W_U: the unembedding matrix
        tokens: the token ids of the sequence

    Returns:
        Tensor of shape (seq_len-1, n_components)
        represents the concatenation (along dim=-1) of logit attributions from:
            the direct path (seq-1,1)
            layer 0 logits (seq-1, n_heads)
            layer 1 logits (seq-1, n_heads)
        so n_components = 1 + 2*n_heads
    '''
    W_U_correct_tokens = W_U[:, tokens[1:]]
    # SOLUTION
    direct_attributions = einops.einsum(W_U_correct_tokens, embed[:-1], "emb seq, seq emb -> seq")
    l1_attributions = einops.einsum(W_U_correct_tokens, l1_results[:-1], "emb seq, seq nhead emb -> seq nhead")
    l2_attributions = einops.einsum(W_U_correct_tokens, l2_results[:-1], "emb seq, seq nhead emb -> seq nhead")
    return t.concat([direct_attributions.unsqueeze(-1), l1_attributions, l2_attributions], dim=-1)
```
</details>


Once you've got the tests working, you can visualise the logit attributions for each path through the model. We've provided you with the helper function `plot_logit_attribution`, which presents the results in a nice way.



In [None]:
embed = cache["embed"]
l1_results = cache["result", 0]
l2_results = cache["result", 1]
logit_attr = logit_attribution(embed, l1_results, l2_results, model.W_U, tokens[0])

plot_logit_attribution(model, logit_attr, tokens)


#### Question - what is the interpretation of this plot?

You should find that the most variation in the logit attribution comes from the direct path. In particular, some of the tokens in the direct path have a very high logit attribution (e.g. tokens 12, 24 and 46). Can you guess what gives them in particular such a high logit attribution? 

<details>
<summary>Answer</summary>

The tokens with very high logit attribution are the ones which "offer very probable bigrams". For instance, the highest contribution on the direct path comes from `| manip|`, because this is very likely to be followed by `|ulative|` (or presumably a different stem like `| ulation|`). `| super|` -> `|human|` is another example of a bigram formed when the tokenizer splits one word into multiple tokens.

There are also examples that come from two different words, rather than a single word split by the tokenizer. These include:

* `| more|` -> `| likely|`
* `| machine|` -> `| learning|`
* `| by|` -> `| default|`
* `| how|` -> `| to|`

See later for a discussion of all the ~infuriating~ fun quirks of tokenization!
</details>

Another feature of the plot - the heads in the second layer seem to have much higher contributions than the heads in the first layer. Why do you think this might be?

<details>
<summary>Hint</summary>

Think about what this graph actually represents, in terms of paths through the transformer.
</details>

<details>
<summary>Answer</summary>

This is because of a point we discussed earlier - this plot doesn't pick up on things like a head's effect in composition with another head. So the attribution for layer-0 heads won't involve any composition, whereas the attributions for layer-1 heads will involve not only the single-head paths through those attention heads, but also the 2-layer compositional paths through heads in layer 0 and layer 1.
</details>
