<a href="https://colab.research.google.com/github/adagio7/induction-heads/blob/main/induction_heads.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# Install required packages
%pip install transformer-lens
%pip install circuitsvis

Collecting transformer-lens
  Downloading transformer_lens-2.15.0-py3-none-any.whl.metadata (12 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer-lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer-lens)
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting fancy-einsum>=0.0.3 (from transformer-lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer-lens)
  Downloading jaxtyping-0.3.1-py3-none-any.whl.metadata (7.0 kB)
Collecting transformers-stream-generator<0.0.6,>=0.0.5 (from transformer-lens)
  Downloading transformers-stream-generator-0.0.5.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.7.1->transformer-lens)
  Downloading dill-0.3.8-py

In [6]:
import torch
import circuitsvis as cv
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import (
    ActivationCache,
    HookedTransformer,
)

# We are only really interested in model inference, not training
torch.set_grad_enabled(False)
device = utils.get_device()

In [7]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


In [8]:
prompt = "Tom went to the store. Tom"
tokens = model.to_tokens(prompt)

# As the name implies `remove_batch_dim` removes the first dimension (batch), as we have batch size = 1
# However, for some reason, this doesn't seem to be working :/
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
print(logits.shape)

torch.Size([1, 8, 50257])


### Circuit Visualizers

In [9]:
# Let's try to visualize the attention pattern of a particular layer
layer = 8 # Change me :D
attention_pattern = cache["pattern", layer, "attn"]
gpt2_str_tokens = model.to_str_tokens(prompt)

# We expect this to be [n_heads, len(prompt_tokens), len(prompt_tokens)] where the latter two dimensions are for the QK matrices
print(f'{attention_pattern.shape=}')

# Note that `attention_patterns` is deprecated, and `attention_heads` should be used instead
# But for space, we used the old version as its more compact
print(f'Attention Pattern for {layer=}')
cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)

attention_pattern.shape=torch.Size([12, 8, 8])
Attention Pattern for layer=8


We use `circuitvis` to visualize the attention pattern of the layer, and to identify the induction heads, we just have to find the head that seemingly pays higher attention to the previous instance of the same token.

From iterating through the layers, we find that Head 4.1, seem to correspond to induction heads.

**N.B**: We also note some other pretty interesting patterns, such as Head 2.11 exclusively attending to the first token in the stream, Head 0.1 that self-attends, and a bunch of Heads (e.g. 4.11) that attend to the previous token.

In [10]:
INDUCTION_LAYER = 4
INDUCTION_HEAD = 1

### Activation Patching

Now that we have our hypothesized induction heads, how can be so sure that this particular head is *neccessary* for induction? Could it be that its calculating some indirect attention that is used downstream? Essentially, we want to formalize that this particular head is neccessary and sufficient for induction to occur.

How we answer this is via *activation patching*, we run an ablation on the input prompt, namely, one where induction does occur and another where it doesn't (for control, we make the prefix of the prompt the same, and only augment the final token). This technique is inspired by causal ablation studies, where we isolate the sufficient causes via interventions.

In [11]:
# We introduce two prompts, one positive case of the inductive behaviour and another of negative
clean_prompt = "Tom went to the store. Tom"
corrupt_prompt = "Tom went to the store. Sarah"

clean_tokens = model.to_tokens(clean_prompt)
corrupt_tokens = model.to_tokens(corrupt_prompt)

To do the activation patching, we first do a forward pass with the positive case

We then take the attention pattern of our hypothesized head and intervened on the corrupt pass

Using this, we can then take a metric to quanitfy the causal effect of the counterfactual.

In [12]:
_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupt_cache = model.run_with_cache(corrupt_tokens)

In [15]:
def patch_head_output(corrupt_cache: ActivationCache, hook: HookPoint):
  """
  Applies the clean_cache activation at INDUCTION_LAYER for INDUCTION_HEAD
  """
  # Find appropriate layer to get the activations
  clean_z = clean_cache[utils.get_act_name('z', INDUCTION_LAYER)]
  corrupt_z = corrupt_cache.clone()

  # Patch the head
  corrupt_z[:, :, INDUCTION_HEAD, :] = clean_z[:, :, INDUCTION_HEAD, :]

  return corrupt_z


In [16]:
patched_logits = model.run_with_hooks(
    corrupt_tokens,
    fwd_hooks=[(
        utils.get_act_name("z", INDUCTION_LAYER),
        patch_head_output
    )]
)

In [20]:
logits = patched_logits[0, -1]

print(logits[model.to_single_token(" Sarah")])

tensor(7.0067)
