In [1]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
import os

os.environ['HF_HOME'] = 'E:/Models/hf_cache'
os.environ['HUGGINGFACE_HUB_CACHE'] = 'E:/Models/hf_cache'

## Introduction

In [3]:
import functools
import sys
from pathlib import Path
from typing import Callable

import warnings
warnings.filterwarnings('ignore')

import circuitsvis as cv
import einops
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import eindex
from IPython.display import display
from jaxtyping import Float, Int
from torch import Tensor
from tqdm import tqdm
from transformer_lens import (
    ActivationCache,
    FactoredMatrix,
    HookedTransformer,
    HookedTransformerConfig,
    utils,
)
from transformer_lens.hook_points import HookPoint

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)

t.cuda.empty_cache()

In [4]:
import bitsandbytes
import accelerate

In [4]:
model = HookedTransformer.from_pretrained(
        "google/gemma-2b-it",
        torch_dtype=t.bfloat16,
        device="cuda",
        fold_ln=False,
        fold_value_biases=False,
        center_writing_weights=False,
        center_unembed=False,
    )
model.eval()
print("Model loaded successfully in 4-bit.")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-2b-it into HookedTransformer
Model loaded successfully in 4-bit.


In [5]:
config = model.cfg
print(f"Number of layers = {config.n_layers}\nNumber of heads per layer = {config.n_heads}\nMaximum context window = {config.n_ctx}")

Number of layers = 18
Number of heads per layer = 8
Maximum context window = 8192


In [6]:
gemini2b : HookedTransformer = model

In [7]:
text = """HookedTransformer comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. Each model is loaded into the consistent HookedTransformer architecture, designed to be clean, consistent and interpretability-friendly.

For this demo notebook we'll look at Gemini-2b-it, a 2B parameter instruction-tuned model. To try the model the model out, let's find the loss on this paragraph!"""

loss = gemini2b(text, return_type = "loss")
print("Model loss:", loss)

Model loss: tensor(5.6875, device='cuda:0', dtype=torch.bfloat16, grad_fn=<DivBackward0>)


In [8]:
print(gemini2b.to_str_tokens("gemma2"))
print(gemini2b.to_str_tokens(["gemma2", "gemma"]))
print(gemini2b.to_tokens("gemma2"))
print(gemini2b.to_string([2, 20237, 534, 235284]))

['<bos>', 'gem', 'ma', '2']
[['<bos>', 'gem', 'ma', '2'], ['<bos>', 'gem', 'ma']]
tensor([[     2,  20237,    534, 235284]], device='cuda:0')
<bos>gemma2


In [9]:
logits: Tensor = gemini2b(text, return_type = "logits")
prediction = logits.argmax(dim=-1).squeeze()[:-1]

true_tokens = gemini2b.to_tokens(text).squeeze()[1:]
is_correct = prediction == true_tokens

print(f"Model Accuracy = {is_correct.sum()}/{len(true_tokens)}")
print(f"Correct tokens = {gemini2b.to_str_tokens(prediction[is_correct])}")

Model Accuracy = 33/107
Correct tokens = [' with', '0', ' models', '.', ' can', 'ed', 'Transformer', '.', 'from', '_', 'pretrained', '(', '_', 'NAME', ')', '`.', ' model', 'ed', 'Transformer', ' to', '-', '.', '\n\n', 'll', ' at', ',', ' a', 'B', ' model', ',', "'", 's', ' the']


In [10]:
print(gemini2b.to_str_tokens("HookedTransformer", prepend_bos=False))

['Hook', 'ed', 'Transformer']


In [11]:
problem_text = "John found that the average of 15 numbers is 40. If 10 is added to each number then the mean of the numbers is? A) 50 B) 45 C) 65 D) 78 E) 64\nAnswer:"

tokens = gemini2b.to_tokens(problem_text)

g_logits, g_cache = gemini2b.run_with_cache(tokens, remove_batch_dim=True)

In [12]:
print(type(g_logits), type(g_cache))

<class 'torch.Tensor'> <class 'transformer_lens.ActivationCache.ActivationCache'>


In [13]:
attn_patterns_from_shorthand = g_cache["pattern", 0]
attn_patterns_from_full_name = g_cache["blocks.0.attn.hook_pattern"]

t.testing.assert_close(attn_patterns_from_shorthand, attn_patterns_from_full_name)

In [14]:
layer_0_pattern_from_cache = attn_patterns_from_shorthand

q, k = g_cache["q", 0], g_cache["k", 0]

seq, nhead, headsize = q.shape

layer0_attn_scores = einops.einsum(q, k, "seqQ n h, seqK n h -> n seqQ seqK") # dot_product(Q , K) 

mask = t.triu(t.ones((seq, seq), dtype=t.bool), diagonal=1).to(device) # masking

layer0_attn_scores.masked_fill_(mask, -1e9) # scaling

layer0_patter_from_q_and_k = (layer0_attn_scores / headsize**0.5).softmax(-1)

In [15]:
for layer in range(config.n_layers):
    print(type(g_cache))
    attention_pattern = g_cache["pattern", layer]
    print(attention_pattern.shape)
    g_str_tokens = gemini2b.to_str_tokens(problem_text)

    print(f"Layer {layer} Head Attention Patterns:")
    display(
        cv.attention.attention_heads(
            tokens=g_str_tokens,
            attention=attention_pattern,
            attention_head_names=[f"L{layer}H{i}" for i in range(12)],
        )
    )

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 0 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 1 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 2 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 3 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 4 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 5 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 6 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 7 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 8 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 9 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 10 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 11 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 12 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 13 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 14 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 15 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 16 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 61, 61])
Layer 17 Head Attention Patterns:


## Induction heads on 2-layer attention-only toy model

In [17]:
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True,  # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b",
    seed=398,
    use_attn_result=True,
    normalization_type=None,  # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer",
)

In [18]:
from huggingface_hub import hf_hub_download

REPO_ID = "callummcdougall/attn_only_2L_half"
FILENAME = "attn_only_2L_half.pth"

weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


attn_only_2L_half.pth:   0%|          | 0.00/184M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [19]:
toy_model = HookedTransformer(cfg)
pretrained_weights = t.load(weights_path, map_location=device, weights_only=True)
toy_model.load_state_dict(pretrained_weights)

tokenizer_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

<All keys matched successfully>

In [20]:
toy_text = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."

logits, cache = toy_model.run_with_cache(toy_text, remove_batch_dim=True)

In [21]:
for layer in range(toy_model.cfg.n_layers):
    print(type(cache))
    attention_pattern = cache["pattern", layer]
    print(attention_pattern.shape)
    t_str_tokens = toy_model.to_str_tokens(toy_text)

    print(f"Layer {layer} Head Attention Patterns:")
    display(
        cv.attention.attention_heads(
            tokens=t_str_tokens,
            attention=attention_pattern,
            attention_head_names=[f"L{layer}H{i}" for i in range(12)],
        )
    )

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 62, 62])
Layer 0 Head Attention Patterns:


<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([12, 62, 62])
Layer 1 Head Attention Patterns:


### Summary for gemma-2b

In [35]:
def current_attn_detector(cache: ActivationCache) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be current-token heads
    """
    attn_heads = []
    for layer in range(gemini2b.cfg.n_layers):
        for head in range(gemini2b.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of diagonal elements
            score = attention_pattern.diagonal().mean()
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads


def prev_attn_detector(cache: ActivationCache) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be prev-token heads
    """
    attn_heads = []
    for layer in range(gemini2b.cfg.n_layers):
        for head in range(gemini2b.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of sub-diagonal elements
            score = attention_pattern.diagonal(-1).mean()
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads


def first_attn_detector(cache: ActivationCache) -> list[str]:
    """
    Returns a list e.g. ["0.2", "1.4", "1.9"] of "layer.head" which you judge to be first-token heads
    """
    attn_heads = []
    for layer in range(gemini2b.cfg.n_layers):
        for head in range(gemini2b.cfg.n_heads):
            attention_pattern = cache["pattern", layer][head]
            # take avg of 0th elements
            score = attention_pattern[:, 0].mean()
            if score > 0.4:
                attn_heads.append(f"{layer}.{head}")
    return attn_heads

print("Heads attending to current token  = ", ", ".join(current_attn_detector(g_cache)))
print("Heads attending to previous token = ", ", ".join(prev_attn_detector(g_cache)))
print("Heads attending to first token    = ", ", ".join(first_attn_detector(g_cache)))

Heads attending to current token  =  0.5, 1.2, 2.2, 3.6, 4.3, 5.0, 6.3, 7.5, 8.1, 9.4, 11.6, 12.0, 13.3, 14.2, 17.7
Heads attending to previous token =  0.1, 0.7, 1.5, 10.4, 13.6
Heads attending to first token    =  0.2, 1.0, 1.1, 1.6, 1.7, 2.0, 2.1, 2.3, 2.4, 3.4, 3.5, 3.7, 4.0, 4.2, 4.7, 5.1, 5.4, 5.5, 5.6, 5.7, 6.0, 6.1, 6.5, 6.6, 6.7, 7.1, 7.2, 7.3, 8.0, 8.2, 8.4, 8.5, 8.6, 9.2, 10.0, 10.1, 10.2, 10.3, 10.5, 10.6, 10.7, 11.1, 11.2, 11.3, 11.4, 11.5, 12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 13.2, 13.5, 14.0, 14.1, 14.3, 14.4, 14.5, 14.6, 14.7, 15.0, 15.1, 15.2, 15.3, 15.4, 15.5, 15.6, 15.7, 16.0, 16.2, 16.3, 16.5, 16.7, 17.6
