# Paper Replication
- Replicate most of the other results from the [IOI paper](https://arxiv.org/abs/2211.00593)

- Practice more open-ended, less guided coding

## Imports & Setup

In [None]:
import os

os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
from pathlib import Path
import torch as t
from torch import Tensor
import numpy as np
import einops
from tqdm.notebook import tqdm
import plotly.express as px
import webbrowser
import re
import itertools
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set, Union
from functools import partial
from IPython.display import display, HTML
from rich.table import Table, Column
from rich import print as rprint
import circuitsvis as cv
from pathlib import Path
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
from ioi_dataset import NAMES, IOIDataset
from solutions import format_prompt, make_table

t.set_grad_enabled(False)

from arena3.chapter1_transformer_interp.exercises.plotly_utils import (
    imshow,
    line,
    scatter,
    bar,
)
import tests

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

MAIN = __name__ == "__main__"

## Load model, set up dataset

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

N = 25
ioi_dataset = IOIDataset(
    prompt_type="mixed",
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=str(device),
)

### Additional setup

In [None]:
def logits_to_ave_logit_diff_2(
    logits: Float[Tensor, "batch seq d_vocab"],
    ioi_dataset: IOIDataset = ioi_dataset,
    per_prompt=False,
) -> Union[Float[Tensor, ""], Float[Tensor, "batch"]]:
    """
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    """

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits = logits[
        range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.io_tokenIDs
    ]  # [batch]
    s_logits = logits[
        range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.s_tokenIDs
    ]  # [batch]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [None]:
abc_dataset = ioi_dataset.gen_flipped_prompts("ABB->XYZ, BAB->XYZ")

model.reset_hooks(including_permanent=True)

ioi_logits_original, ioi_cache = model.run_with_cache(ioi_dataset.toks)
abc_logits_original, abc_cache = model.run_with_cache(abc_dataset.toks)

ioi_per_prompt_diff = logits_to_ave_logit_diff_2(ioi_logits_original, per_prompt=True)
abc_per_prompt_diff = logits_to_ave_logit_diff_2(abc_logits_original, per_prompt=True)

ioi_average_logit_diff = logits_to_ave_logit_diff_2(ioi_logits_original).item()
abc_average_logit_diff = logits_to_ave_logit_diff_2(abc_logits_original).item()

print(f"Average logit diff (IOI dataset): {ioi_average_logit_diff:.4f}")
print(f"Average logit diff (ABC dataset): {abc_average_logit_diff:.4f}")

make_table(
    colnames=["IOI prompt", "IOI logit diff", "ABC prompt", "ABC logit diff"],
    cols=[
        map(format_prompt, ioi_dataset.sentences),
        ioi_per_prompt_diff,
        map(format_prompt, abc_dataset.sentences),
        abc_per_prompt_diff,
    ],
    title="Sentences from IOI vs ABC distribution",
)

## Copying & writing direction results
- Start by replicating the paper's analysis of the Name Mover Heads and Negative Name Mover Heads.

- Our previous analysis should have pretty much convinced us that these heads are copying / negatively copying our indirect object token, but the results here show this with a bit more rigour.

### Exercise: replicate writing direction results

- Figure 3(c) from the paper plots the output of the strongest name mover and negative name mover heads against the attention probabilities for `END` attending to `IO` or `S` (color-coded).

- Some clarifications:
  - "Projection" here is being used synonymously with "dot product".

  - We're projecting onto the name embedding. I.e., the embedding vector for the token to which attention is being paid.
    - This is not the same as the logit diff, which we got by projecting the heads' output onto the difference between the unembedding vectors for `IO` and `S`.

  - We're doing this because the question we're trying to answer is: *"does the attention head copy (or anti-copy) the names to which it pays attention?"*

In [None]:
def scatter_embedding_vs_attn(
    attn_from_end_to_io: Float[Tensor, "batch"],
    attn_from_end_to_s: Float[Tensor, "batch"],
    projection_in_io_dir: Float[Tensor, "batch"],
    projection_in_s_dir: Float[Tensor, "batch"],
    layer: int,
    head: int,
):
    scatter(
        x=t.concat([attn_from_end_to_io, attn_from_end_to_s], dim=0),
        y=t.concat([projection_in_io_dir, projection_in_s_dir], dim=0),
        color=["IO"] * N + ["S"] * N,
        title=f"Projection of the output of {layer}.{head} along the name<br>embedding vs attention probability on name",
        title_x=0.5,
        labels={
            "x": "Attn prob on name",
            "y": "Dot w Name Embed",
            "color": "Name type",
        },
        color_discrete_sequence=["#72FF64", "#C9A5F7"],
        width=650,
    )

In [None]:
def calculate_and_show_scatter_embedding_vs_attn(
    layer: int,
    head: int,
    cache: ActivationCache = ioi_cache,
    dataset: IOIDataset = ioi_dataset,
) -> None:
    """
    Creates and plots a figure equivalent to 3(c) in the paper.

    This should involve computing the four 1D tensors:
        attn_from_end_to_io
        attn_from_end_to_s
        projection_in_io_dir
        projection_in_s_dir
    and then calling the scatter_embedding_vs_attn function.
    """
    # Get the value written to the residual stream at the end token by this head
    z = cache[utils.get_act_name("z", layer)][:, :, head]  # [batch seq d_head]
    N = z.size(0)
    output = z @ model.W_O[layer, head]  # [batch seq d_model]
    output_on_end_token = output[
        t.arange(N), dataset.word_idx["end"]
    ]  # [batch d_model]

    # Get the directions we'll be projecting onto
    io_unembedding = model.W_U.T[dataset.io_tokenIDs]  # [batch d_model]
    s_unembedding = model.W_U.T[dataset.s_tokenIDs]  # [batch d_model]

    # Get the value of projections, by multiplying and summing over the d_model dimension
    projection_in_io_dir = (output_on_end_token * io_unembedding).sum(-1)  # [batch]
    projection_in_s_dir = (output_on_end_token * s_unembedding).sum(-1)  # [batch]

    # Get attention probs, and index to get the probabilities from END -> IO / S
    attn_probs = cache["pattern", layer][:, head]  # [batch seqQ seqK]
    attn_from_end_to_io = attn_probs[
        t.arange(N), dataset.word_idx["end"], dataset.word_idx["IO"]
    ]  # [batch]
    attn_from_end_to_s = attn_probs[
        t.arange(N), dataset.word_idx["end"], dataset.word_idx["S1"]
    ]  # [batch]

    # Show scatter plot
    scatter_embedding_vs_attn(
        attn_from_end_to_io,
        attn_from_end_to_s,
        projection_in_io_dir,
        projection_in_s_dir,
        layer,
        head,
    )


nmh = (9, 9)
calculate_and_show_scatter_embedding_vs_attn(*nmh)

nnmh = (11, 10)
calculate_and_show_scatter_embedding_vs_attn(*nnmh)

#### Interpretation

- For head `9.9`, both the `S` and `IO` tokens exhibit a positive correlation between their contribution to logit difference and the attention probability on their name.
  - This suggests the head is just copying names it attends to from the name to the `END` token.
  - We can see that it is paying more attention to the `IO` token and less on `S`, which is what we expect (thanks to Q-composition with the S-inhibition heads).

- The same is true for the negative name mover head `11.10`, only it works in the opposite direction: actively suppressing the logit score for the names it attends to.
  - **Note**: it's important that we observe this negative correlation, because this shows us that the head really is anti-copying the IO token (rather than just copying the S token).

### Exercise: replicate copying score results

In [None]:
def get_copying_scores(
    model: HookedTransformer, k: int = 5, names: list = NAMES
) -> Float[Tensor, "2 layer-1 head"]:
    """
    Gets copying scores (both positive and negative) as described in page 6 of the IOI paper, for every (layer, head) pair in the model.

    Returns these in a 3D tensor (the first dimension is for positive vs negative).

    Omits the 0th layer, because this is before MLP0 (which we're claiming acts as an extended embedding).
    """
    results = t.zeros((2, model.cfg.n_layers, model.cfg.n_heads), device=device)

    name_tokens: Int[Tensor, "batch 1"] = model.to_tokens(names, prepend_bos=False)
    name_embeddings: Int[Tensor, "batch 1 d_model"] = model.embed(name_tokens)

    resid_after_mlp0 = name_embeddings + model.blocks[0].mlp(
        model.blocks[0].ln2(name_embeddings)
    )

    for layer in tqdm(range(model.cfg.n_layers), desc="Layers"):
        for head in range(model.cfg.n_heads):

            resid_after_OV_pos = (
                resid_after_mlp0 @ model.W_V[layer, head] @ model.W_O[layer, head]
            )
            resid_after_OV_neg = (
                resid_after_mlp0 @ -model.W_V[layer, head] @ model.W_O[layer, head]
            )

            logits_pos = model.unembed(model.ln_final(resid_after_OV_pos)).squeeze()
            logits_neg = model.unembed(model.ln_final(resid_after_OV_neg)).squeeze()

            topk_logits = t.topk(logits_pos, dim=-1, k=k).indices
            in_topk = (topk_logits == name_tokens).any(-1)

            bottomk_logits = t.topk(logits_neg, dim=-1, k=k).indices
            in_bottomk = (bottomk_logits == name_tokens).any(-1)

            # Fill in results
            results[:, layer, head] = t.tensor(
                [in_topk.float().mean(), in_bottomk.float().mean()]
            )
    return results


copying_results = get_copying_scores(model)

imshow(
    copying_results,
    facet_col=0,
    facet_labels=["Positive copying scores", "Negative copying scores"],
    title="Copying scores of attention heads' OV circuits",
    width=800,
)


heads = {
    "name mover": [(9, 9), (10, 0), (9, 6)],
    "negative name mover": [(10, 7), (11, 10)],
}

for i, name in enumerate(["name mover", "negative name mover"]):
    make_table(
        title=f"Copying Scores ({name} heads)",
        colnames=["Head", "Score"],
        cols=[
            list(map(str, heads[name])) + ["[dark_orange bold]Average"],
            [f"{copying_results[i, layer, head]:.2%}" for (layer, head) in heads[name]]
            + [f"[dark_orange bold]{copying_results[i].mean():.2%}"],
        ],
    )

## Validation of early heads

Three different kinds of heads early in the circuit:

1. Previous token heads

2. Induction heads

3. Duplicate token heads.

Can validate all at once using a sequence of random `n` tokens followed by same `n` tokens repeated:

1. Prev token heads, by measuring the attention patterns with an offset of one (i.e. one below the diagonal).

2. Induction heads, by measuring the attention patterns with an offset of `n-1` (i.e. the second instance of a token paying attention to the token after its first instance).

3. Duplicate token heads, by measuring the attention patterns with an offset of `n` (i.e. a token paying attention to its previous instance).

In all three cases, if heads score close to 1 on these metrics, it's strong evidence that they are working as this type of head.

**Note**: it's a leaky abstraction to say things like "head X is an induction head", since we're only observing it on a certain distribution. For instance, it's not clear what the role of induction heads and duplicate token heads is when there are no duplicates (they could in theory do something completely different).


### Exercise: perform head validation

In [None]:
def generate_repeated_tokens(
    model: HookedTransformer, seq_len: int, batch: int = 1
) -> Float[Tensor, "batch 2*seq_len"]:
    """
    Generates a sequence of repeated random tokens (no start token).
    """
    rep_tokens_half = t.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=t.int64)
    rep_tokens = t.cat([rep_tokens_half, rep_tokens_half], dim=-1).to(device)
    return rep_tokens


def get_attn_scores(
    model: HookedTransformer,
    seq_len: int,
    batch: int,
    head_type: Literal["duplicate", "prev", "induction"],
) -> Float[Tensor, "n_layers n_heads"]:
    """
    Returns attention scores for sequence of duplicated tokens, for every head.
    """

    tokens = generate_repeated_tokens(model, seq_len, batch)

    _, cache = model.run_with_cache(
        tokens, return_type=None, names_filter=lambda name: name.endswith("pattern")
    )

    if head_type == "duplicate":
        src_indices = range(seq_len)
        dest_indices = range(seq_len, 2 * seq_len)

    elif head_type == "prev":
        src_indices = range(seq_len)
        dest_indices = range(1, seq_len + 1)

    elif head_type == "induction":
        src_indices = range(1, seq_len + 1)
        dest_indices = range(seq_len, 2 * seq_len)

    else:
        raise ValueError(f"Invalid head type: '{head_type}'")

    result = t.zeros(
        (model.cfg.n_layers, model.cfg.n_heads), device=device, dtype=t.float32
    )
    for layer in tqdm(range(model.cfg.n_layers), desc="Layers"):
        for head in range(model.cfg.n_heads):
            attn_scores = cache["pattern", layer]  # [batch head dest src]
            result[layer, head] = (
                attn_scores[:, head, dest_indices, src_indices].mean().item()
            )

    return result


def plot_early_head_validation_results(seq_len: int = 50, batch: int = 50):
    """
    Produces a plot that looks like Figure 18 in the paper.
    """
    head_types = ["duplicate", "prev", "induction"]

    results = t.stack(
        [
            get_attn_scores(model, seq_len, batch, head_type=head_type)
            for head_type in head_types
        ]
    )

    imshow(
        results,
        facet_col=0,
        facet_labels=[
            f"{head_type.capitalize()} token attention prob.<br>on sequences of random tokens"
            for head_type in head_types
        ],
        labels={"x": "Head", "y": "Layer"},
        width=1300,
    )


model.reset_hooks()
plot_early_head_validation_results()

## Minimal Circuit

### Background: faithfulness, completeness, and minimality

IOI paper authors developed three criteria for validating their circuit explanations:

1. **Faithful**: the circuit can perform as well as the model
   - equivalent to $F|(C) - F(M)|$ being small, where $C$ is the circuit, $M$ is the model, and $F$ is the performance metric function

2. **Complete**: the circuit contains all nodes used to perform the task
   - equivalent to $|F(C \setminus K) - F(M \setminus K)|$ being small for any subset $K \subset C$, including when $K$ is the empty set, showing that completeness implies faithfulness.

   - Although completeness implies faithfulness, faithfulness does *not* impliy completeness.

   - Backup name mover heads illustrate this point. They are used in the task, and without understanding the role they play you'll have an incorrect model of reality
     - E.g., you'll think ablating the name mover heads would destroy performance, which turns out not to be true.

     - If you define a circuit that doesn't contain backup name mover heads then it will be faithful (the backup name mover heads won't be used) but not complete.

     - I.e., is $K$ is the set of name mover heads, $C \setminus K$ performs worse than $M \setminus K$, because the latter contains backup name mover heads, while the former does not

3. **Minimal**: the circuit does not contain nodes that are irrelevant to the task
   - Non-minimal circuits may not be mechanistically understandable, which defeats the purpose of this kind of circuit analysis.

If all three criteria are met, then the circuit is considered a reliable explanation for model behaviour.

### Towards minimality

- We've analysed most components of the circuit, now try ablating everything except those core components and verify model performance

- Very large ablation!
  - Everything except for the output of each of our key attention heads (e.g., duplicate token heads or S-inhibition heads) at a single sequence position. 
    - E.g., for the DTHs, this is the `S2` token; and for SIHs, this is the `end` token.
  
  - Given that our core circuit has 26 heads in total, and our sequences have length around 20 on average, this means we're ablating all but $(26/144)/20 \approx 1\%$ of our attention heads' output
  
    - Note that the number of possible paths through the model is reduced by ***much*** more than this

- How to ablate?
  - Zero ablation? Some non-obvious problems:
    - Heads might be "expecting" non-zero input, and setting the input to zero is essentially an arbitrary choice which takes it off distribution.
      - You can think of this as adding a bias term to this head, which might mess up subsequent computation and lead to noisy results.
  
  - Mean ablation?
    - Set a head's output to its average output over `ioi_dataset`.

    - Problem: taking the mean over this dataset might contain relevant information for solving the IOI task.
      - E.g., the `is_duplicated` flag that gets written to S2 will be present for all sequences, so averaging won't remove this information.
    
    - Solution (for this task): ablate with the mean of the ABC dataset rather than the IOI dataset.
      - Removes the problem of averages still containing relevant information from solving the IOI task.

- Complication: the sentences have different templates, and the positions of tokens like `S` and `IO` are not consistent across these templates
  - E.g.:
    ```python
    "Then, [B] and [A] had a long argument and after that [B] said to [A]"
    "After the lunch [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]"
    ```
   
  - We avoided this problem in previous exercises by choosing a very small set of sentences, where all the important tokens had the same indices.

  - Solution: take mean over each *template* and ablate, rather than whole dataset.

### Exercise: constructing the minimal circuit



In [None]:
CIRCUIT = {
    "name mover": [(9, 9), (10, 0), (9, 6)],
    "backup name mover": [
        (10, 10),
        (10, 6),
        (10, 2),
        (10, 1),
        (11, 2),
        (9, 7),
        (9, 0),
        (11, 9),
    ],
    "negative name mover": [(10, 7), (11, 10)],
    "s2 inhibition": [(7, 3), (7, 9), (8, 6), (8, 10)],
    "induction": [(5, 5), (5, 8), (5, 9), (6, 9)],
    "duplicate token": [(0, 1), (0, 10), (3, 0)],
    "previous token": [(2, 2), (4, 11)],
}

SEQ_POS_TO_KEEP = {
    "name mover": "end",
    "backup name mover": "end",
    "negative name mover": "end",
    "s2 inhibition": "end",
    "induction": "S2",
    "duplicate token": "S2",
    "previous token": "S1+1",
}

In [None]:
def compute_means_by_template(
    means_dataset: IOIDataset, model: HookedTransformer
) -> Float[Tensor, "layer batch seq head_idx d_head"]:
    """
    Returns the mean of each head's output over the means dataset. This mean is
    computed separately for each group of prompts with the same template (these
    are given by means_dataset.groups).
    """
    # Cache the outputs of every head
    _, means_cache = model.run_with_cache(
        means_dataset.toks.long(),
        return_type=None,
        names_filter=lambda name: name.endswith("z"),
    )

    # Create tensor to store means
    batch, seq_len = len(means_dataset), means_dataset.max_len
    means = t.zeros(
        size=(model.cfg.n_layers, batch, seq_len, model.cfg.n_heads, model.cfg.d_head),
        device=model.cfg.device,
    )

    # Get set of different templates for this data
    for layer in range(model.cfg.n_layers):
        z_for_this_layer = means_cache[
            utils.get_act_name("z", layer)
        ]  # [batch seq head d_head]

        for template_group in means_dataset.groups:
            z_for_this_template = z_for_this_layer[template_group]
            z_means_for_this_template = einops.reduce(
                z_for_this_template, "batch seq head d_head -> seq head d_head", "mean"
            )
            means[layer, template_group] = z_means_for_this_template

    return means


def get_heads_and_posns_to_keep(
    means_dataset: IOIDataset,
    model: HookedTransformer,
    circuit: Dict[str, List[Tuple[int, int]]],
    seq_pos_to_keep: Dict[str, str],
) -> Dict[int, Bool[Tensor, "batch seq head"]]:
    """
    Returns a dictionary mapping layers to a boolean mask giving the indices of the
    z output which *shouldn't* be mean-ablated.

    The output of this function will be used for the hook function that does ablation.
    """
    heads_and_posns_to_keep = {}
    batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads

    for layer in range(model.cfg.n_layers):

        mask = t.zeros(size=(batch, seq, n_heads))

        for head_type, head_list in circuit.items():
            seq_pos = seq_pos_to_keep[head_type]
            indices = means_dataset.word_idx[seq_pos]

            for layer_idx, head_idx in head_list:
                if layer_idx == layer:
                    mask[:, indices, head_idx] = 1

        heads_and_posns_to_keep[layer] = mask.bool()

    return heads_and_posns_to_keep


def hook_fn_mask_z(
    z: Float[Tensor, "batch seq head d_head"],
    hook: HookPoint,
    heads_and_posns_to_keep: Dict[int, Bool[Tensor, "batch seq head"]],
    means: Float[Tensor, "layer batch seq head d_head"],
) -> Float[Tensor, "batch seq head d_head"]:
    """
    Hook function which masks the z output of a transformer head.

    heads_and_posns_to_keep
        Dict created with the get_heads_and_posns_to_keep function. This tells
        us where to mask.

    means
        Tensor of mean z values of the means_dataset over each group of prompts
        with the same template. This tells us what values to mask with.
    """
    # Get the mask for this layer, and add d_head=1 dimension so it broadcasts correctly
    mask_for_this_layer = (
        heads_and_posns_to_keep[hook.layer()].unsqueeze(-1).to(z.device)
    )

    # Set z values to the mean
    z = t.where(mask_for_this_layer, z, means[hook.layer()])

    return z


def add_mean_ablation_hook(
    model: HookedTransformer,
    means_dataset: IOIDataset,
    circuit: Dict[str, List[Tuple[int, int]]] = CIRCUIT,
    seq_pos_to_keep: Dict[str, str] = SEQ_POS_TO_KEEP,
    is_permanent: bool = True,
) -> HookedTransformer:
    """
    Adds a permanent hook to the model, which ablates according to the circuit and
    seq_pos_to_keep dictionaries.

    In other words, when the model is run on ioi_dataset, every head's output will
    be replaced with the mean over means_dataset for sequences with the same template,
    except for a subset of heads and sequence positions as specified by the circuit
    and seq_pos_to_keep dicts.
    """

    model.reset_hooks(including_permanent=True)

    # Compute the mean of each head's output on the ABC dataset, grouped by template
    means = compute_means_by_template(means_dataset, model)

    # Convert this into a boolean map
    heads_and_posns_to_keep = get_heads_and_posns_to_keep(
        means_dataset, model, circuit, seq_pos_to_keep
    )

    # Get a hook function which will patch in the mean z values for each head, at
    # all positions which aren't important for the circuit
    hook_fn = partial(
        hook_fn_mask_z, heads_and_posns_to_keep=heads_and_posns_to_keep, means=means
    )

    # Apply hook
    model.add_hook(lambda name: name.endswith("z"), hook_fn, is_permanent=is_permanent)

    return model

In [None]:
import ioi_circuit_extraction


model = ioi_circuit_extraction.add_mean_ablation_hook(
    model, means_dataset=abc_dataset, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP
)

ioi_logits_minimal = model(ioi_dataset.toks)

print(
    f"Average logit difference (IOI dataset, using entire model): {logits_to_ave_logit_diff_2(ioi_logits_original):.4f}"
)
print(
    f"Average logit difference (IOI dataset, only using circuit): {logits_to_ave_logit_diff_2(ioi_logits_minimal):.4f}"
)


model = add_mean_ablation_hook(
    model, means_dataset=abc_dataset, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP
)

ioi_logits_minimal = model(ioi_dataset.toks)

print(
    f"Average logit difference (IOI dataset, using entire model): {logits_to_ave_logit_diff_2(ioi_logits_original):.4f}"
)
print(
    f"Average logit difference (IOI dataset, only using circuit): {logits_to_ave_logit_diff_2(ioi_logits_minimal):.4f}"
)