# Path Patching & Activation Patching

This notebook contains my implementation of **path patching**. Eventually I'll add it to TransformerLens as a demo, but for now I just wanted to add it as a personal repo. I expect it to be useful for the research sprint I'm about to start for Neel Nanda's SERI MATS stream.

***This code is designed to be low-effort for mechanistic interpretability researchers doing exploratory analysis***. It's not been optimized for e.g. fitting into larger systems like ACDC. I've tried to follow the general "research as play" philosophy of TransformerLens.

You can find the public GitHub repo [here](https://github.com/callummcdougall/path_patching). The code below uses code directly from that repo.

### Contents

* **Setup code** - run this to `wget` repo, import libs, and define datasets.
* **Activation Patching** - introduces the main `act_patch` function, and shows how to use it (with 2 examples).
* **Path Patching** - introduces the main `path_patch` function, and shows how to use it (with 4 examples).
* **Appendix** - explains what path patching is, and includes some thoughts about the pros/cons of my implementation.

*Note - I'd recommend you focus more on the examples than on the explanations that precede them, because I expect staring at these to be a faster way of understanding how the library works.*

Please comment or send me a message if you have any feedback. I hope you find this useful!

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/pathpatching.png" width="320">

# Setup code (don't read, just run)

When you run this, it'll download the `path_patching.py` file from my repo and put it in local storage; you'll be able to see it in the Files menu on the left.

In [None]:
from transformer_lens.cautils.utils import *

# from ioi_dataset import NAMES, IOIDataset
# from path_patching import Node, IterNode, path_patch, act_patch

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat",
    "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor",
    "showlegend", "xaxis_tickmode", "yaxis_tickmode", "xaxis_tickangle", "yaxis_tickangle", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"
}

def imshow(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    facet_labels = kwargs_pre.pop("facet_labels", None)
    border = kwargs_pre.pop("border", False)
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if border:
        fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
        fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    # things like `xaxis_tickmode` should be applied to all subplots. This is super janky lol but I'm under time pressure
    for setting in ["tickangle"]:
      if f"xaxis_{setting}" in kwargs_post:
          i = 2
          while f"xaxis{i}" in fig["layout"]:
            kwargs_post[f"xaxis{i}_{setting}"] = kwargs_post[f"xaxis_{setting}"]
            i += 1
    fig.update_layout(**kwargs_post)
    fig.show(renderer=renderer)

def hist(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    names = kwargs_pre.pop("names", None)
    if "barmode" not in kwargs_post:
        kwargs_post["barmode"] = "overlay"
    if "bargap" not in kwargs_post:
        kwargs_post["bargap"] = 0.0
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.histogram(x=tensor, **kwargs_pre).update_layout(**kwargs_post)
    if names is not None:
        for i in range(len(fig.data)):
            fig.data[i]["name"] = names[i // 2]
    fig.show(renderer)

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)

Neel's data (and metric functions):

In [None]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
name_pairs = [
    (" John", " Mary"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]

# Define 8 prompts, in 4 groups of 2 (with adjacent prompts having answers swapped)
prompts = [
    prompt.format(name)
    for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]
]
# Define the answers for each prompt, in the form (correct, incorrect)
answers = [names[::i] for names in name_pairs for i in (1, -1)]
# Define the answer tokens (same shape as the answers)
answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in answers
])

def logits_to_ave_logit_diff(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    per_prompt: bool = False
):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''
    # Only the final logits are relevant for the answer
    final_logits: Float[Tensor, "batch d_vocab"] = logits[:, -1, :]
    # Get the logits corresponding to the indirect object / subject tokens respectively
    answer_logits: Float[Tensor, "batch 2"] = final_logits.gather(dim=-1, index=answer_tokens)
    # Find logit difference
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

clean_tokens = model.to_tokens(prompts, prepend_bos=True).to(device)
flipped_indices = [i+1 if i % 2 == 0 else i-1 for i in range(len(clean_tokens))]
flipped_tokens = clean_tokens[flipped_indices]

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
flipped_logits, flipped_cache = model.run_with_cache(flipped_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
flipped_logit_diff = logits_to_ave_logit_diff(flipped_logits, answer_tokens)

print(
    "Clean string 0:    ", model.to_string(clean_tokens[0]), "\n"
    "Flipped string 0:", model.to_string(flipped_tokens[0])
)
print(f"Clean logit diff: {clean_logit_diff:.4f}")
print(f"Flipped logit diff: {flipped_logit_diff:.4f}")

def ioi_metric_denoising(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    flipped_logit_diff: float = flipped_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is
    same as on flipped input, and 1 when performance is same as on clean input.
    '''
    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return ((patched_logit_diff - flipped_logit_diff) / (clean_logit_diff  - flipped_logit_diff)).item()

labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]

Data (and metric functions) from authors of paper are below.

> *Sorry for the hilariously low-tech way of getting a large number of datapoints. I just did it in a for loop cause I didn't want to break memory and I don't know any better ways cause I'm pretty naive when it comes to doing anything at scale (-:*


In [None]:
def _logits_to_ave_logit_diff(logits: Float[Tensor, "batch seq d_vocab"], ioi_dataset: IOIDataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()



def _ioi_metric_noising(
        logits: Float[Tensor, "batch seq d_vocab"],
        clean_logit_diff: float,
        corrupted_logit_diff: float,
        ioi_dataset: IOIDataset,
    ) -> float:
        '''
        We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
        and -1 when performance has been destroyed (i.e. is same as ABC dataset).
        '''
        patched_logit_diff = _logits_to_ave_logit_diff(logits, ioi_dataset)
        return ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff)).item()



def generate_data_and_caches(N: int, verbose: bool = False, seed: int = 42):

    ioi_dataset = IOIDataset(
        prompt_type="mixed",
        N=N,
        tokenizer=model.tokenizer,
        prepend_bos=False,
        seed=seed,
        device=str(device)
    )

    abc_dataset = ioi_dataset.gen_flipped_prompts("ABB->XYZ, BAB->XYZ")

    model.reset_hooks(including_permanent=True)

    ioi_logits_original, ioi_cache = model.run_with_cache(ioi_dataset.toks)
    abc_logits_original, abc_cache = model.run_with_cache(abc_dataset.toks)

    ioi_average_logit_diff = _logits_to_ave_logit_diff(ioi_logits_original, ioi_dataset).item()
    abc_average_logit_diff = _logits_to_ave_logit_diff(abc_logits_original, ioi_dataset).item()

    if verbose:
        print(f"Average logit diff (IOI dataset): {ioi_average_logit_diff:.4f}")
        print(f"Average logit diff (ABC dataset): {abc_average_logit_diff:.4f}")

    ioi_metric_noising = partial(
        _ioi_metric_noising,
        clean_logit_diff=ioi_average_logit_diff,
        corrupted_logit_diff=abc_average_logit_diff,
        ioi_dataset=ioi_dataset,
    )

    return ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising



N = 30
ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric_noising = generate_data_and_caches(N, verbose=True)

# Activation Patching

*I'd recommend skimming the explanation below, and then jumping into the examples. I expect that they'll be pretty self-explanatory, and reading them will help you understand how to use this library faster than reading all the details below.*

## `Node` and `IterNode` (at a high level)

The two main classes which are used to specify the patching operations are `Node` and `IterNode`.

`Node` allows you to specify an activation, or a particular slice of an activation (e.g. corresponding to a particular attention head in an attention layer, a particular neuron in an MLP, or a particular sequence position). Here are some examples:

```python
Node("q", layer=9, head=9)           # -> query input of head 9.9
Node("blocks.9.attn.hook_q", head=9) # -> same as above
Node("attn_out", 8)                  # -> output of all attention heads in layer 8
Node("post", 7, neuron=100)          # -> post-activation function output of neuron L7N100
Node("resid_pre", 1, seq_pos=0)      # -> first sequence position of all resid_pre activations in layer 1
```

Initialising a node works just like indexing into the cache (i.e. the first argument is component name, the second optional argument is layer). The only difference is the extra ability to specify a head, neuron, or sequence position (all three default to "None", meaning "patch over all indices").

Each instance of patching will be on a node or set of nodes.

`IterNode` allows you to iterate through multiple nodes at once. In theory you could just do this by defining a list of `Node` objects and iterating through them yourself, but doing it via the `IterNode` object is more convenient.

*Reading the examples below should help, if any of these design choices seem confusing right now.*

## How does the function work?

### Single activation patching

We perform a single instance of patching using the `Node` class, which was discussed above. A few more details:

`head` and `neuron` can be `None` (the default, meaning all heads/neurons), or an int, or a list of ints. `seq_pos` can take 4 possible values:

* `None`, meaning we patch at all sequence positions
* An integer, meaning we patch at that sequence position for all sequences in the batch
* A 1D tensor of shape `(batch,)` meaning we patch at `seq_pos[i]` for the `i`th sequence in the batch
* A 2D tensor of shape `(batch, pos)`, meaning we patch at all positions `seq_pos[i, :]` for the `i`th sequence in the batch

Running `act_patch` when `patching_nodes` is a `Node` object (or list of `Node` objects) will perform a single instance of activation patching, and return a float.

### Multiple instances of patching

We perform multiple instances of patching using the `IterNode` class. This class takes two arguments: a component name or list of names (e.g. z, q, resid_pre, etc), and a `seq_pos` argument. `seq_pos` works just like the `seq_pos` argument for `Node`, except there is a fifth option: `"each"`, which means we iterate over all sequence positions (as oppose to using the same sequence positions for each iteration).

```python
IterNode("resid_post", seq_pos="each") # -> gives (seq_len, layers) tensor, with results of patching head output at all positions
IterNode(["q", "k", "v"])              # -> gives (layers, heads) tensors, one for each of "q", "k", "v"
```

Running `act_patch` when `patching_nodes` is an `IterNode` object will return a dictionary of results, where keys are component names (e.g. `"q"`, `"k"`, `"v"` for the example above), and values are the tensors of results.

### Additional details

* You can use `apply_metric_to_cache` to have the metric function take in a cache rather than a logit output. This is sometimes useful e.g. if you want to see what the effect of patching is on some intermediate attention patterns (as opposed to the final logit output).
* When using iterative patching, you can set `verbose=True` to have the shape of your output printed out (with dimensions named).
* You have to specify `new_input` (which is the thing we patch in). You can supply `orig_input` or `orig_cache` (in the former case, a cache will be created during the function call).
* The `patching_metric` function can return a tensor rather than a float. If you don't return a float *and* you use `IterNode`, the return type will be `Dict[str, list]`, i.e. your results will be stored in a flattened list rather than converted to a tensor. In this case you should use `verbose=True`, so you know what the order of the dimensions would be.

### Terminology note - `orig` vs `new`

I used the terminology `orig_input` and `new_input`, rather than `clean_input` and `corrupted_input`. This is because the latter is ambiguous - **noising algorithms** (where we patch from clean to corrupted) are more common, but there are also **denoising algorithms** (where we patch from corrupted to clean). In my terminology, we always run the model on `orig` and patch in from the `new` distribution (same for path patching).

## Examples

### 1️⃣ Activation patching at head outputs (denoising)

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/e1.png" width="900">

In [None]:
results = act_patch(
    model=model,
    orig_input=flipped_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode("z"), # iterating over all heads' output in all layers
    patching_metric=ioi_metric_denoising,
    verbose=True,
)

In [None]:
imshow(
    results['z'] * 100,
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

These results tell us which heads are probably important. We can see:

* the main early heads (duplicate token head 3.0 and induction head 5.5),
* the S-inhibition heads in layers 7 and 8,
* the name mover heads like 9.9,
* the negative name mover heads 10.7 and 11.10 (which show up negative, because they make performance worse).

Now an example of how you could get a single one of these results (rather than iterating over nodes). You can compare this result to the one you can see in the plot above.

In [None]:
# patching at output of head 9.9 (the name mover head)

act_patch(
    model=model,
    orig_input=flipped_tokens,
    new_cache=clean_cache,
    patching_nodes=Node("z", layer=9, head=9),
    patching_metric=ioi_metric_denoising,
    verbose=True,
)

Lastly, an example of how you could do this at multiple different parts of the attention head at once (by adding more components to the `IterNode` object):

In [None]:
# iterating over all heads' output in all layers

results = act_patch(
    model=model,
    orig_input=flipped_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode(["z", "q", "k", "v", "pattern"]),
    patching_metric=ioi_metric_denoising,
    verbose=True,
)

In [None]:
assert results.keys() == {"z", "q", "k", "v", "pattern"}
assert all([r.shape == (12, 12) for r in results.values()])

imshow(
    t.stack(tuple(results.values())) * 100,
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1500,
    margin={"r": 100, "l": 100}
)

This tells us not just which heads are important (the left-side plot is the same as the one we generated previously), but also a bit more about **why** they're important. For example:

* The value inputs of the S-inhibition heads are important (because in they move information about the first (duplicated) subject token name from `S2` to `end`, and both the identity and position of the subject token got flipped when we created our flipped dataset).
* They query inputs of the earlier heads are important (because their job is to attend from `S2` to `S1` (or to `S1+1` in case of the induction heads), and the destination token at `S2` got flipped when we created our flipped dataset).

### 2️⃣ Patching at residual stream & block outputs, by sequence position (denoising)

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/e2.png" width="700">

In [None]:
# patching at each (layer, sequence position) for each of (resid_pre, attn_out, mlp_out) in turn

results = act_patch(
    model=model,
    orig_input=flipped_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode(["resid_pre", "attn_out", "mlp_out"], seq_pos="each"),
    patching_metric=ioi_metric_denoising,
    verbose=True,
)

In [None]:
assert results.keys() == {"resid_pre", "attn_out", "mlp_out"}

imshow(
    t.stack([r.T for r in results.values()]) * 100, # we transpose so layer is on the y-axis
    facet_col=0,
    facet_labels=["resid_pre", "attn_out", "mlp_out"],
    title="Patching at resid stream & layer outputs (corrupted -> clean)",
    labels={"x": "Sequence position", "y": "Layer", "color": "Logit diff variation"},
    x=labels,
    xaxis_tickangle=45,
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1300,
    margin={"r": 100, "l": 100}
)

*Note - the tokens on the x-axis are just the tokens for the first sequence in the batch. The batch has size 8, and the other sequences have the same basic structure (including position of second subject token).*

This shows us that the important information for solving the IOI task is localised. It starts in the second subject token (this is the only token which is different between the clean and flipped data), and then moves to the end token around layers 7 and 8 (the S-inhibition heads are the ones doing this).

We can also see:

* No MLPs really matter except MLP0 (this is important for GPT-2, it acts as a sort of extended embedding)
* There are some neg name mover heads in layers 10 and 11, which suppress the correct answer (we can see them in the middle plot)

Lastly, this is a bit artificial, but let's see how we can specify a particular sequence position for `IterNode`. These results will match the `Mary 10` slice of the results above.

In [None]:
# Only patching at the 'Mary 10' position

results = act_patch(
    model=model,
    orig_input=flipped_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode(["resid_pre", "attn_out", "mlp_out"], seq_pos=10),
    patching_metric=ioi_metric_denoising,
    verbose=True,
)

In [None]:
imshow(
    t.stack(list(results.values())).T * 100, # we transpose so layer is on the y-axis
    title="Patching at 'Mary_10'",
    labels={"x": "Component", "y": "Layer", "color": "Logit diff variation"},
    x=["resid_pre", "attn_out", "mlp_out"],
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=400,
)

What if we always want to patch at the second subject token, but that's not always at `seq_pos=10`? We can supply a tensor as `seq_pos` too! The code below will give identical results (and unlike the code above, it could in theory accept different patching positions for each sequence).

In [None]:
# Only patching at the 'Mary 10' position

seq_pos = t.tensor([10 for _ in range(len(flipped_tokens))])

results = act_patch(
    model=model,
    orig_input=flipped_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode(["resid_pre", "attn_out", "mlp_out"], seq_pos=seq_pos),
    patching_metric=ioi_metric_denoising,
    verbose=True,
)

In [None]:
imshow(
    t.stack(list(results.values())).T * 100, # we transpose so layer is on the y-axis
    title="Patching at 'Mary_10'",
    labels={"x": "Component", "y": "Layer", "color": "Logit diff variation"},
    x=["resid_pre", "attn_out", "mlp_out"],
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=400,
)

### 3️⃣ Patching at multiple nodes at once

If you pass a list of nodes, then they all get patched. Note that **this is different to using `IterNode` (which iterates through all nodes in `IterNode`, one at a time)**. If you want to do a form of iteration which isn't covered by `IterNode`, you'll need to make your own for loop for it.

To illustrate the difference, I'll perform activation patching in a for loop *and* with a list of `Node` objects. Specifically, I'm going to patch all the queries of the three **name mover heads**, then all the keys, then all the values (i.e. getting three different results).

In [None]:
NAME_MOVER_HEADS = [(9, 6), (9, 9), (10, 10)]

for component in "qkv":
    result = act_patch(
        model=model,
        orig_input=flipped_tokens,
        new_cache=clean_cache,
        patching_nodes=[Node(component, layer=layer, head=head) for layer, head in NAME_MOVER_HEADS],
        patching_metric=ioi_metric_denoising,
        verbose=True,
    )
    print(f"Patching name movers at {component}: {result}")

We can see that the activation patching results for queries are much larger than for keys or values. This is because the name mover heads attend from the `end` token to the `IO` token (in order to copy and predict it), and it's query composition from earlier heads which tells them where to attend / which token to attend to. The flipped dataset changes both the position and the identity of the `IO` token, so the name mover heads will be copying the wrong thing. Patching at the queries cause them to attend to the right thing, hence we get a positive score. Patching at keys or values won't help, because the flipped dataset flips the `{}` token in the prompt `"When John and Mary went to the shops,{} gave the bag to"` - in other words, the first subject and object tokens don't change.

# Path Patching

*I'd recommend skimming the explanation below, and then jumping into the examples. I expect that they'll be pretty self-explanatory, and reading them will help you understand how to use this library faster than reading all the details below.*

*Make sure you've at least skimmed the "activation patching" section first.*

## How does the function work?

### Single path patching

Just like activation patching, we use the `Node` class for single instances. We specify `sender_nodes` and `receiver_nodes` (both can be `Node` objects, or lists of `Node` objects). In this case, we get a float returned.

### Multiple instances of patching

Just like activation patching, we use the `IterNode` class for multiple instances. We always have one of `sender_nodes` and `receiver_nodes` be fixed, and the other be iterated over (if you try to iterate over both, you get an error).

### Sequence positions

In activation patching, we can have different positions for each `Node`. But in path patching, it wouldn't make sense to have e.g. a different sequence position for a sender and receiver node (because direct paths can't connect different sequence positions). Instead, we use a single argument `seq_pos` in our `path_patch` function, which applies to all nodes (sender and receiver, fixed or iterate). This `seq_pos` argument has the same syntax as for activation patching (i.e. it can be `None` (default), an int, a 1D tensor of shape `(batch,)`, or a 2D tensor of shape `(batch, pos)`.

So in path patching, **you shouldn't use `seq_pos` for individual `Node` objects.** Similarly, you shouldn't set `seq_pos` for `IterNode` (unless you're specifying `seq_pos="each"`, in which case this will override the `seq_pos` argument). If you try and do this, you'll get an error.

### Direct includes MLPs?

There are 2 main ways to do path patching, which are supported by this code.

1. **MLPs count as direct paths**, in other words a direct path is any path which doesn't go through an attention head. This is what the IOI paper did. It requires a more complicated algorithm (2 forward passes), because the MLPs between sender and receiver need to be recomputed.
2. **MLPs don't count as direct paths**, in other words a direct path is syonymous with "skip connection". This is twice as fast (only requiring one forward pass).

You can specify which option using the boolean argument `direct_includes_mlp`. The default is `True` (i.e. the first of these two options).

### Additional details

* You can use `apply_metric_to_cache` to have the metric function take in a cache rather than a logit output. This is sometimes useful e.g. if you want to see what the effect of patching is on some intermediate attention patterns (as opposed to the final logit output).
* Like before, you can use `verbose=True` to have the shape of your output printed out (with dimensions named).
* You have to specify `new_input` **and** `orig_input`. Both `orig_cache` and `new_cache` are optional (if you don't supply them, then they will be computed by the function).

## Examples

### 1️⃣ Patching from attention head -> final residual stream value

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/e3.png" width="400">

This captures the direct effect of one of our heads. The most important head in the IOI circuit (in a direct sense) is `9.9`, so we expect this one to have a pretty large value.

The metric `ioi_metric_noising` is -1 when performance is destroyed, and 0 when performance is exactly the same, so we're expecting a result quite close to 1.

In [None]:
model.reset_hooks()

results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=Node('z', 9, head=9), # This is the output of head 9 at layer 9
    receiver_nodes=Node('resid_post', 11), # This is resid_post at layer 11
    patching_metric=ioi_metric_noising,
)

print(results)

Now let's do something similar, but for `direct_includes_mlps=False`. We expect to get something similar, but not necessarily exactly the same (because it's a slightly different methodology).

In [None]:
results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=Node("z", 9, head=9), # This is the output of head 9 at layer 9
    receiver_nodes=Node("resid_post", 11), # This is resid_post at layer 11
    patching_metric=ioi_metric_noising,
    direct_includes_mlps=False
)

print(results)

Now let's plot the results for all heads:

In [None]:
results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode('z'), # This means iterate over all heads in all layers
    receiver_nodes=Node('resid_post', 11), # This is resid_post at layer 11
    patching_metric=ioi_metric_noising,
    verbose=True
)

In [None]:
imshow(
    results['z'],
    title="Direct effect on logit diff (patch from head output -> final resid)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

This matches the results in figure 3(b) in the paper:

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/fig_3b.png" width="700">

Let's see if it's approximately the same for `direct_includes_mlps=False` (which it should be, if this method is at all principled):

In [None]:
results_direct = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode('z'), # This means iterate over all heads in all layers
    receiver_nodes=Node('resid_post', 11), # This is resid_post at layer 11
    patching_metric=ioi_metric_noising,
    direct_includes_mlps=False,
    verbose=True
)

In [None]:
imshow(
    t.stack([results['z'], results_direct['z']]),
    facet_col=0, facet_labels=["Direct path includes MLPs", "Direct path is just skip connections"],
    title="Each attention head's direct effect on logit difference",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    border=True,
    width=950,
    margin={"r": 100, "l": 100}
)

They look about the same - great!

Lastly, let's see what happens when we path patch from all the main name mover heads (9.6, 9.9 and 10.10) at once. The effect should be larger than the results we got above.

In [None]:
NAME_MOVER_HEADS = [(9, 6), (9, 9), (10, 10)]

results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=[Node("z", layer, head=head) for layer, head in NAME_MOVER_HEADS], # Output of all name mover heads
    receiver_nodes=Node("resid_post", 11), # This is resid_post at layer 11
    patching_metric=ioi_metric_noising,
)

print(results)

Performance is worse than zero! (This might be because of the **negative name mover heads**).

### 2️⃣ Patching from residual stream -> final residual stream value (for each sequence position)

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/e4.png" width="800">

Here we patch for each possible sender node = value of residual stream in (block, layer, position), with the receiver node being the final value of the residual stream. Block can be `resid_pre`, `attn_out` or `mlp_out`.

Note, this example is a bit artificial, but I just wanted to give an example of patching by sequence position.

In [None]:
results = path_patch(
    model,
    orig_input=flipped_tokens,
    new_input=clean_tokens,
    sender_nodes=IterNode(["resid_pre", "attn_out", "mlp_out"], seq_pos="each"),
    receiver_nodes=Node("resid_post", 11),
    patching_metric=ioi_metric_denoising,
    direct_includes_mlps=False, # gives similar results to direct_includes_mlps=True
    verbose=True,
)

In [None]:
# We get a dictionary where each key is a node name, and each value is a tensor of (layer, seq_pos)
assert list(results.keys()) == ['resid_pre', 'attn_out', 'mlp_out']
assert results["resid_pre"].shape == (15, 12)

results_stacked = t.stack([
    results.T for results in results.values()
])

imshow(
    results_stacked,
    facet_col=0,
    facet_labels=['resid_pre', 'attn_out', 'mlp_out'],
    title="Results of denoising patching at residual stream",
    labels={"x": "Sequence position", "y": "Layer", "color": "Logit diff variation"},
    x=labels,
    xaxis_tickangle=45,
    width=1300,
    margin={"r": 100, "l": 100},
    border=True,
)

This shows how the attention heads get us the right result at layer 9, and then that information is stored in the residual stream at the position of the `END` token. It also shows the effect of the negative name mover heads (negative values in the middle plot for the last 2 layers).

This is different from our previous plot because it only looks at **direct logit attribution** from the sender component to the final value of the residual stream, rather than other downstream effects of patching at that position (which *do* get captured by activation patching).

Note - I chose to have the output returned as a dict rather than tensor. I think this makes more sense, because e.g. stacking a `z`-tensor (with head dimension) and an `mlp_out`-tensor (without head dimension) would be weird. Better to have the return type be easily understandable, and let users figure out how to do the plotting.

### 3️⃣ Patching from head to head

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/e5.png" width="380">


Here we'll replicate figure 5(b) of the [IOI paper](https://arxiv.org/abs/2211.00593) by patching from (each head's output in turn) to the value inputs of the S-Inhibition heads.

In [None]:
S_INHIBITION_HEADS = [(7, 3), (7, 9), (8, 6), (8, 10)]

results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in S_INHIBITION_HEADS],
    patching_metric=ioi_metric_noising,
    verbose=True,
)

In [None]:
imshow(
    results["z"][:8] * 100,
    title="Direct effect on S-Inhibition Heads' values",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=700,
    margin={"r": 100, "l": 100}
)

This matches figure 5 in the paper:

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/fig_5b.png" width="700">

Again, how does this change when we don't include MLPs in direct path?

In [None]:
results_direct = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in S_INHIBITION_HEADS],
    patching_metric=ioi_metric_noising,
    direct_includes_mlps=False,
    verbose=True,
)

In [None]:
imshow(
    t.stack([results['z'], results_direct['z']])[:, :8] * 100,
    facet_col=0, facet_labels=["Direct path includes MLPs", "Direct path is just skip connections"],
    title="Direct effect on S-Inhibition Heads' values",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1100
)

While we're on a roll, let's also do figure 4 (it won't require anything that we haven't already seen).

In [None]:
results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("q", layer, head=head) for layer, head in NAME_MOVER_HEADS],
    patching_metric=ioi_metric_noising,
    verbose=True,
)

In [None]:
imshow(
    results['z'][:9] * 100,
    title="Direct effect on Name Mover Heads' queries",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=700,
)

Compare to figure 4:

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/fig_4b.png" width="700">

### 4️⃣ Applying patching metric to the cache

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/e6.png" width="400">

In this section, I'm going to include some code I made as part of the initial investigations into my SERI MATS project for Neel Nanda's stream. It involved investigating **negative name mover heads**, and seeing what they might be doing off-distribution. My theory was that they were query-composing with the name mover heads - in other words, they were detecting that the `IO` token had been copied from `IO1` to `end`, and suppressing this prediction. The best way of testing this would be for the `patching_metric` function to look at attention patterns (which could be found from the cache) rather than the output logits. I realised that I hadn't built this feature into my path patching code yet, so I decided to do that!

The `apply_metric_to_cache` argument, if True, will result in the `patching_metric` function taking in an `ActivationCache` (the one we get from the final step of the path patching algorithm). Below, you can see the patching metric I'm using: `get_io_vs_s_attn_for_nmh`. It looks at the difference between attention probability from `end -> IO1` and `end -> S1` for the neg name mover head 10.7 (this difference should be positive for the clean distribution, and if my theory about composition was correct then it would reduce to approx zero when I path patched from the output of the three name mover heads to the query input of the neg name mover 10.7). The histogram shows that we do indeed see this.

> **Note** - the `get_io_vs_s_attn_for_nmh` function actually returns a tuple of tensors rather than a float. This is supported by my path patching implementation. If the result isn't a float, it returns it all the same (and if the non-float results was returned as part of an `IterNode` function, it'll return these results in a list rather than trying to rearrange them into tensors). Here though, I was just performing a single instance rather than iterating.

#### Setup

*As mentioned, this is a hilariously low-tech way of getting 1000 datapoints. I just did it in a for loop cause I didn't want to break memory and I don't know any better ways cause I'm pretty naive when it comes to doing anything at scale (-:*


In [None]:
NEG_NMH = (10, 7)
NAME_MOVER_HEADS = [(9, 6), (9, 9), (10, 10)]


def get_io_vs_s_attn_for_nmh(
    patched_cache: ActivationCache,
    ioi_dataset: IOIDataset,
    ioi_cache: ActivationCache,
    neg_nmh = NEG_NMH,
) -> Float[Tensor, "batch"]:
    '''
    Returns the difference between patterns[END, IO] and patterns[END, S1], where patterns
    are the attention patterns for the negative name mover head.

    This is returned in the form of a tuple of 2 tensors: one for the patched distribution
    (calculated using `patched_cache` which is returned by the path patching algorithm), and
    one for the clean IOI distribution (which is just calculated directly from that cache).
    '''
    layer, head = neg_nmh
    attn_pattern_patched = patched_cache["pattern", layer][:, head]
    attn_pattern_clean = ioi_cache["pattern", layer][:, head]
    # both are (batch, seq_Q, seq_K), and I want all the "end -> IO" attention probs

    N = ioi_dataset.toks.size(0)
    io_seq_pos = ioi_dataset.word_idx["IO"]
    s1_seq_pos = ioi_dataset.word_idx["S1"]
    end_seq_pos = ioi_dataset.word_idx["end"]

    return (
        attn_pattern_patched[range(N), end_seq_pos, io_seq_pos] - attn_pattern_patched[range(N), end_seq_pos, s1_seq_pos],
        attn_pattern_clean[range(N), end_seq_pos, io_seq_pos] - attn_pattern_clean[range(N), end_seq_pos, s1_seq_pos],
    )


results_patched = []
results_clean = []

for seed in tqdm(range(50)):

    ioi_dataset, abc_dataset, ioi_cache, abc_cache, ioi_metric = generate_data_and_caches(20, seed=seed)

    result_patched, result_clean = path_patch(
        model,
        orig_input=ioi_dataset.toks,
        new_input=abc_dataset.toks,
        orig_cache=ioi_cache,
        new_cache=abc_cache,
        sender_nodes=[Node("z", layer=layer, head=head) for layer, head in NAME_MOVER_HEADS], # Output of all name mover heads
        receiver_nodes=Node("q", NEG_NMH[0], head=NEG_NMH[1]), # To query input of negative name mover head
        patching_metric=partial(get_io_vs_s_attn_for_nmh, ioi_dataset=ioi_dataset, ioi_cache=ioi_cache),
        apply_metric_to_cache=True,
        direct_includes_mlps=False,
    )
    results_patched.extend(result_patched.tolist())
    results_clean.extend(result_clean.tolist())

    t.cuda.empty_cache()

In [None]:
hist(
    [results_patched, results_clean],
    labels={"variable": "Version", "value": "Attn diff (positive ⇒ more attn paid to IO than S1)"},
    title="Difference in attn from END➔IO vs. END➔S1 (path-patched vs clean)",
    names=["Patched", "Clean"],
    width=800,
    height=600,
    opacity=0.7,
    marginal="box",
    template="simple_white"
)

# Appendix

## What is path patching?

Firstly, let's consider the version of path patching where "direct paths" mean skip connections, i.e. the output of one head being used as direct input for another head.

Path patching from a **sender node** (usually the output of a component) to **receiver node** (usually the input of a component0) means we patch into the **receiver node** and change it's value to *"what it would have been if the direct path from sender -> receiver was the same as it was on `new_input`, but all other paths to reciver were the same as on `orig_input`"*.

The phase ***"what it would have been"*** hides a bit of nuance here. The key idea to have in mind is that **transformers are the sum of multiple paths**, with the number of paths growing exponentially as the transformer gets larger. If we want to examine a specific path (e.g. composition between head $h_1$ and a later head $h_2$) then we might want to see what happens when we change the direct input into $h_2$ which comes from $h_1$, but keep all the other inputs the same.

There are 2 steps to our algorithm:

1. Gather all activations we need on orig and new distributions.
2. Run model on original, with **receiver nodes** patched (adding the difference between new and orig outputs on the **sender nodes**).

Example, patching from the output of head `0.0` to the input of head `2.0`:

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/path-patching-alg-transformers-direct.png" width="440">

Why does this work? The value we patch into head `2.0` (i.e. the blue) is the value it would have if we wrote it as a sum of paths, with every path the same as in the original distribution, except for the direct path from `0.0`:

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/path-patching-decomp-one.png" width="1050">

Now, let's consider a different form of the algorithm, where **we count paths including MLPs as direct**. This is the version of the algorithm used in the IOI paper.

Why might we want to do this? Maybe the output of head `0.0` is being used directly by head `2.0`, but one of the MLPs is acting as a mediator. In other words (to oversimplify), `0.0` writes the vector $v$ into the residual stream, some neuron detects $v$ and writes $w$ to the residual stream, and `2.0` detects $w$. If we used the algorithm above, then we wouldn't catch this causal relationship. The drawback is that things get a bit messier, because now we're essentially passing a "fake input" into our MLPs, and it's dangerous to assume that any operation as clean as the one previously described (with vectors $v$, $w$) would still happen under these new circumstances.

If we include MLPs as direct paths, then we need an extra step in our algorithm, to compute the effect of all the direct paths going through MLPs. Our new algorithm is:

1. Gather all activations we need on orig and new distributions.
2. Run model on orig with **sender nodes** patched from new and all other attention heads frozen. Cache the receiver nodes.
    * Note, we **don't** freeze the MLPs, because we want to include the effect of paths from sender -> receiver that go through MLPs.
3. Run model on original, with **receiver nodes** patched from previously cached values.

Illustration, in the case where we're patching from head to head:

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/path-patching-alg-transformers-5.png" width="550">

Why does this work? After step 2, we can see from the following decomposition that all the direct paths that feed into `2.0` have the same value that they would on the original distribution (grey), and all non-direct paths have the same value as they would on the new distribution (green).

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/path-patching-decomp-two.png" width="800">

## Comparison to TransformerLens' activation patching

Using this is very different to using TransformerLens' current activation patching. TransformerLens has one general function `generic_activation_patch` and several `partial`s defined from it, each one for a specific use case (e.g. `get_act_patch_attn_out`). I think the way I've done things here has 3 advantages over the current method:

* **Less different functions to remember** (the general `Node` and `IterNode` syntax allows you to express all different forms of patching in more or less the same way).
* **More customizable**. You aren't restricted to using the `partial`-defined functions.
* **More suited for path patching**. My understanding of `generic_activation_patch` is that it wasn't designed to be used regularly; its specialisations were.

The main disadvantage of my method is more abstractions to keep in mind (in particular the `Node` and `IterNode` classes). To mitigate this, I tried to make the syntax for defining `Node` and `IterNode` as close to the syntax for indexing into the cache as possible, so people don't feel like they're learning an entirely new way of doing things.

## Limitations / drawbacks

* The version of path patching with a single forward pass (rather than two) doesn't work when you're patching individual neurons in an MLP. This is because (unlike the inputs to attention heads) there's no "residual stream splits into a separate value for each neuron" option in transformer models. Obviously, this would be ridiculous, because you might be duplicating the residul stream several thousands of times! If you try to path patch individual neurons, then a version of the two-forward-passes algorithm will be used instead (but one which freezes MLPs, so the effect is the same as the single forward pass algorithm). I mention it here just to explain why the algorithm is inexplicably twice as slow if you patch individual neurons.

* Path patching into attention patterns by sequence position is pretty questionable (I'm not sure what version of this makes sense, or whether it even makes sense at all to path patch into attention patterns). For that reason, I'd recommend just not using this option.

* I haven't yet written code to support `z` as a receiver node (I would interpret this as doing `q`, `k`, `v` all at once, and again it wouldn't make sense to only do this at a single sequence position). For now, you should just use pass a list of nodes for `q`, `k` and `v` manually.

  * The things you can path patch into (i.e. receiver components) are:
    * `resid_pre`, `resid_mid`, `resid_post`
    * Attention head inputs: `q`, `k`, `v`, or `q_input`, `k_input`, `v_input`
    * MLP input: `pre`

  * The things you can patch from (i.e. sender components) are:
    * `resid_pre`, `resid_mid`, `resid_post`
    * Attention head output, by head `z` or over all heads `attn_out`
    * MLP output, by neuron `post` or over all neurons `mlp_out`

* Arthur Conmy's ACDC uses a different methodology for path patching, and this method wouldn't fit as well here. But this is by design. ***My activation & path patching implementation was designed to be low-effort for mechanistic interpretability researchers doing exploratory analysis***. It's not been optimized for e.g. fitting into larger systems like ACDC.