In [1]:
# NBVAL_IGNORE_OUTPUT
import os

# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"
IN_GITHUB = True
try:
    import google.colab

    IN_COLAB = True
    print("Running as a Colab notebook")

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False

if not IN_GITHUB and not IN_COLAB:
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

if IN_GITHUB or IN_COLAB:
    %pip install torch
    %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git@dev
    
from transformer_lens import HookedTransformer, HookedTransformerConfig
import torch as t

device = t.device("cuda" if t.cuda.is_available() else "cpu")

zsh:1: 2.3 not found
Note: you may need to restart the kernel to use updated packages.
[0mCollecting git+https://github.com/TransformerLensOrg/TransformerLens.git@dev
  Cloning https://github.com/TransformerLensOrg/TransformerLens.git (to revision dev) to /private/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/pip-req-build-2tlfmlwg
  Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/TransformerLens.git /private/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/pip-req-build-2tlfmlwg
  Running command git checkout -b dev --track origin/dev
  Switched to a new branch 'dev'
  branch 'dev' set up to track 'origin/dev'.
  Resolved https://github.com/TransformerLensOrg/TransformerLens.git to commit a0e45ef95b5b379654ed7d6b5c014f6909f60401
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting torch!=2.0,!=2.1.0,>=1.10 (from transfo

In [2]:
# NBVAL_IGNORE_OUTPUT


reference_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small",
    fold_ln=False,
    center_unembed=False,
    center_writing_weights=False,
    device=device,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:

# [1.1] Transformer From Scratch
# 1️⃣ UNDERSTANDING INPUTS & OUTPUTS OF A TRANSFORMER

sorted_vocab = sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda n: n[1])
first_vocab = sorted_vocab[0]
assert isinstance(first_vocab, tuple)
assert isinstance(first_vocab[0], str)
first_vocab[1]

0

In [4]:
reference_gpt2.to_str_tokens("Ralph")

['<|endoftext|>', 'R', 'alph']

In [5]:
reference_gpt2.to_str_tokens(" Ralph")

['<|endoftext|>', ' Ralph']

In [6]:

reference_gpt2.to_str_tokens(" ralph")


['<|endoftext|>', ' r', 'alph']

In [7]:
reference_gpt2.to_str_tokens("ralph")

['<|endoftext|>', 'ral', 'ph']

In [8]:

reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text)
tokens.shape


torch.Size([1, 35])

In [9]:

logits, cache = reference_gpt2.run_with_cache(tokens, device=device)
logits.shape


torch.Size([1, 35, 50257])

In [10]:

most_likely_next_tokens = reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])
most_likely_next_tokens[-1]



' I'

In [11]:
# 2️⃣ CLEAN TRANSFORMER IMPLEMENTATION

layer_0_hooks = [
    (name, tuple(tensor.shape)) for name, tensor in cache.items() if ".0." in name
]
non_layer_hooks = [
    (name, tuple(tensor.shape)) for name, tensor in cache.items() if "blocks" not in name
]

sorted(non_layer_hooks, key=lambda x: x[0])


[('hook_embed', (1, 35, 768)),
 ('hook_pos_embed', (1, 35, 768)),
 ('ln_final.hook_normalized', (1, 35, 768)),
 ('ln_final.hook_scale', (1, 35, 1))]

In [12]:

sorted(layer_0_hooks, key=lambda x: x[0])

[('blocks.0.attn.hook_attn_scores', (1, 12, 35, 35)),
 ('blocks.0.attn.hook_k', (1, 35, 12, 64)),
 ('blocks.0.attn.hook_pattern', (1, 12, 35, 35)),
 ('blocks.0.attn.hook_q', (1, 35, 12, 64)),
 ('blocks.0.attn.hook_v', (1, 35, 12, 64)),
 ('blocks.0.attn.hook_z', (1, 35, 12, 64)),
 ('blocks.0.hook_attn_out', (1, 35, 768)),
 ('blocks.0.hook_mlp_out', (1, 35, 768)),
 ('blocks.0.hook_resid_mid', (1, 35, 768)),
 ('blocks.0.hook_resid_post', (1, 35, 768)),
 ('blocks.0.hook_resid_pre', (1, 35, 768)),
 ('blocks.0.ln1.hook_normalized', (1, 35, 768)),
 ('blocks.0.ln1.hook_scale', (1, 35, 1)),
 ('blocks.0.ln2.hook_normalized', (1, 35, 768)),
 ('blocks.0.ln2.hook_scale', (1, 35, 1)),
 ('blocks.0.mlp.hook_post', (1, 35, 3072)),
 ('blocks.0.mlp.hook_pre', (1, 35, 3072))]

In [13]:
# NBVAL_IGNORE_OUTPUT
# [1.2] Intro to mech interp
# 2️⃣ FINDING INDUCTION HEADS

cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True, # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b", 
    seed=398,
    use_attn_result=True,
    normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer"
)
model = HookedTransformer(cfg)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [14]:


text = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."

logits, cache = model.run_with_cache(text, remove_batch_dim=True)

logits.shape

torch.Size([1, 62, 50278])

In [15]:
cache["embed"].ndim

2