<a href="https://colab.research.google.com/github/adagio7/induction-heads/blob/main/induction_heads.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required packages
%pip install transformer-lens
%pip install circuitsvis

In [3]:
import torch
import circuitsvis as cv
import transformer_lens.utils as utils
from transformer_lens import (
    HookedTransformer
)

# We are only really interested in model inference, not training
torch.set_grad_enabled(False)
device = utils.get_device()

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

In [34]:
prompt = "Tom went to the store. Tom"
tokens = model.to_tokens(prompt)

# As the name implies `remove_batch_dim` removes the first dimension (batch), as we have batch size = 1
# However, for some reason, this doesn't seem to be working :/
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
print(logits.shape)

torch.Size([1, 8, 50257])


#### Circuit Visualizers

In [33]:
# Let's try to visualize the attention pattern of a particular layer
layer = 8 # Change me :D
attention_pattern = cache["pattern", layer, "attn"]
gpt2_str_tokens = model.to_str_tokens(prompt)

# We expect this to be [n_heads, len(prompt_tokens), len(prompt_tokens)] where the latter two dimensions are for the QK vectors respectively
print(f'{attention_pattern.shape=}')

# Note that `attention_patterns` is deprecated, and `attention_heads` should be used instead
# But for space, we used the old version as its more compact
print(f'Attention Pattern for {layer=}')
cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)

attention_pattern.shape=torch.Size([12, 8, 8])
Attention Pattern for layer=8


We use circuitvis to visualize the attention pattern of the layer, and to identify the induction heads, we just have to find the head that seemingly pays higher attention to the previous instance of the same token.

From iterating through the layers, we find that Head 4.1, seem to correspond to induction heads.

**N.B**: We also note some other pretty interesting patterns, such as Head 2.11 exclusively attending to the first token in the stream, Head 0.1 that self-attends, and a bunch of Heads (e.g. 4.11) that attend to the previous token.