<a href="https://colab.research.google.com/github/EffiSciencesResearch/ML4G-2.0/blob/master/workshops/induction_heads/induction_heads_hard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Introduction

Thanks to Callum McDougall for creating much of this notebook's content!

This exercise should get you familiar with concepts in circuit interpretability in transformers.

## Content & Learning Objectives

#### 1️⃣ Find Inductive Heads via attention patterns

You'll plot attention patterns to check out the theory of Inductive Heads and see if you can find them in the model you're working with.

> ##### Learning objectives
> 
> - Use `circuitsvis` to visualize attention heads
> - Understand what the theory of inductive heads predicts about attention patterns
> - Use attention patterns to identify inductive heads
> - Automate this process to find inductive heads in a larger model

#### 2️⃣ Logit Attribution

Here, you'll learn how to use `TransformerLens` to implement `LogitLens`, a tool for attributing logit values to specific components of the model. You'll also learn how to use this tool to identify basic attention heads that are important for Induction tasks.
> ##### Learning objectives
> - Perform direct logit attribution to figure out which heads are writing to the residual stream in a significant way
> - Crosscheck your earlier results with the results

In [None]:
try:
    import google.colab

    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    # Install packages
    %pip install einops transformer_lens circuitsvis -q

    !wget -q https://github.com/EffiSciencesResearch/ML4G-2.0/archive/refs/heads/master.zip
    !unzip -o /content/master.zip 'ML4G-2.0-master/workshops/induction_heads/*'
    !mv --no-clobber ML4G-2.0-master/workshops/induction_heads/* .
    !rm -r ML4G-2.0-master

    print("Imports & installations complete!")
else:
    %load_ext autoreload
    %autoreload 2

In [None]:
import circuitsvis as cv
import torch
import typeguard
from typing import Callable
import plotly.express as px
from huggingface_hub import hf_hub_download
from IPython.display import display
from jaxtyping import Float, Int, jaxtyped
from tests import (
    test_average_over_condition,
    test_current_attn_detector,
    test_first_attn_detector,
    test_induction_attn_detector,
    test_logit_attribution,
    test_prev_attn_detector,
)
from torch import Tensor
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig

# We use @typechecked above a function to make sure the types, and tensor sizes are respected.
typechecked = jaxtyped(typechecker=typeguard.typechecked)

# Disable the gradients globally. We don't want them for this notebook.
torch.set_grad_enabled(False)

# Use CUDA if available
device = torch.device("cpu")

## Loading and Running Models

Today, we will be using two models:

1. GPT-2 Small: a smaller version of the GPT-2 model with 12 layers and 80 million parameters.

2. A two-layer attention-only transformer model.
This toy model is a 2L attention-only transformer trained specifically for today. Some changes to make it easier to interpret:
- It has only attention blocks.
- The positional embeddings are only added to the residual stream before each key and query vector in the attention layers instead of the token embeddings - i.e. we compute queries as `Q = (resid + pos_embed) @ W_Q + b_Q` and the same for keys, but for values as `V = resid @ W_V + b_V`. This means that **the residual stream can't directly encode positional information**.
- This turns out to make it *way* easier for induction heads to form; it happens 2-3x times earlier - [see the comparison of two training runs](https://wandb.ai/mechanistic-interpretability/attn-only/reports/loss_ewma-22-08-24-11-08-83---VmlldzoyNTI0MDMz?accessToken=8ap8ir6y072uqa4f9uinotdtrwmoa8d8k2je4ec0lyasf1jcm3mtdh37ouijgdbm) here. (The bump in each curve is the formation of induction heads.)
    - The argument that does this below is `positional_embedding_type="shortformer"`.
- It has no MLP layers, no LayerNorms, and no biases.
- There are separate embed and unembed matrices (i.e. the weights are not tied).


In [None]:
# Don't read, just run

# Load GPT-2
gpt2: HookedTransformer = HookedTransformer.from_pretrained("gpt2-small")
gpt2.set_use_attn_result(True)

# Load the toy transformer
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True,  # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b",
    seed=398,
    use_attn_result=True,
    normalization_type=None,  # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer",
)

REPO_ID = "callummcdougall/attn_only_2L_half"
FILENAME = "attn_only_2L_half.pth"

weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

model = HookedTransformer(cfg)
model.load_state_dict(torch.load(weights_path, map_location=device))

## Caching all Activations

The first basic operation when doing mechanistic interpretability is to break open the black box of the model and look at all of the internal activations of a model. This can be done with `logits, cache = model.run_with_cache(tokens)`. Let's try this out, on the first sentence from the GPT-2 paper.

<details>
<summary>Aside - a note on <code>remove_batch_dim</code></summary>

Every activation inside the model begins with a batch dimension. Here, because we only entered a single batch dimension, that dimension is always length 1 and can be annoying, so passing in the `remove_batch_dim=True` keyword removes it. 

`gpt2_cache_no_batch_dim = gpt2_cache.remove_batch_dim()` would have achieved the same effect.
</details>

In [None]:
text = " I can't repeat enough that ML4Good is the best place to learn about AI safety." * 2

# Tokenize the text
tokens: Int[Tensor, "batch=1 seq_len"] = model.to_tokens(text)
print(f"{tokens.shape=}")

# Gather the logits and caches
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

If you inspect the `gpt_cache` object, you should see that it contains a very large number of keys, each one corresponding to a different activation in the model. You can access the keys by indexing the cache directly, or by using a more convenient indexing shorthand. For instance, the code:

In [None]:
pattern = cache["pattern", 0]
print(f"{pattern.shape=}")

<details>
<summary>What is this activation ["pattern", 0]? Can you explain what the size of the shape?</summary>

The activations are attention scores between the pairs of tokens of the sentence. The 0 corresponds to the first layer of the model.
Its shape is `(number of heads, query token, key token)`.
</details>

You can use the following diagram to find the names of activations that you want (click to see it larger)
[ ![activations](https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/transformer-full-updated.png) ](https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/transformer-full-updated.png)

## Visualizing Attention Heads

A key insight from the Mathematical Frameworks paper is that we should focus on interpreting the parts of the model that are intrinsically interpretable - the input tokens, the output logits, and the attention patterns. Everything else (the residual stream, keys, queries, values, etc.) are compressed intermediate states when calculating meaningful things. So a natural place to start is classifying heads by their attention patterns on various texts.

When doing interpretability, it's always good to begin by visualizing your data rather than relying on summary statistics. Summary statistics can be super misleading! Now that we have visualized the attention patterns, we can create some basic summary statistics and use our visualizations to validate them. Being proficient in web development/data visualization is a surprisingly useful skill set. Neural networks are very high-dimensional objects.
Let's visualize the attention pattern of all the heads in layer 0, using [Alan Cooney's CircuitsVis library](https://github.com/alan-cooney/CircuitsVis) (based on Anthropic's PySvelte library). If you did the previous set of exercises, you'll have seen this library before.

We will use the function `cv.attention.attention_patterns`, which takes the following arguments:

* `attention`: Attention head activations.
    * This should be a tensor of shape `[nhead, seq_dest, seq_src]`, i.e. the `[i, :, :]`th element is the grid of attention patterns (probabilities) for the `i`th attention head.
    * We get this by indexing our `gpt2_cache` object.
* `tokens`: List of tokens (e.g. `["A", "person"]`). 
    * Sequence length must match that inferred from `attention`.
    * This is used to label the grid.
    * We get this by using the `gpt2_small.to_str_tokens` method.
* `attention_head_names`: Optional list of names for the heads.

There are also other circuitsvis functions, e.g. `cv.attention.attention_pattern` (for just a single head) or `cv.attention.attention_heads` (which has the same syntax and but presents the information in a different form).

<details>
<summary>Help - my <code>attention_heads</code> plots are behaving weirdly (e.g. they continually shrink after I plot them).</summary>
This seems to be a bug in `circuitsvis` - on VSCode, the attention head plots continually shrink in size.

Until this is fixed, one way to get around it is to open the plots in your browser. You can do this inline with the `webbrowser` library:

```python
attn_heads = cv.attention.attention_heads(
    tokens=gpt2_str_tokens, 
    attention=attention_pattern,
    attention_head_names=[f"L0H{i}" for i in range(12)],
)

path = "attn_heads.html"

with open(path, "w") as f:
    f.write(str(attn_heads))

webbrowser.open(path)
```

To check exactly where this is getting saved, you can print your current working directory with `os.getcwd()`.
</details>

This visualization is interactive! Try hovering over a token or head, and click to lock. The grid on the top left and for each head is the attention pattern as a destination position by source position grid. It's lower triangular because GPT-2 has **causal attention**, attention can only look backward, so information can only move forward in the network.


In [None]:
layer = 1

attention_pattern = cache["pattern", layer]
str_tokens = model.to_str_tokens(text)

cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

Hover over heads to see the attention patterns; click on a head to lock it. Hover over each token to see which other tokens it attends to (or which other tokens attend to it - you can see this by changing the dropdown to `Destination <- Source`).

# Finding induction heads

             
Use the [diagram at this link](https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/small-merm.svg) to remind yourself of the relevant hook names.


### Exercise - visualise attention patterns

```yaml
Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to ~10 minutes on this exercise.

It's important to be comfortable using circuitsvis, and the cache object.
```

*This exercise should be very quick - you can reuse code from the previous section. You should look at the solution if you're still stuck after 5-10 minutes.*

Visualise the attention patterns for **both** layers of your model, with the same repeated text.


In [None]:
# 1. Convert the text to tokens, both as integers and as strings (for the plot)
...  # TODO: ~8 words

# 2. Pass the tokens through the model, get the cache
...  # TODO: ~7 words

# 3. For each of the model.cfg.n_layers, display the attention pattern
...  # TODO: ~24 words

<details>
<summary>Show solution</summary>

```python
# 1. Convert the text to tokens, both as integers and as strings (for the plot)
tokens = model.to_tokens(text)
str_tokens = model.to_str_tokens(text)

# 2. Pass the tokens through the model, get the cache
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

# 3. For each of the model.cfg.n_layers, display the attention pattern
for layer in range(model.cfg.n_layers):
    print(f"Layer {layer+1}")
    attention_pattern = cache["pattern", layer]
    display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))
```

</details>



<details>
<summary>What patterns do you see (find at least 3)</summary>
We notice that there are three basic patterns which repeat quite frequently:

* `prev_token_heads`, which mainly attend to the previous token (e.g. head `0.7`)
* `current_token_heads`, which mainly attend to the current token (e.g. head `1.6`)
* `first_token_heads`, which attend a lot to the first token (e.g. heads `0.3` or `1.4`, although these are a bit less clear-cut than the other two)
The `prev_token_heads` and `current_token_heads` are perhaps unsurprising because words that are close together in a sequence probably have a lot more mutual information (i.e. we could get quite far using bigram or trigram prediction). 

The `first_token_heads` are a bit more surprising. The basic intuition here is that the first token in a sequence is often used as a resting or null position for heads that only sometimes activate (since our attention probabilities always have to add up to 1).
</details>

Now that we've observed our three basic attention patterns, it's time to make detectors for those patterns!
### Exercise - Write Your Own Detectors

```yaml
Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You shouldn't spend more than 15-20 minutes on these exercises.
These exercises shouldn't be too challenging if you understand attention patterns.
```

You should fill in the functions below, which act as detectors for particular types of heads. Validate your detectors by comparing these results to the visual attention patterns above. Summary statistics on their own can be unreliable, but are much more reliable if you can validate them by directly engaging with the data.

Tasks like this are useful because we need to be able to interpret our observations/intuitions about what a model is doing and translate these into quantitative measures. As the exercises proceed, we'll be creating some much more interesting tools and detectors!

Note - there's no objectively correct answer for which heads are performing which tasks, and which detectors can spot them. You should try to come up with something plausible that identifies the kind of behavior you're looking for.


In [None]:
THRESHOLD = 0.35


@typechecked
def average_over_condition(
    tensor: Float[Tensor, "head query key"],
    condition: Callable[[int, int], bool],
) -> Float[Tensor, "head"]:
    """
    Return the average of the values in the tensor whose indices satisfy the condition.

    Args:
        tensor: A 3D tensor of shape (n_head, n_query, n_key).
        condition: A function that takes two indices, for the query and key, and determines whether to consider them.

    Returns:
        A 1D tensor of shape (n_head,) with the average of the values in the tensor that satisfy the condition.
    """
    n_head, n_query, n_key = tensor.shape

    # Make a 2D tensor with True where the condition is met, and use it to index the tensor
    ...  # TODO: ~33 words


def heads_with_high_average_pattern(
    cache, condition: Callable[[int, int], bool], threshold: float
) -> list[str]:
    """
    Returns a list of "L{layer}H{head}" which have an average pattern value over the threshold
    on the entries whose indices satisfy the condition.
    """
    return_values = []

    for layer, pattern in enumerate(cache.stack_activation("pattern")):
        ...  # TODO: ~21 words
    return return_values


def current_attn_detector(cache: ActivationCache, threshold=THRESHOLD) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be the current-token heads
    """

    def is_attn_to_self(query: int, key: int): ...  # TODO: ~3 words

    return heads_with_high_average_pattern(cache, is_attn_to_self, threshold=threshold)


def prev_attn_detector(cache: ActivationCache, threshold=THRESHOLD) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be prev-token heads
    """

    def is_attn_to_prev_token(query: int, key: int): ...  # TODO: ~4 words

    return heads_with_high_average_pattern(cache, is_attn_to_prev_token, threshold=threshold)


def first_attn_detector(cache: ActivationCache, threshold=THRESHOLD) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be first-token heads
    """

    def is_attn_to_first_token(query: int, key: int): ...  # TODO: ~3 words

    return heads_with_high_average_pattern(cache, is_attn_to_first_token, threshold=threshold)


def find_repeating_rows(tensor):
    """
    Finds repeating rows (vectors) in a 2D torch tensor.

    Args:
        tensor (torch.Tensor): A 2D torch tensor.

    Returns:
        dict: A dictionary where keys are the indices of repeating rows,
            and values are the indices where those rows last occurred.
    """
    last_occurrence = {}
    repeats = {}

    for pos, token in enumerate(tensor[0]):
        id = token.item()

        if id in last_occurrence:
            repeats[pos] = last_occurrence[id]
        last_occurrence[id] = pos

    return repeats


@typechecked
def induction_attn_detector(
    cache: ActivationCache, tokens: Int[Tensor, "batch seq"], off_by_one=True, threshold=THRESHOLD
) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be induction heads

    Args:
        off_by_one: If True, the heads attending to the token after the previous occurrence are returned,
            otherwise the heads attending directly to the previous occurrence are returned.
    """
    repeat_dict = find_repeating_rows(tokens)

    def is_attn_to_last_occurrence(query, key):
        if query not in repeat_dict.keys():
            return False
        to_add = 1 if off_by_one else 0
        return repeat_dict[query] + to_add == key

    return heads_with_high_average_pattern(cache, is_attn_to_last_occurrence, threshold=threshold)


def print_head_detection_report(cache, tokens, threshold=THRESHOLD):
    print("Heads attending to the current token:", *current_attn_detector(cache, threshold))

    print("Heads attending to the first token:", *first_attn_detector(cache, threshold))
    print(
        "Heads attending to previous occurrence:",
        *induction_attn_detector(cache, tokens, off_by_one=False, threshold=threshold),
    )
    print(
        "Heads attending to one after previous occurrence:",
        *induction_attn_detector(cache, tokens, off_by_one=True, threshold=threshold),
    )


test_average_over_condition(average_over_condition)
test_current_attn_detector(current_attn_detector, model)
test_first_attn_detector(first_attn_detector, model)
test_induction_attn_detector(induction_attn_detector, model)
test_prev_attn_detector(prev_attn_detector, model)
print()

print_head_detection_report(cache, tokens, threshold=0.35)

<details>
<summary>Show solution</summary>

```python
THRESHOLD = 0.35


@typechecked
def average_over_condition(
    tensor: Float[Tensor, "head query key"],
    condition: Callable[[int, int], bool],
) -> Float[Tensor, "head"]:
    """
    Return the average of the values in the tensor whose indices satisfy the condition.

    Args:
        tensor: A 3D tensor of shape (n_head, n_query, n_key).
        condition: A function that takes two indices, for the query and key, and determines whether to consider them.

    Returns:
        A 1D tensor of shape (n_head,) with the average of the values in the tensor that satisfy the condition.
    """
    n_head, n_query, n_key = tensor.shape

    # Make a 2D tensor with True where the condition is met, and use it to index the tensor
    condition_tensor = torch.tensor(
        [[condition(j, k) for k in range(n_key)] for j in range(n_query)]
    )
    assert condition_tensor.shape == (n_query, n_key)

    # For each head, average over the condition
    return torch.stack([tensor[head][condition_tensor].mean() for head in range(n_head)])


def heads_with_high_average_pattern(
    cache, condition: Callable[[int, int], bool], threshold: float
) -> list[str]:
    """
    Returns a list of "L{layer}H{head}" which have an average pattern value over the threshold
    on the entries whose indices satisfy the condition.
    """
    return_values = []

    for layer, pattern in enumerate(cache.stack_activation("pattern")):
        scores = average_over_condition(pattern, condition)
        for head, score in enumerate(scores):
            if score > threshold:
                return_values.append(f"L{layer+1}H{head}")
    return return_values


def current_attn_detector(cache: ActivationCache, threshold=THRESHOLD) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be the current-token heads
    """

    def is_attn_to_self(query: int, key: int):
        return query == key

    return heads_with_high_average_pattern(cache, is_attn_to_self, threshold=threshold)


def prev_attn_detector(cache: ActivationCache, threshold=THRESHOLD) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be prev-token heads
    """

    def is_attn_to_prev_token(query: int, key: int):
        return query == key + 1

    return heads_with_high_average_pattern(cache, is_attn_to_prev_token, threshold=threshold)


def first_attn_detector(cache: ActivationCache, threshold=THRESHOLD) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be first-token heads
    """

    def is_attn_to_first_token(query: int, key: int):
        return key == 0

    return heads_with_high_average_pattern(cache, is_attn_to_first_token, threshold=threshold)


def find_repeating_rows(tensor):
    """
    Finds repeating rows (vectors) in a 2D torch tensor.

    Args:
        tensor (torch.Tensor): A 2D torch tensor.

    Returns:
        dict: A dictionary where keys are the indices of repeating rows,
            and values are the indices where those rows last occurred.
    """
    last_occurrence = {}
    repeats = {}

    for pos, token in enumerate(tensor[0]):
        id = token.item()

        if id in last_occurrence:
            repeats[pos] = last_occurrence[id]
        last_occurrence[id] = pos

    return repeats


@typechecked
def induction_attn_detector(
    cache: ActivationCache, tokens: Int[Tensor, "batch seq"], off_by_one=True, threshold=THRESHOLD
) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be induction heads

    Args:
        off_by_one: If True, the heads attending to the token after the previous occurrence are returned,
            otherwise the heads attending directly to the previous occurrence are returned.
    """
    repeat_dict = find_repeating_rows(tokens)

    def is_attn_to_last_occurrence(query, key):
        if query not in repeat_dict.keys():
            return False
        to_add = 1 if off_by_one else 0
        return repeat_dict[query] + to_add == key

    return heads_with_high_average_pattern(cache, is_attn_to_last_occurrence, threshold=threshold)


def print_head_detection_report(cache, tokens, threshold=THRESHOLD):
    print("Heads attending to the current token:", *current_attn_detector(cache, threshold))

    print("Heads attending to the first token:", *first_attn_detector(cache, threshold))
    print(
        "Heads attending to previous occurrence:",
        *induction_attn_detector(cache, tokens, off_by_one=False, threshold=threshold),
    )
    print(
        "Heads attending to one after previous occurrence:",
        *induction_attn_detector(cache, tokens, off_by_one=True, threshold=threshold),
    )


test_average_over_condition(average_over_condition)
test_current_attn_detector(current_attn_detector, model)
test_first_attn_detector(first_attn_detector, model)
test_induction_attn_detector(induction_attn_detector, model)
test_prev_attn_detector(prev_attn_detector, model)
print()

print_head_detection_report(cache, tokens, threshold=0.35)
```

</details>



Compare the printouts to your attention visualizations above. Do they seem to make sense?

You can now do the same thing with GPT-2 and find all the induction heads in the model!

In [None]:
layer = 8  # @param {type:"slider", min:1, max:12}

# 1. Convert the text to tokens, both as integers and as strings (for the plot)
tokens = gpt2.to_tokens(text)
str_tokens = gpt2.to_str_tokens(text)

# 2. Pass the tokens through the model, get the cache
logits, cache = gpt2.run_with_cache(tokens, remove_batch_dim=True)

# 3. Display the attention pattern for the selected layer
print(f"Layer {layer}")
attention_pattern = cache["pattern", layer - 1]
display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))

print_head_detection_report(cache, tokens, threshold=0.5)

# Logit Attribution

We now implement a second method to identify induction heads, which will hopefully give similar results to the first method. This method is based on the idea of **logit attribution**.

A consequence of the residual stream is that the output logits are the sum of the contributions of each layer, and thus the sum of the results of each head. This means we can decompose the output logits into a term coming from each head and directly do attribution like this!

<details>
<summary>A concrete example</summary>

Let's say that our model knows that the token Harry is followed by the token Potter, and we want to figure out how it does this. The logits on Harry are `residual @ W_U`. But this is a linear map, and the residual stream is the sum of all previous layers `residual = embed + attn_out_0 + attn_out_1`. So `logits = (embed @ W_U) + (attn_out @ W_U) + (attn_out_1 @ W_U)`

We can be even more specific, and *just* look at the logit of the Potter token - this corresponds to a column of `W_U`, and so a direction in the residual stream - our logit is now a single number that is the sum of `(embed @ potter_U) + (attn_out_0 @ potter_U) + (attn_out_1 @ potter_U)`. Even better, we can decompose each attention layer output into the sum of the result of each head, and use this to get many terms.
</details>

Your mission here is to write a function to look at how much each component contributes to the correct logit. Your components are:

* The direct path (i.e. the residual connections from the embedding to unembedding),
* Each layer 0 head (via the residual connection and skipping layer 1)
* Each layer 1 head

To emphasize, these are not paths from the start to the end of the model, these are paths from the output of some component directly to the logits - we make no assumptions about how each path was calculated!
A few important notes for this exercise:

* Here we are only examining the DIRECT effect on the logits, i.e. the information that this component writes/embeds into the residual stream - we will not capture interactions where heads combine with other heads to influence logits, or suppress logits for other tokens to enhance the correct one.
* When focusing only on the logits corresponding to the correct token, our data is lower-dimensional because we can disregard all other tokens except the correct one (Handling a 50K vocab size can be tedious!). However, this approach may overlook more nuanced effects, such as a head suppressing other plausible logits to elevate the log probability of the correct one.
    * There are other situations where our job might be easier. For instance, in the IOI task (which we'll discuss shortly) we're just comparing the logits of the indirect object to the logits of the direct object, meaning we can use the **difference between these logits**, and ignore all the other logits.
* When calculating correct output logits, we will get tensors with a dimension `(position - 1,)`, not `(position,)` - we remove the final element of the output (logits), and the first element of labels (tokens). This is because we're predicting the *next* token, and we don't know the token after the final token, so we ignore it.

<details>

<summary>Question - why don't we do this to the log probs instead?</summary>

Because log probs aren't linear, they go through `log_softmax`, a non-linear function.
</details>


### Exercise - build logit attribution tool



You should implement the `logit_attribution` function below. This should return the contribution of each component in the "correct direction". 




In [None]:
def plot_attribution_pattern(attribution_scores: Float[Tensor, "layers heads"]):
    num_layers, num_heads = attribution_scores.shape

    fig = px.imshow(
        attribution_scores,
        labels=dict(x="Heads", y="Layers"),
        color_continuous_scale="Viridis",
        x=[f"Head {i+1}" for i in range(num_heads)],
        y=[f"Layer {i+1}" for i in range(num_layers)],
    )

    fig.update_layout(
        title={"text": "Attribution Scores", "font": {"size": 16}},
        xaxis_title="Heads",
        yaxis_title="Layers",
        xaxis={"tickfont": {"size": 12}, "tickangle": -45},
        yaxis={"tickfont": {"size": 12}},
        coloraxis_colorbar={"title": "", "tickfont": {"size": 12}},
    )

    fig.show()

In [None]:
def logit_attribution(
    tokens: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    cache: ActivationCache,
    token_position: int,
) -> Float[Tensor, "layers heads"]:
    """
    Computes the logit attribution for a specific token position in the input sequence.

    Args:
        tokens (Int[Tensor, "batch seq"]): The input token IDs tensor with shape (batch_size, sequence_length).
        model (HookedTransformer): The HookedTransformer model instance.
        cache (ActivationCache): The activation cache containing the intermediate results.
        token_position (int): The position of the token in the input sequence for which to compute the attribution.

    Returns:
        Float[Tensor, "layers heads"]: The logit attribution tensor with shape (num_layers, num_heads).

    Description:
        This function computes the logit attribution for a specific token position in the input sequence.
        It unembeds the output of each attention head in each layer, and sees what upweight it gives on the correct next token.

    Note:
        - The input `tokens` tensor is assumed to have a batch size of 1.
        - The `token_position` is zero-indexed, meaning the first token in the sequence has a position of 0.
        - The returned attention pattern has shape (num_layers, num_heads), representing the attribution scores
          for each layer and attention head.
    """
    # Retrieve the attention results from the activation cache for each transformer block
    ...  # TODO: ~14 words

    # Stack the attention results along the layer dimension
    ...  # TODO: ~6 words

    # Select the attention results corresponding to the specified token position
    ...  # TODO: ~3 words

    # Pass the selected attention results through the model's unembed function to obtain the logits
    ...  # TODO: ~4 words

    # Get the ID of the next token in the sequence
    ...  # TODO: ~5 words

    # Extract the logits corresponding to the next token ID
    ...  # TODO: ~3 words

    return attributions


test_logit_attribution(logit_attribution, model)

<details>
<summary>Show solution</summary>

```python
def logit_attribution(
    tokens: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    cache: ActivationCache,
    token_position: int,
) -> Float[Tensor, "layers heads"]:
    """
    Computes the logit attribution for a specific token position in the input sequence.

    Args:
        tokens (Int[Tensor, "batch seq"]): The input token IDs tensor with shape (batch_size, sequence_length).
        model (HookedTransformer): The HookedTransformer model instance.
        cache (ActivationCache): The activation cache containing the intermediate results.
        token_position (int): The position of the token in the input sequence for which to compute the attribution.

    Returns:
        Float[Tensor, "layers heads"]: The logit attribution tensor with shape (num_layers, num_heads).

    Description:
        This function computes the logit attribution for a specific token position in the input sequence.
        It unembeds the output of each attention head in each layer, and sees what upweight it gives on the correct next token.

    Note:
        - The input `tokens` tensor is assumed to have a batch size of 1.
        - The `token_position` is zero-indexed, meaning the first token in the sequence has a position of 0.
        - The returned attention pattern has shape (num_layers, num_heads), representing the attribution scores
          for each layer and attention head.
    """
    # Retrieve the attention results from the activation cache for each transformer block
    results = [cache[f"blocks.{i}.attn.hook_result"] for i in range(len(model.blocks))]

    # Stack the attention results along the layer dimension
    results = torch.stack(results, dim=1)

    # Select the attention results corresponding to the specified token position
    results = results[token_position, :, :, :]

    # Pass the selected attention results through the model's unembed function to obtain the logits
    logits = model.unembed(results)

    # Get the ID of the next token in the sequence
    next_token_id = tokens[0, token_position + 1]

    # Extract the logits corresponding to the next token ID
    attributions = logits[:, :, next_token_id]

    return attributions


test_logit_attribution(logit_attribution, model)
```

</details>



Now that you have the tool to see which heads have which effects on the logits, run some experiments to see which heads are useful for induction.
Are they the same as those that have the pattern you identified earlier?

In [None]:
repeat_seq = gpt2.to_tokens(text)

...  # TODO: ~32 words

<details>
<summary>Show solution</summary>

```python
repeat_seq = gpt2.to_tokens(text)

gpt_logits, gpt_cache = gpt2.run_with_cache(repeat_seq, remove_batch_dim=True)
toymodel_logits, toymodel_cache = model.run_with_cache(repeat_seq, remove_batch_dim=True)

token_position = 14
attribution_scores = logit_attribution(repeat_seq, model, toymodel_cache, token_position)
plot_attribution_pattern(attribution_scores)

attribution_scores = logit_attribution(repeat_seq, gpt2, gpt_cache, token_position)
plot_attribution_pattern(attribution_scores)

```

</details>

