<a href="https://colab.research.google.com/github/adagio7/induction-heads/blob/main/induction_heads.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install required packages
%pip install transformer-lens
%pip install circuitsvis



In [2]:
import torch
import circuitsvis as cv
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import (
    ActivationCache,
    HookedTransformer,
)

# We are only really interested in model inference, not training
torch.set_grad_enabled(False)
device = utils.get_device()

In [3]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
prompt = "Tom went to the store. Tom"
tokens = model.to_tokens(prompt)

# As the name implies `remove_batch_dim` removes the first dimension (batch), as we have batch size = 1
# However, for some reason, this doesn't seem to be working :/
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
print(logits.shape)

torch.Size([1, 8, 50257])


### Circuit Visualizers

In [59]:
# Let's try to visualize the attention pattern of a particular layer
layer = 5 # Change me :D
attention_pattern = cache["pattern", layer, "attn"]
gpt2_str_tokens = model.to_str_tokens(prompt)

# We expect this to be [n_heads, len(prompt_tokens), len(prompt_tokens)] where the latter two dimensions are for the QK matrices
print(f'{attention_pattern.shape=}')

# Note that `attention_patterns` is deprecated, and `attention_heads` should be used instead
# But for space, we used the old version as its more compact
print(f'Attention Pattern for {layer=}')
cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)

attention_pattern.shape=torch.Size([12, 8, 8])
Attention Pattern for layer=5


We use `circuitvis` to visualize the attention pattern of the layer, and to identify the induction heads, we just have to find the head that seemingly pays higher attention to the previous instance of the same token.

From iterating through the layers, we find that Head (5.5, 5.8) seem to correspond to induction heads.

**N.B**: We also note some other pretty interesting patterns, such as Head 2.11 exclusively attending to the first token in the stream, Head 0.1 that self-attends, and a bunch of Heads (e.g. 4.11) that attend to the previous token.

In [30]:
INDUCTION_LAYER = 5
INDUCTION_HEAD = 8

### Activation Patching

Now that we have our hypothesized induction heads, how can be so sure that this particular head is *neccessary* for induction? Could it be that its calculating some indirect attention that is used downstream? Essentially, we want to formalize that this particular head is neccessary and sufficient for induction to occur.

How we answer this is via *activation patching*, we run an ablation on the input prompt, namely, one where induction does occur and another where it doesn't (for control, we make the prefix of the prompt the same, and only augment the final token). This technique is inspired by causal ablation studies, where we isolate the sufficient causes via interventions.

In [31]:
# We introduce two prompts, one positive case of the inductive behaviour and another of negative
clean_prompt = "Tom went to the store. Tom"
corrupt_prompt = "Tom went to the store. Sarah"

clean_tokens = model.to_tokens(clean_prompt)
corrupt_tokens = model.to_tokens(corrupt_prompt)

To do the activation patching, we first do a forward pass with the positive case

We then take the attention pattern of our hypothesized head and intervened on the corrupt pass


In [32]:
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_tokens)

In [33]:
def patch_head_output(corrupt_cache: ActivationCache, hook: HookPoint):
  """
  Applies the clean_cache activation at INDUCTION_LAYER for INDUCTION_HEAD
  """
  # Find appropriate layer to get the activations
  clean_z = clean_cache[utils.get_act_name('z', INDUCTION_LAYER)]
  corrupt_z = corrupt_cache.clone()

  # Patch the head
  corrupt_z[:, :, INDUCTION_HEAD, :] = clean_z[:, :, INDUCTION_HEAD, :]

  return corrupt_z


In [34]:
# Run the corrupt prompt but with the activation patch
patched_logits = model.run_with_hooks(
    corrupt_tokens,
    fwd_hooks=[(
        utils.get_act_name("z", INDUCTION_LAYER),
        patch_head_output
    )]
)

Using this, we can then take a metric to quanitfy the causal effect of the counterfactual.

In [35]:
# We are only interested in the final generated token
target_pos = -1

In [52]:
# Extract the prediction logits for the clean, corrupt and patched baseline
# Do this by extracting logits from [batch, n_ctx, n_vocab]
clean_last_logits = clean_logits[0, target_pos, :]
corrupt_last_logits = corrupt_logits[0, target_pos, :]
patched_last_logits = patched_logits[0, target_pos, :]

# Logits should only be for the last column, so should be vocab_size
assert(clean_last_logits.shape == (50257,))
assert(corrupt_last_logits.shape == (50257,))
assert(patched_last_logits.shape == (50257,))

In [53]:
def get_logit_diff(logits, correct: int, distractor: int):
  """
  Returns the logits difference at the `correct` index and `distractor` index
  """
  return logits[correct] - logits[distractor]

def get_argmax_token(logits):
  """
  Returns the string corresponding to the max logit
  """
  return model.to_string(torch.argmax(logits))

In [54]:
# Note that GPT2 is space sensitive
correct_token = model.to_single_token(" went")
wrong_token = model.to_single_token(" was")

In [55]:
logit_diff_clean = get_logit_diff(clean_last_logits, correct_token, wrong_token)
logit_diff_corrupt = get_logit_diff(corrupt_last_logits, correct_token, wrong_token)
logit_diff_patched = get_logit_diff(patched_last_logits, correct_token, wrong_token)

print("Logit diff (clean):", logit_diff_clean.item())
print("Logit diff (corrupt):", logit_diff_corrupt.item())
print("Logit diff (patched):", logit_diff_patched.item())

Logit diff (clean): 0.07222270965576172
Logit diff (corrupt): -0.44322967529296875
Logit diff (patched): 0.2803611755371094


In [56]:
print("Most likely token (clean):", get_argmax_token(clean_last_logits))
print("Most likely token (corrupt):", get_argmax_token(corrupt_last_logits))
print("Most likely token (patched):", get_argmax_token(patched_last_logits))

Most likely token (clean):  went
Most likely token (corrupt):  was
Most likely token (patched):  went


We note the success of the activation patching method as it greatly enhances the probability of inductive behaviour (as seen by the increase between the corrupt and patched).

However, we note the oddness that for the clean prompt, the logit is not very significant, which might suggest that the logit distribution is more uniform than expected. However, it still remains the most likely token for the clean logit.