<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Head_Detector_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# TransformerLens Head Detector Demo

A common technique in mechanistic interpretability of transformer-based neural networks is identification of specialized attention heads, based on the attention patterns elicited by one or more prompts. The most basic examples of such heads are: previous token head, duplicate token head, or induction head ([more info](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=_Jzi6YHRHKP1JziwdE02qdYZ)). Usually, such heads are identified manually, by through visualizations of attention patterns layer by layer, head by head, and trying to recognize the patterns by eye.

The purpose of the `TransformerLens.head_detector` feature is to automate a part of that workflow. The pattern characterizing a head of particular type/function is specified as a `Tensor` being a `seq_len x seq_len` [lower triangular matrix](https://en.wikipedia.org/wiki/Triangular_matrix). It can be either passed to the `detect_head` function directly or by giving a string identifying of several pre-defined detection patterns.

## How to use this notebook

Go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.

Tips for reading this Colab:

* You can run all this code for yourself!
* The graphs are interactive!
* Use the table of contents pane in the sidebar to navigate
* Collapse irrelevant sections with the dropdown arrows
* Search the page using the search in the sidebar, not CTRL+F

## Setup (Ignore)

In [1]:
# NBVAL_IGNORE_OUTPUT
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

DEVELOPMENT_MODE = True
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

if IN_COLAB or IN_GITHUB:
    %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git
    # Install Neel's personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    # Needed for PySvelte to work, v3 came out and broke things...
    %pip install typeguard==2.13.3
    %pip install typing-extensions

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/TransformerLensOrg/TransformerLens.git
  Cloning https://github.com/TransformerLensOrg/TransformerLens.git to /tmp/pip-req-build-v3x96q_b
  Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/TransformerLens.git /tmp/pip-req-build-v3x96q_b
  Resolved https://github.com/TransformerLensOrg/TransformerLens.git to commit 0ffcc8ad647d9e991f4c2596557a9d7475617773
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/neel-plotly.git
  Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-u8mujxc3
  Running command 

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
import torch
import einops
import pysvelte
from tqdm import tqdm

import transformer_lens
from transformer_lens import HookedTransformer, ActivationCache
from neel_plotly import line, imshow, scatter

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{device = }")

device = 'cuda'


### Some plotting utils

In [5]:
# Util for plotting head detection scores

def plot_head_detection_scores(
    scores: torch.Tensor,
    zmin: float = -1,
    zmax: float = 1,
    xaxis: str = "Head",
    yaxis: str = "Layer",
    title: str = "Head Matches"
) -> None:
    imshow(scores, zmin=zmin, zmax=zmax, xaxis=xaxis, yaxis=yaxis, title=title)

def plot_attn_pattern_from_cache(cache: ActivationCache, layer_i: int):
    attention_pattern = cache["pattern", layer_i, "attn"].squeeze(0)
    attention_pattern = einops.rearrange(attention_pattern, "heads seq1 seq2 -> seq1 seq2 heads")
    print(f"Layer {layer_i} Attention Heads:")
    return pysvelte.AttentionMulti(tokens=model.to_str_tokens(prompt), attention=attention_pattern)

## Head detector

Utils: these will be in `transformer_lens.utils` after merging the fork to the main repo

In [6]:
def is_square(x: torch.Tensor) -> bool:
    """Checks if `x` is a square matrix."""
    return x.ndim == 2 and x.shape[0] == x.shape[1]

def is_lower_triangular(x: torch.Tensor) -> bool:
    """Checks if `x` is a lower triangular matrix."""
    if not is_square(x):
        return False
    return x.equal(x.tril())

The code below is copy-pasted from the expanded (not yet merged) version of `transformer_lens.head_detector`.

After merging the code below can be replaced with simply

```py
from transformer_lens.head_detector import *
```

(but please don't use star-imports in production ;))

In [7]:
from collections import defaultdict
import logging
from typing import cast, Dict, List, Optional, Tuple, Union
from typing_extensions import get_args, Literal

import numpy as np
import torch

from transformer_lens import HookedTransformer, ActivationCache
# from transformer_lens.utils import is_lower_triangular, is_square

HeadName = Literal["previous_token_head", "duplicate_token_head", "induction_head"]
HEAD_NAMES = cast(List[HeadName], get_args(HeadName))
ErrorMeasure = Literal["abs", "mul"]

LayerHeadTuple = Tuple[int, int]
LayerToHead = Dict[int, List[int]]

INVALID_HEAD_NAME_ERR = (
    f"detection_pattern must be a Tensor or one of head names: {HEAD_NAMES}; got %s"
)

SEQ_LEN_ERR = (
    "The sequence must be non-empty and must fit within the model's context window."
)

DET_PAT_NOT_SQUARE_ERR = "The detection pattern must be a lower triangular matrix of shape (sequence_length, sequence_length); sequence_length=%d; got detection patern of shape %s"


def detect_head(
    model: HookedTransformer,
    seq: Union[str, List[str]],
    detection_pattern: Union[torch.Tensor, HeadName],
    heads: Optional[Union[List[LayerHeadTuple], LayerToHead]] = None,
    cache: Optional[ActivationCache] = None,
    *,
    exclude_bos: bool = False,
    exclude_current_token: bool = False,
    error_measure: ErrorMeasure = "mul",
) -> torch.Tensor:
    """Searches the model (or a set of specific heads, for circuit analysis) for a particular type of attention head.
    This head is specified by a detection pattern, a (sequence_length, sequence_length) tensor representing the attention pattern we expect that type of attention head to show.
    The detection pattern can be also passed not as a tensor, but as a name of one of pre-specified types of attention head (see `HeadName` for available patterns), in which case the tensor is computed within the function itself.

    There are two error measures available for quantifying the match between the detection pattern and the actual attention pattern.

    1. `"mul"` (default) multiplies both tensors element-wise and divides the sum of the result by the sum of the attention pattern.
    Typically, the detection pattern should in this case contain only ones and zeros, which allows a straightforward interpretation of the score:
    how big fraction of this head's attention is allocated to these specific query-key pairs?
    Using values other than 0 or 1 is not prohibited but will raise a warning (which can be disabled, of course).
    2. `"abs"` calculates the mean element-wise absolute difference between the detection pattern and the actual attention pattern.
    The "raw result" ranges from 0 to 2 where lower score corresponds to greater accuracy. Subtracting it from 1 maps that range to (-1, 1) interval,
    with 1 being perfect match and -1 perfect mismatch.

    **Which one should you use?** `"abs"` is likely better for quick or exploratory investigations. For precise examinations where you're trying to
    reproduce as much functionality as possible or really test your understanding of the attention head, you probably want to switch to `"abs"`.

    The advantage of `"abs"` is that you can make more precise predictions, and have that measured in the score.
    You can predict, for instance, 0.2 attention to X, and 0.8 attention to Y, and your score will be better if your prediction is closer.
    The "mul" metric does not allow this, you'll get the same score if attention is 0.2, 0.8 or 0.5, 0.5 or 0.8, 0.2.

    Args:
    ----------
        model: Model being used.
        seq: String or list of strings being fed to the model.
        head_name: Name of an existing head in HEAD_NAMES we want to check. Must pass either a head_name or a detection_pattern, but not both!
        detection_pattern: (sequence_length, sequence_length) Tensor representing what attention pattern corresponds to the head we're looking for **or** the name of a pre-specified head. Currently available heads are: `["previous_token_head", "duplicate_token_head", "induction_head"]`.
        heads: If specific attention heads is given here, all other heads' score is set to -1. Useful for IOI-style circuit analysis. Heads can be spacified as a list tuples (layer, head) or a dictionary mapping a layer to heads within that layer that we want to analyze.
        cache: Include the cache to save time if you want.
        exclude_bos: Exclude attention paid to the beginning of sequence token.
        exclude_current_token: Exclude attention paid to the current token.
        error_measure: `"mul"` for using element-wise multiplication (default). `"abs"` for using absolute values of element-wise differences as the error measure.

    Returns:
    ----------
    A (n_layers, n_heads) Tensor representing the score for each attention head.

    Example:
    --------
    .. code-block:: python

        >>> from transformer_lens import HookedTransformer,  utils
        >>> from transformer_lens.head_detector import detect_head
        >>> import plotly.express as px

        >>> def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
        >>>     px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

        >>> model = HookedTransformer.from_pretrained("gpt2-small")
        >>> sequence = "This is a test sequence. This is a test sequence."

        >>> attention_score = detect_head(model, sequence, "previous_token_head")
        >>> imshow(attention_score, zmin=-1, zmax=1, xaxis="Head", yaxis="Layer", title="Previous Head Matches")
    """

    cfg = model.cfg
    tokens = model.to_tokens(seq).to(cfg.device)
    seq_len = tokens.shape[-1]
    
    # Validate error_measure
    
    assert error_measure in get_args(ErrorMeasure), f"Invalid {error_measure=}; valid values are {get_args(ErrorMeasure)}"

    # Validate detection pattern if it's a string
    if isinstance(detection_pattern, str):
        assert detection_pattern in HEAD_NAMES, (
            INVALID_HEAD_NAME_ERR % detection_pattern
        )
        if isinstance(seq, list):
            batch_scores = [detect_head(model, seq, detection_pattern) for seq in seq]
            return torch.stack(batch_scores).mean(0)
        detection_pattern = cast(
            torch.Tensor,
            eval(f"get_{detection_pattern}_detection_pattern(tokens.cpu())"),
        ).to(cfg.device)

    # if we're using "mul", detection_pattern should consist of zeros and ones
    if error_measure == "mul" and not set(detection_pattern.unique().tolist()).issubset(
        {0, 1}
    ):
        logging.warning(
            "Using detection pattern with values other than 0 or 1 with error_measure 'mul'"
        )

    # Validate inputs and detection pattern shape
    assert 1 < tokens.shape[-1] < cfg.n_ctx, SEQ_LEN_ERR
    assert (
        is_lower_triangular(detection_pattern) and seq_len == detection_pattern.shape[0]
    ), DET_PAT_NOT_SQUARE_ERR % (seq_len, detection_pattern.shape)

    if cache is None:
        _, cache = model.run_with_cache(tokens, remove_batch_dim=True)

    if heads is None:
        layer2heads = {
            layer_i: list(range(cfg.n_heads)) for layer_i in range(cfg.n_layers)
        }
    elif isinstance(heads, list):
        layer2heads = defaultdict(list)
        for layer, head in heads:
            layer2heads[layer].append(head)
    else:
        layer2heads = heads

    matches = -torch.ones(cfg.n_layers, cfg.n_heads)

    for layer, layer_heads in layer2heads.items():
        # [n_heads q_pos k_pos]
        layer_attention_patterns = cache["pattern", layer, "attn"]
        for head in layer_heads:
            head_attention_pattern = layer_attention_patterns[head, :, :]
            head_score = compute_head_attention_similarity_score(
                head_attention_pattern,
                detection_pattern=detection_pattern,
                exclude_bos=exclude_bos,
                exclude_current_token=exclude_current_token,
                error_measure=error_measure,
            )
            matches[layer, head] = head_score
    return matches


# Previous token head
def get_previous_token_head_detection_pattern(
    tokens: torch.Tensor,  # [batch (1) x pos]
) -> torch.Tensor:
    """Outputs a detection score for [previous token heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=0O5VOHe9xeZn8Ertywkh7ioc).

    Args:
      tokens: Tokens being fed to the model.
    """
    detection_pattern = torch.zeros(tokens.shape[-1], tokens.shape[-1])
    # Adds a diagonal of 1's below the main diagonal.
    detection_pattern[1:, :-1] = torch.eye(tokens.shape[-1] - 1)
    return torch.tril(detection_pattern)


# Duplicate token head
def get_duplicate_token_head_detection_pattern(
    tokens: torch.Tensor,  # [batch (1) x pos]
) -> torch.Tensor:
    """Outputs a detection score for [duplicate token heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=2UkvedzOnghL5UHUgVhROxeo).

    Args:
      sequence: String being fed to the model.
    """
    # [pos x pos]
    token_pattern = tokens.repeat(tokens.shape[-1], 1).numpy()

    # If token_pattern[i][j] matches its transpose, then token j and token i are duplicates.
    eq_mask = np.equal(token_pattern, token_pattern.T).astype(int)

    np.fill_diagonal(
        eq_mask, 0
    )  # Current token is always a duplicate of itself. Ignore that.
    detection_pattern = eq_mask.astype(int)
    return torch.tril(torch.as_tensor(detection_pattern).float())


# Induction head
def get_induction_head_detection_pattern(
    tokens: torch.Tensor,  # [batch (1) x pos]
) -> torch.Tensor:
    """Outputs a detection score for [induction heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=_tFVuP5csv5ORIthmqwj0gSY).

    Args:
      sequence: String being fed to the model.
    """
    duplicate_pattern = get_duplicate_token_head_detection_pattern(tokens)

    # Shift all items one to the right
    shifted_tensor = torch.roll(duplicate_pattern, shifts=1, dims=1)

    # Replace first column with 0's
    # we don't care about bos but shifting to the right moves the last column to the first,
    # and the last column might contain non-zero values.
    zeros_column = torch.zeros(duplicate_pattern.shape[0], 1)
    result_tensor = torch.cat((zeros_column, shifted_tensor[:, 1:]), dim=1)
    return torch.tril(result_tensor)


def get_supported_heads() -> None:
    """Returns a list of supported heads."""
    print(f"Supported heads: {HEAD_NAMES}")


def compute_head_attention_similarity_score(
    attention_pattern: torch.Tensor,  # [q_pos k_pos]
    detection_pattern: torch.Tensor,  # [seq_len seq_len] (seq_len == q_pos == k_pos)
    *,
    exclude_bos: bool,
    exclude_current_token: bool,
    error_measure: ErrorMeasure,
) -> float:
    """Compute the similarity between `attention_pattern` and `detection_pattern`.

    Args:
      attention_pattern: Lower triangular matrix (Tensor) representing the attention pattern of a particular attention head.
      detection_pattern: Lower triangular matrix (Tensor) representing the attention pattern we are looking for.
      exclude_bos: `True` if the beginning-of-sentence (BOS) token should be omitted from comparison. `False` otherwise.
      exclude_bcurrent_token: `True` if the current token at each position should be omitted from comparison. `False` otherwise.
      error_measure: "abs" for using absolute values of element-wise differences as the error measure. "mul" for using element-wise multiplication (legacy code).
    """
    assert is_square(
        attention_pattern
    ), f"Attention pattern is not square; got shape {attention_pattern.shape}"

    # mul

    if error_measure == "mul":
        if exclude_bos:
            attention_pattern[:, 0] = 0
        if exclude_current_token:
            attention_pattern.fill_diagonal_(0)
        score = attention_pattern * detection_pattern
        return (score.sum() / attention_pattern.sum()).item()

    # abs

    abs_diff = (attention_pattern - detection_pattern).abs()
    assert (abs_diff - torch.tril(abs_diff).to(abs_diff.device)).sum() == 0

    size = len(abs_diff)
    if exclude_bos:
        abs_diff[:, 0] = 0
    if exclude_current_token:
        abs_diff.fill_diagonal_(0)

    return 1 - round((abs_diff.mean() * size).item(), 3)


## Using Head Detector For Premade Heads



Load the model

In [8]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


See what heads are supported out of the box

In [9]:
get_supported_heads()

Supported heads: ('previous_token_head', 'duplicate_token_head', 'induction_head')


Let's test detecting previous token head in the following prompt.

In [10]:
prompt = "The head detector feature for TransformerLens allows users to check for various common heads automatically, reducing the cost of discovery."
head_scores = detect_head(model, prompt, "previous_token_head")
plot_head_detection_scores(head_scores, title="Previous Head Matches")

We can see both L2H2 and L4H11 are doing a fair bit of previous token detection. Let's take a look and see if that pans out.

In [11]:
_, cache = model.run_with_cache(prompt)

In [12]:
plot_attn_pattern_from_cache(cache, 2)

Layer 2 Attention Heads:


In [13]:
plot_attn_pattern_from_cache(cache, 4)

Layer 4 Attention Heads:


As we expected, L2H2 is doing a lot of previous token detection, but doesn't appear to be a sharp previous token detection head. L4H11, on the other hand, is pretty much perfect. In fact, the only place it seems to be putting any other attention is the very first token, where it pays attention to the BOS (*beginning-of-sentence*) token.

Mechanistic interpretability is still a very new field, and we don't know the best ways to measure things yet. Ignoring attention paid to BOS allows us to solve problems like the above, but may also give us artifically high results for a head like L4H10, which doesn't appear to be doing much of anything, but does have a bit of previous token attention going on if you squint carefully.

As such, the head detector supports both an `exclude_bos` and `exclude_current_token` argument, which ignores all BOS attention and all current token attention respectively. By default these are `False`, but this is a pretty arbitrary decision, so feel free to try things out! You don't need a good reason to change these arguments - pick whatever best helps you find out useful things!

In [14]:
head_scores = detect_head(model, prompt, "previous_token_head", exclude_bos=True, exclude_current_token=True)
plot_head_detection_scores(head_scores, title="Previous Head Matches")

Now we have a lot more detection, including L0H3 and L5H6 which were unremarkable before. Let's check them out!

In [15]:
plot_attn_pattern_from_cache(cache, 5)

Layer 5 Attention Heads:


In [16]:
plot_attn_pattern_from_cache(cache, 0)

Layer 0 Attention Heads:


Here, we see some interesting results. L5H6 does very little, but happens to react quite strongly to the first token of "Trans|former". (Capital letters? Current word detection? We don't know)

L0H3 reacts almost entirely to the current token, but what little it does outside of this pays attention to the previous token. Again, it seems to be caring about the first token of "Trans|former".

In order to more fully automate these heads, we'll need to discover more principled ways of expressing these scores. For now, you can see how while scores may be misleading, different scores lead us to interesting results.

## Using Head Detector for Custom Heads

These heads are great, but sometimes there are more than three things going on in Transformers. [citation needed] As a result, we may want to use our head detector for things that aren't pre-included in TransformerLens. Fortunately, the head detector provides support for this, via **detection patterns**.


A detection pattern is simply a matrix of the same size as our attention pattern, which specifies the attention pattern exhibited by the kind of head we're looking for.

There are two error measures available for quantifying the match between the detection pattern and the actual attention pattern. You can choose it by passing the right value to the `error_measure` argument.



### 1. `"mul"` (default) multiplies both tensors element-wise and divides the sum of the result by the sum of the attention pattern.

Typically, the detection pattern should in this case contain only ones and zeros, which allows a straightforward interpretation of the score: how big fraction of this head's attention is allocated to these specific query-key pairs? Using values other than 0 or 1 is not prohibited but will raise a warning (which can be disabled, of course).

<br>

$$
\begin{pmatrix}
1 & 0 & 0 & 0 \\
0.5 & 0.5 & 0 & 0 \\
0.2 & 0.3 & 0.5 & 0 \\
0.1 & 0.15 & 0.5 & 0.25
\end{pmatrix}
\odot
\begin{pmatrix}
0 & 0 & 0 & 0 \\
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
0 & 0 & 1 & 0
\end{pmatrix}
=
\begin{pmatrix}
0 & 0 & 0 & 0 \\
0.5 & 0 & 0 & 0 \\
0 & 0.3 & 0 & 0 \\
0 & 0 & 0.5 & 0
\end{pmatrix}
$$

<br>

0.5, 0.3, and 0.5 all get multiplied by 1, so they get kept. All the others go to 0 and are removed. (Note: You can use values other than 0 or 1 when creating your own heads)

Our total score would then be 1.3 / 4, or 0.325. If we ignore bos and current token, it would be 0.8 / 0.95 instead, or ~0.842. (This is a large difference, but the difference generally gets smaller as the matrices get bigger)

This is how the head detector works under the hood - each existing head just has its own detection pattern. Thus, we can pass in our own detection pattern using the `detection_pattern` argument.


### 2. `"abs"` calculates the mean element-wise absolute difference between the detection pattern and the actual attention pattern.

The "raw result" ranges from 0 to 2 where lower score corresponds to greater accuracy. Subtracting it from 1 maps that range to (-1, 1) interval, with 1 being perfect match and -1 perfect mismatch.

We take the attention pattern and compute its absolute element-wise difference with our detection pattern. Since every number in any of the two patterns has a value between -1 and 1, the maximum absolute difference of any pair is 2 and the minimum is 0:

$$|-1-1|=|1-(-1)|=2$$

$$|x-x|=0$$

That number tells us how much our expectation and the real attention pattern diverge, i.e., the error.

$$
M_{diff}=
\left|
\begin{pmatrix}
1 & 0 & 0 & 0
\\
0.5 & 0.5 & 0 & 0 
\\
0.2 & 0.3 & 0.5 & 0 
\\
0.1 & 0.15 & 0.5 & 0.25 
\end{pmatrix}
-
\begin{pmatrix}
0 & 0 & 0 & 0
\\
1 & 0 & 0 & 0 
\\
0 & 1 & 0 & 0 
\\
0 & 0 & 1 & 0 
\end{pmatrix}
\right|
=
\begin{pmatrix}
1 & 0 & 0 & 0
\\
0.5 & 0.5 & 0 & 0 
\\
0.2 & 0.7 & 0.5 & 0
\\
0.1 & 0.15 & 0.5 & 0.25 
\end{pmatrix}
$$


We take the mean and multiply it by the number of rows.

We subtract the result from 1 in order to map the (0, 2) interval where lower is better to the (-1, 1) interval where higher is better.

$$1 - \text{n_rows} \times \text{mean}(M_{diff}) = 1 - 4 \times 0.275 = 1 - 1.1 = -.1$$

Our final score would then be -1. If we ignore `BOS` and current token, it would be 0.6625. (This is a large difference, but the difference generally gets smaller as the matrices get bigger.)

This is how the head detector works under the hood - each existing head just has its own detection pattern. Thus, we can pass in our own detection pattern using the `detection_pattern` argument.

I'm curious what's going on with this L0H3 result, where we mostly focus on the current token but occasionally focus on the "Trans" token in "Trans|former". Let's make a **current word head** detection pattern, which returns 1 for previous tokens that are part of the current word being looked at, and 0 for everything else.

### **Which one should you use?** 

`"abs"` is likely better for quick or exploratory investigations. For precise examinations where you're trying to reproduce as much functionality as possible or really test your understanding of the attention head, you probably want to switch to `"abs"`. 

The advantage of `"abs"` is that you can make more precise predictions, and have that measured in the score. You can predict, for instance, 0.2 attention to X, and 0.8 attention to Y, and your score will be better if your prediction is closer. The "mul" metric does not allow this, you'll get the same score if attention is 0.2, 0.8 or 0.5, 0.5 or 0.8, 0.2.

Below we show how different scores these two measures can give on the same prompt. After that, we will proceed with using `"abs"` and will get back to `"mul"` at the end of the notebook.

In [17]:
prompt = "The following lexical sequence has been optimised for the maximisation of loquaciously multitoken letter combinations."
tokens = model.to_str_tokens(prompt)
print(len(tokens), tokens)
detection_pattern = []
for i in range(2):
  detection_pattern.append([0 for t in tokens]) # Ignore BOS token and first token.
for i in range(2, len(tokens)):
    current_token = i
    previous_tokens_in_word = 0
    while not tokens[current_token].startswith(' '): # If the current token does not start with a space (and is not the first token) it's part of a word.
      previous_tokens_in_word += 1
      current_token -= 1
    # Hacky code that adds in some 1's where needed, and fills the rest of the row with 0's.
    detection_pattern.append([0 for j in range(i - previous_tokens_in_word)] + [1 for j in range(previous_tokens_in_word)] + [0 for j in range(i+1, len(tokens)+1)])
detection_pattern = torch.as_tensor(detection_pattern).to(device)
detection_pattern.shape

23 ['<|endoftext|>', 'The', ' following', ' lex', 'ical', ' sequence', ' has', ' been', ' optim', 'ised', ' for', ' the', ' maxim', 'isation', ' of', ' lo', 'qu', 'aciously', ' multit', 'oken', ' letter', ' combinations', '.']


torch.Size([23, 23])

In [18]:
_, cache = model.run_with_cache(prompt)

`"mul"`

In [19]:
head_scores = detect_head(
    model, 
    prompt, 
    detection_pattern=detection_pattern, 
    exclude_bos=False, 
    exclude_current_token=True, 
    error_measure="mul"
)
plot_head_detection_scores(head_scores, title="Current Word Head Matches (mul)")

`"abs"`

In [20]:
head_scores = detect_head(
    model, 
    prompt, 
    detection_pattern=detection_pattern, 
    exclude_bos=False, 
    exclude_current_token=True, 
    error_measure="abs"
)
plot_head_detection_scores(head_scores, title="Current Word Head Matches (abs)")

75% match for L0H3 - only 16% for L5H6. Let's check them out with our new sequence!

In [21]:
plot_attn_pattern_from_cache(cache, 5)

Layer 5 Attention Heads:


In [22]:
plot_attn_pattern_from_cache(cache, 0)

Layer 0 Attention Heads:


As we can see, L5H6 appears to be doing something totally different than we expected, whereas L0H3 is mostly doing what we expected - by our original hypothesis, we would expect "lo|qu|aciously" to have a lot of attention paid to, and "combinations|." the same, which didn't happen. However, our two-token words were exactly as we expected. Could this be a two-token detector (that doesn't work on punctuation)? A "current word" detector that just doesn't understand an obscure word like "loquaciously"? The field is full of such problems, just waiting to be answered!

So, why do this at all? For just a couple of sentences, it's easier to just look at the attention patterns directly and see what we get. But as we can see, heads react differently to different sentences. What we might want to do is give an entire dataset or distribution of sentences to our attention head and see that it consistently does what we want - that's something that would be much harder without this feature!

So what if we gave it a whole distribution? Rather than actually create one, which is not the point of this demo, we're just going to repeat our last sentence a thousand times.

In [23]:
scores = []
for i in tqdm(range(100)):
    scores.append(detect_head(model, prompt, detection_pattern=detection_pattern, exclude_bos=False, exclude_current_token=True, error_measure="abs"))
scores = torch.stack(scores).mean(dim=0)
plot_head_detection_scores(scores, title="Current Word Head Matches")

100%|██████████| 100/100 [00:13<00:00,  7.64it/s]


## Processing Many Prompts

`detect_head` can also take more than one prompt. The resulting attention score is the mean of scores for each prompt.

In [24]:
prompts = [
    "This is the first the test prompt.",
    "This is another test prompt, being just a sequence of tokens.",
    "If you're interested in mechanistic interpretability, this is how the sausage REALLY is made."
]

In [25]:
head_scores = detect_head(model, prompts, "previous_token_head", error_measure="abs")
plot_head_detection_scores(head_scores, title="Previous token head; average across 3 prompts")

L4H11 emerges again as the dominant head, exactly as expected.

What about duplicate token heads?

In [26]:
head_scores = detect_head(model, prompts, "duplicate_token_head", error_measure="abs")
plot_head_detection_scores(head_scores, title="Duplicate token head; average across 3 prompts")

Nothing but this should be expected, in hindsight, since our prompts don't contain too many duplicate tokens. Let's try three other prompts that do.

In [27]:
prompts = [
    "one two three one two three one two three",
    "1 2 3 4 5 1 2 3 4 1 2 3 1 2 3 4 5 6 7",
    "green ideas sleep furiously; green ideas don't sleep furiously"
]

In [28]:
head_scores = detect_head(model, prompts, "duplicate_token_head", exclude_bos=False, exclude_current_token=False, error_measure="abs")
plot_head_detection_scores(head_scores, title="Duplicate token head; average across 3 prompts")

3 or 4 heads seem to do something that we would expected from a duplicate token head but the signal is not very strong. You can tweak the `exclude_bos` and `exclude_current_token` flags if you want, but it doesn't change much.

Let's hunt for induction heads now!

In [34]:
head_scores = detect_head(model, prompts, "induction_head", exclude_bos=False, exclude_current_token=False, error_measure="abs")
plot_head_detection_scores(head_scores, title="Duplicate token head; average across 3 prompts")

Similarly, at least on average.

Try running the script on different prompts and see if you can get high values for duplicate token or induction heads.

## Why not element-wise multiplication - robustness against [Goodharting](https://en.wikipedia.org/wiki/Goodhart%27s_law)

Initially, the error measure was not the mean element-wise absolute value error (normalized to the number of rows) but the mean [element-wise product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)). However, it had its problems, such as susceptibility to Goodharting. You can specify a pattern consisting of all ones and in this way achieve a perfect match for all layers and heads in the model.

More generally, using element-wise product causes the score to go down when we narrow our hypothesis. We can get a maximum score by just predicting 1 for everything. 

In [30]:
prompt = "The head detector feature for TransformerLens allows users to check for various common heads automatically, reducing the cost of discovery."
seq_len = len(model.to_str_tokens(prompt))
# torch.tril to make the pattern lower triangular
ones_detection_pattern = torch.tril(torch.ones(seq_len, seq_len).to(device))

In [31]:
ones_head_scores = detect_head(
    model, 
    prompt, 
    ones_detection_pattern, 
    exclude_bos=True, 
    exclude_current_token=True, 
)
plot_head_detection_scores(ones_head_scores, title="Transformers Have Now Been Solved, We Can All Go Home")

The new error measure also achieves uniform score but this time its uniformly extremely negative because **not a single head in the model matches this pattern**.

*(It's true that the scores descend below -9 whereas in theory they should remain within the (-1, 1) range. It's not yet clear if that matters for real-world uses.)*

An alternative would be to demand that *predictions add up to 1 for each row* but that seems unnecessarily nitpicky considering that your score will get reduced in general for not doing that anyway.

Mean squared errors have also bean tried before converging on the absolute ones. The problem with MSE is that the scores get lower as attention gets more diffuse. Error value of 1 would become 1, 0.5 would become 0.25 etc.

In [32]:
ones_head_scores = detect_head(
    model, 
    prompt, 
    ones_detection_pattern, 
    exclude_bos=True, 
    exclude_current_token=True, 
    error_measure="abs" # we specify the error measure here
)
plot_head_detection_scores(ones_head_scores, title="Transformers Have Not Been Solved Yet, Get Back To Work!")

## Further improvements

**Performance for large distributions** isn't as good as it could be. The head detector could be rewritten to support taking in a list of sequences and performing these computations in parallel, but 1000 sequences per minute is certainly adequate for most use cases. If having this be faster would help your research, please write up an issue on TransformerLens, mention it on the Open Source Mechanistic Interpretability Slack, or e-mail jaybaileycs@gmail.com.

### Other

- Extending to few-shot learning/translation heads
- More pre-specified heads?
- For inspiration, see [this post from Neel](https://www.lesswrong.com/s/yivyHaCAmMJ3CqSyj/p/btasQF7wiCYPsr5qw)