In [1]:
#setup

try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os, sys

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")
        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

    CHAPTER = r"chapter1_transformers"
    CHAPTER_DIR = r"./" if CHAPTER in os.listdir() else os.getcwd().split(CHAPTER)[0]
    EXERCISES_DIR = CHAPTER_DIR + f"{CHAPTER}/exercises"
    sys.path.append(EXERCISES_DIR)

import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = (exercises_dir / "part2_intro_to_mech_interp").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow, hist, plot_comp_scores, plot_logit_attribution, plot_loss_difference
from part1_transformer_from_scratch.solutions import get_log_probs
import part2_intro_to_mech_interp.tests as tests

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0
Collecting jaxtyping
  Downloading jaxtyping-0.2.24-py3-none-any.whl (38 kB)
Collecting typeguard<3,>=2.13.3 (from jaxtyping)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, jaxtyping
Successfully installed jaxtyping-0.2.24 typeguard-2.13.3
Collecting transformer_lens
  Downloading transformer_lens-1.10.0-py3-none-any.whl (119 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.1/119.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m14.8 MB/s[0m eta 

In [2]:
# my configured toy transformer model using Neel Nanda's TransformerLens library.
# I will use this model to test my various attention head detectors.
# These detectors will output the different types of heads, such as induction heads

cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True, # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b",
    seed=398,
    use_attn_result=True,
    normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer"
)

#Load weights
weights_dir = (section_dir / "attn_only_2L_half.pth").resolve()

if not weights_dir.exists():
    url = "https://drive.google.com/uc?id=1vcZLJnJoYKQs-2KOjkd6LvHZrkSdoxhu"
    output = str(weights_dir)
    gdown.download(url, output)

model = HookedTransformer(cfg)
pretrained_weights = t.load(weights_dir, map_location=device)
model.load_state_dict(pretrained_weights)

Downloading...
From: https://drive.google.com/uc?id=1vcZLJnJoYKQs-2KOjkd6LvHZrkSdoxhu
To: /content/chapter1_transformers/exercises/part2_intro_to_mech_interp/attn_only_2L_half.pth
100%|██████████| 184M/184M [00:01<00:00, 107MB/s]


tokenizer_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/457k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

<All keys matched successfully>

In [13]:
# this function will allow me to visualise the attention patters
# so I can visually verify attention head detectors

text = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."

def visualise_attention_heads(model, text_or_tokens):
    # text is required as attention patterns are dependant on text duing inference time
    logits, cache = model.run_with_cache(text_or_tokens, remove_batch_dim=True)

    for layer in range(model.cfg.n_layers):
        attention_pattern = cache["pattern", layer, "attn"]
        print(attention_pattern.shape)
        str_tokens = model.to_str_tokens(text_or_tokens)

        print(f"Layer {layer} Head Attention Patterns:")
        display(cv.attention.attention_patterns(
            tokens=str_tokens,
            attention=attention_pattern,
            attention_head_names=[f"L{layer}H{i}" for i in range(12)],
        ))

visualise_attention_heads(model, text)

torch.Size([12, 62, 62])
Layer 0 Head Attention Patterns:


torch.Size([12, 62, 62])
Layer 1 Head Attention Patterns:


In [4]:
def current_attn_detector(cache: ActivationCache, layers) -> List[str]:
    result = []
    for i in range(layers):
        heads = cache["pattern", i, "attn"]
        for j, head in enumerate(heads):
            mean_prob = head.diagonal().mean()
            #print(f"layer: {i}, head: {j}", "mean:", mean_prob)
            if mean_prob > 0.35:
                result.append(f"Layer {i}, head {j}")
    return result


def prev_attn_detector(cache: ActivationCache, layers) -> List[str]:
    result = []
    for i in range(layers):
        heads = cache["pattern", i, "attn"]
        for j, head in enumerate(heads):
            mean_prob = head.diagonal(-1).mean()
            #print(f"layer: {i}, head: {j}", "mean:", mean_prob)
            if mean_prob > 0.3:
                result.append(f"Layer {i}, head {j}")
    return result


def first_attn_detector(cache: ActivationCache, layers) -> List[str]:
    result = []
    for i in range(layers):
        heads = cache["pattern", i, "attn"]
        for j, head in enumerate(heads):
            mean_prob = head[:,0].mean()
            #print(f"layer: {i}, head: {j}", "mean:", mean_prob)
            if mean_prob > 0.8:
                result.append(f"Layer {i}, head {j}")
    return result


def categorise_attention_heads(model, text):
    logits, cache = model.run_with_cache(text, remove_batch_dim=True)
    layers = model.cfg.n_layers
    print("Heads attending to current token  = ", " | ".join(current_attn_detector(cache, layers)))
    print("\nHeads attending to previous token = ", " | ".join(prev_attn_detector(cache, layers)))
    print("\nHeads attending to first token    = ", " | ".join(first_attn_detector(cache, layers)))

categorise_attention_heads(model, text)

Heads attending to current token  =  Layer 0, head 9 | Layer 0, head 11 | Layer 1, head 6

Heads attending to previous token =  Layer 0, head 7 | Layer 0, head 9

Heads attending to first token    =  Layer 0, head 3 | Layer 1, head 4 | Layer 1, head 10


In [14]:
# indiction heads are difficult to quantify so I will use repeated text
# in the following section to make spotting induction heads easier

def generate_repeated_text(model: HookedTransformer, text_length: int, batch: int = 1):
    bos = model.to_tokens('')
    random_tokens = t.randint(model.cfg.d_vocab,(batch, text_length),device='cuda:0')
    repeated_tokens = t.cat((bos, random_tokens, random_tokens),-1)
    repeated_text = model.to_string(repeated_tokens)
    return repeated_tokens, repeated_text


repeated_tokens, repeated_text = generate_repeated_text(model, 50)
print(repeated_text)
visualise_attention_heads(model, repeated_tokens)


['<|endoftext|> percol cere annmsg exhib highest gatheredPlay VirMINulu�585 bombTouchhammer artic ein farther Ax flowed abuse}}{{\\Mc wed corporate millilit Citizens CDCRoss origHOSTış chaptersellarcbHistoryquitAbbreviations superficial genome pres mortgage�manualenzymeOLD CNTimeterestival percol cere annmsg exhib highest gatheredPlay VirMINulu�585 bombTouchhammer artic ein farther Ax flowed abuse}}{{\\Mc wed corporate millilit Citizens CDCRoss origHOSTış chaptersellarcbHistoryquitAbbreviations superficial genome pres mortgage�manualenzymeOLD CNTimeterestival']
torch.Size([12, 101, 101])
Layer 0 Head Attention Patterns:


torch.Size([12, 101, 101])
Layer 1 Head Attention Patterns:


In [19]:
def induction_attn_detector(cache: ActivationCache, layers) -> List[str]:
    result = []
    for i in range(layers):
        heads = cache["pattern", i, "attn"]
        for j, head in enumerate(heads):
            random_seq_len = int(len(head)//2)
            mean_prob = head.diagonal(-random_seq_len+1).mean()
            # print(f"layer: {i}, head: {j}", "mean:", mean_prob)
            if mean_prob > 0.6:
                result.append(f"Layer {i}, head {j}")
    return result

def categorise_indiction_heads(model, tokens):
    logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
    layers = model.cfg.n_layers
    print("\nInduction heads   = ", " | ".join(induction_attn_detector(cache, layers)))

categorise_indiction_heads(model, repeated_tokens)


Induction heads   =  Layer 1, head 4 | Layer 1, head 10
