<a href="https://colab.research.google.com/github/alif-munim/mech-interp/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"

if IN_COLAB:
    # Install packages
    %pip install transformer_lens
    %pip install einops
    %pip install jaxtyping
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

    # Code to download the necessary files (e.g. solutions, test funcs)
    if not os.path.exists(f"/content/{chapter}"):
        !wget https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/main.zip
        !unzip /content/main.zip 'ARENA_3.0-main/chapter1_transformer_interp/exercises/*'
        sys.path.append(f"/content/{repo}-main/{chapter}/exercises")
        os.remove("/content/main.zip")
        os.rename(f"{repo}-main/{chapter}", chapter)
        os.rmdir(f"{repo}-main")
        os.chdir(f"{chapter}/exercises")
else:
    chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
    sys.path.append(chapter_dir + f"{chapter}/exercises")

Collecting transformer_lens
  Downloading transformer_lens-2.4.0-py3-none-any.whl.metadata (12 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer_lens)
  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)
Collecting wandb>=0.13.5 (from transformer_lens)
  Downloading wandb-0.17.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Collecting pyarrow>=15.0.0 (from datasets>=2.7.1->transformer_lens)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata 

In [2]:
import os; os.environ['ACCELERATE_DISABLE_RICH'] = "1"
import sys
import einops
from dataclasses import dataclass
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
import math
from tqdm.notebook import tqdm
from typing import Tuple, List, Optional, Dict, Callable
from jaxtyping import Float, Int
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from collections import defaultdict
from rich.table import Table
from rich import print as rprint
import datasets
from torch.utils.data import DataLoader
import wandb
from pathlib import Path
import webbrowser

# Make sure exercises are in the path
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = (exercises_dir / "part1_transformer_from_scratch").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
import part1_transformer_from_scratch.solutions as solutions
import part1_transformer_from_scratch.tests as tests

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == '__main__'

reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


# Tokenization

In [3]:
# by default, the vocab is a dictionary ('word': <idx>)
vocab = reference_gpt2.tokenizer.vocab

# the .items() function returns a view object
# it reflects any changes made to the dictionary (like a window)
# but it needs to be converted into a list
list_vocab = list(vocab.items())

# we sort the list by using a lambda (anonymous) function
# we pass the tuple as the input to the key arg, and use the vocab index as the key
sorted_vocab = sorted(list_vocab, key=lambda n: n[1])

In [4]:
print(f"Total vocab size: {len(sorted_vocab)}")
print()
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])

Total vocab size: 50257

[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19)]

[('ľ', 250), ('Ŀ', 251), ('ŀ', 252), ('Ł', 253), ('ł', 254), ('Ń', 255), ('Ġt', 256), ('Ġa', 257), ('he', 258), ('in', 259), ('re', 260), ('on', 261), ('Ġthe', 262), ('er', 263), ('Ġs', 264), ('at', 265), ('Ġw', 266), ('Ġo', 267), ('en', 268), ('Ġc', 269)]

[('Ġprodu', 990), ('Ġstill', 991), ('led', 992), ('ah', 993), ('Ġhere', 994), ('Ġworld', 995), ('Ġthough', 996), ('Ġnum', 997), ('arch', 998), ('imes', 999), ('ale', 1000), ('ĠSe', 1001), ('ĠIf', 1002), ('//', 1003), ('ĠLe', 1004), ('Ġret', 1005), ('Ġref', 1006), ('Ġtrans', 1007), ('ner', 1008), ('ution', 1009)]


In [5]:
print(sorted_vocab[-20:])

[('Revolution', 50237), ('Ġsnipers', 50238), ('Ġreverted', 50239), ('Ġconglomerate', 50240), ('Terry', 50241), ('794', 50242), ('Ġharsher', 50243), ('Ġdesolate', 50244), ('ĠHitman', 50245), ('Commission', 50246), ('Ġ(/', 50247), ('âĢ¦."', 50248), ('Compar', 50249), ('Ġamplification', 50250), ('ominated', 50251), ('Ġregress', 50252), ('ĠCollider', 50253), ('Ġinformants', 50254), ('Ġgazed', 50255), ('<|endoftext|>', 50256)]


In [6]:
# Spacing and capitalization can lead to different tokenization
print(reference_gpt2.to_str_tokens("Ralph"))
print(reference_gpt2.to_str_tokens(" Ralph"))
print(reference_gpt2.to_str_tokens(" ralph"))
print(reference_gpt2.to_str_tokens("ralph"))

['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']


In [7]:
# Arithmetic is a mess because numbers are also split into smaller chunks.
print(reference_gpt2.to_str_tokens("56873+3184623=123456789-1000000000"))

['<|endoftext|>', '568', '73', '+', '318', '46', '23', '=', '123', '45', '67', '89', '-', '1', '000000', '000']


In [8]:
reference_text = "I am GPT-2 style transformer, and can't wait to take over the"
tokens = reference_gpt2.to_tokens(reference_text).to(device)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

tensor([[50256,    40,   716,   402, 11571,    12,    17,  3918, 47385,    11,
           290,   460,   470,  4043,   284,  1011,   625,   262]],
       device='cuda:0')
torch.Size([1, 18])
['<|endoftext|>', 'I', ' am', ' G', 'PT', '-', '2', ' style', ' transformer', ',', ' and', ' can', "'t", ' wait', ' to', ' take', ' over', ' the']


In [9]:
logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 18, 50257])


In [10]:
# softmax along the vocab dimension (shape doesn't change)
probs = logits.softmax(dim=-1)
print(probs.shape)

torch.Size([1, 18, 50257])


In [11]:
import pprint
most_likely_next_tokens = reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])
pprint.pp(list(zip(reference_gpt2.to_str_tokens(reference_text), most_likely_next_tokens)))

[('<|endoftext|>', '\n'),
 ('I', "'m"),
 (' am', ' a'),
 (' G', '.'),
 ('PT', ','),
 ('-', '1'),
 ('2', ','),
 (' style', ','),
 (' transformer', '.'),
 (',', ' I'),
 (' and', ' I'),
 (' can', "'t"),
 ("'t", ' wait'),
 (' wait', ' to'),
 (' to', ' get'),
 (' take', ' my'),
 (' over', ' the'),
 (' the', ' world')]


In [12]:
logits.shape

torch.Size([1, 18, 50257])

In [13]:
next_token = logits[0, -1].argmax(dim=-1)
next_char = reference_gpt2.to_string(next_token)
print(repr(next_char))

' world'


In [14]:
next_token.shape

torch.Size([])

In [15]:
tokens.shape

torch.Size([1, 18])

In [16]:
tokens

tensor([[50256,    40,   716,   402, 11571,    12,    17,  3918, 47385,    11,
           290,   460,   470,  4043,   284,  1011,   625,   262]],
       device='cuda:0')

In [17]:
print(f"Sequence so far: {reference_gpt2.to_string(tokens)}")

for i in range(20):
  print(f"{tokens.shape[-1]+1}th char = {next_char}")

  # add the batch and sequence dimensions for torch concat
  tokens = t.cat([tokens, next_token[None, None]], dim=-1)

  # get logits, select max prob token, convert to str
  logits = reference_gpt2(tokens)
  next_token = logits[0, -1].argmax(dim=-1)
  next_char = reference_gpt2.to_string(next_token)

Sequence so far: ["<|endoftext|>I am GPT-2 style transformer, and can't wait to take over the"]
19th char =  world
20th char = .
21th char =  I
22th char =  have
23th char =  been
24th char =  using
25th char =  this
26th char =  transformer
27th char =  for
28th char =  over
29th char =  a
30th char =  year
31th char =  now
32th char = ,
33th char =  and
34th char =  I
35th char =  am
36th char =  very
37th char =  happy
38th char =  with


# Architecture

Reference:
```
batch = 1
position = 35
d_model = 768
n_heads = 12
n_layers = 12
d_mlp = 3072 (= 4 * d_model)
d_head = 64 (= d_model / n_heads)
```

In [18]:
# print all activation shapes of ref model
for activation_name, activation in cache.items():
  # only print the first layer activations
  if '.0.' in activation_name or "blocks" not in activation_name:
    print(f"{activation_name:30} {tuple(activation.shape)}")

hook_embed                     (1, 18, 768)
hook_pos_embed                 (1, 18, 768)
blocks.0.hook_resid_pre        (1, 18, 768)
blocks.0.ln1.hook_scale        (1, 18, 1)
blocks.0.ln1.hook_normalized   (1, 18, 768)
blocks.0.attn.hook_q           (1, 18, 12, 64)
blocks.0.attn.hook_k           (1, 18, 12, 64)
blocks.0.attn.hook_v           (1, 18, 12, 64)
blocks.0.attn.hook_attn_scores (1, 12, 18, 18)
blocks.0.attn.hook_pattern     (1, 12, 18, 18)
blocks.0.attn.hook_z           (1, 18, 12, 64)
blocks.0.hook_attn_out         (1, 18, 768)
blocks.0.hook_resid_mid        (1, 18, 768)
blocks.0.ln2.hook_scale        (1, 18, 1)
blocks.0.ln2.hook_normalized   (1, 18, 768)
blocks.0.mlp.hook_pre          (1, 18, 3072)
blocks.0.mlp.hook_post         (1, 18, 3072)
blocks.0.hook_mlp_out          (1, 18, 768)
blocks.0.hook_resid_post       (1, 18, 768)
ln_final.hook_scale            (1, 18, 1)
ln_final.hook_normalized       (1, 18, 768)


In [19]:
# print all param shapes of reference model
for name, param in reference_gpt2.named_parameters():
  # only print first layer (18 for spacing indentation)
  if ".0." in name or "blocks" not in name:
    print(f"{name:18} {tuple(param.shape)}")

embed.W_E          (50257, 768)
pos_embed.W_pos    (1024, 768)
blocks.0.ln1.w     (768,)
blocks.0.ln1.b     (768,)
blocks.0.ln2.w     (768,)
blocks.0.ln2.b     (768,)
blocks.0.attn.W_Q  (12, 768, 64)
blocks.0.attn.W_O  (12, 64, 768)
blocks.0.attn.b_Q  (12, 64)
blocks.0.attn.b_O  (768,)
blocks.0.attn.W_K  (12, 768, 64)
blocks.0.attn.W_V  (12, 768, 64)
blocks.0.attn.b_K  (12, 64)
blocks.0.attn.b_V  (12, 64)
blocks.0.mlp.W_in  (768, 3072)
blocks.0.mlp.b_in  (3072,)
blocks.0.mlp.W_out (3072, 768)
blocks.0.mlp.b_out (768,)
ln_final.w         (768,)
ln_final.b         (768,)
unembed.W_U        (768, 50257)
unembed.b_U        (50257,)


In [20]:
# all model hyperparameters
print(reference_gpt2.cfg)

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LN',
 'num_experts': None,
 'original_architecture': 'GPT2LMHeadModel',
 'output_logits_soft_cap': -1.0,
 'parallel_attn_mlp': False,
 'positional

In [21]:
@dataclass
class Config:
  d_model: int = 768
  debug: bool = True
  layer_norm_eps: float = 1e-5
  d_vocab: int = 50257
  init_range: float = 0.02
  n_ctx: int = 1024
  d_head: int = 64
  d_mlp: int = 3072
  n_heads: int = 12
  n_layers: int = 12

cfg = Config()
pprint.pp(cfg)

Config(d_model=768,
       debug=True,
       layer_norm_eps=1e-05,
       d_vocab=50257,
       init_range=0.02,
       n_ctx=1024,
       d_head=64,
       d_mlp=3072,
       n_heads=12,
       n_layers=12)


In [29]:
def rand_float_test(cls, shape):
  cfg = Config(debug=True)
  layer = cls(cfg).to(device)
  random_input = t.randn(shape).to(device)
  print("Input shape:", random_input.shape)
  output = layer(random_input)
  if isinstance(output, tuple): output = output[0]
  print("Output shape:", output.shape, "\n")

def rand_int_test(cls, shape):
  cfg = Config(debug=True)
  layer = cls(cfg).to(device)
  random_input = t.randint(100, 1000, shape).to(device)
  print("Input shape:", random_input.shape)
  output = layer(random_input)
  if isinstance(output, tuple): output = output[0]
  print("Output shape:", output.shape, "\n")

def load_gpt2_test(cls, gpt2_layer, input):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    print("Input shape:", input.shape)
    output = layer(input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape)
    try: reference_output = gpt2_layer(input)
    except: reference_output = gpt2_layer(input, input, input)
    print("Reference output shape:", reference_output.shape, "\n")
    comparison = t.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct\n")

In [48]:
class LayerNorm(nn.Module):
  """
  After centering and normalization, each vector of length d_model should have mean 0 and variance 1
  """
  def __init__(self, cfg: Config):
    super().__init__()
    self.cfg = cfg
    self.w = nn.Parameter(t.ones(cfg.d_model))
    self.b = nn.Parameter(t.zeros(cfg.d_model))
    self.eps=1e-05

  def forward(self, residual: Float[Tensor, "batch psn d_model"]) -> Float[Tensor, "batch psn d_model"]:

    # always calculate over last dimensions
    # keepdim for reduction operations to preserve original dims
    residual_mean = residual.mean(dim=-1,keepdim=True)
    residual_var = (residual.var(dim=-1,keepdim=True) + self.eps).sqrt()
    y = (residual - residual_mean)/residual_var
    return y * self.w + self.b

rand_float_test(LayerNorm, [2,4,768])
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, cache["resid_post", 11])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 18, 768])
Output shape: torch.Size([1, 18, 768])
Reference output shape: torch.Size([1, 18, 768]) 

98.73% of the values are correct



In [120]:
x = t.randn((2,2,768))
print("ORIGINAL")
print(x)
print(x.shape)
print()
ln = LayerNorm(cfg)
norm_x = ln(x)
print("NORLMALIZED")
print(norm_x)
print(norm_x.shape)

ORIGINAL
tensor([[[ 0.4221,  0.4041,  0.2338,  ...,  2.1154,  1.0884, -1.3014],
         [-1.0093, -0.7415, -0.4963,  ...,  0.5777,  0.5499, -0.6779]],

        [[-0.1677, -1.5213,  1.3647,  ..., -1.1647, -0.2363, -0.2931],
         [ 0.1709, -0.7709,  0.3691,  ...,  0.3170, -0.6554,  1.1314]]])
torch.Size([2, 2, 768])

NORLMALIZED
tensor([[[ 0.4299,  0.4122,  0.2458,  ...,  2.0851,  1.0812, -1.2548],
         [-0.9790, -0.7099, -0.4636,  ...,  0.6155,  0.5875, -0.6460]],

        [[-0.1418, -1.4483,  1.3372,  ..., -1.1042, -0.2081, -0.2628],
         [ 0.1292, -0.8356,  0.3322,  ...,  0.2789, -0.7173,  1.1132]]],
       grad_fn=<AddBackward0>)
torch.Size([2, 2, 768])


In [49]:
class Embed(nn.Module):
  def __init__(self, cfg: Config):
    super().__init__()
    self.cfg = cfg
    # lookup table is d_vocab -> d_model
    self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
    nn.init.normal_(self.W_E, std=self.cfg.init_range)

  def forward(self, tokens: Int[Tensor, "batch posn"]) -> Float[Tensor, "batch posn d_model"]:
    return self.W_E[tokens]

rand_int_test(Embed, [2,4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 38])
Output shape: torch.Size([1, 38, 768])
Reference output shape: torch.Size([1, 38, 768]) 

100.00% of the values are correct



In [88]:
x = t.randint(1, 10, (10,6))
print(x) # table for 10 words, mapped to 6d vectors

tensor([[2, 4, 3, 4, 4, 2],
        [7, 2, 6, 9, 9, 5],
        [1, 4, 8, 5, 4, 9],
        [9, 6, 8, 7, 7, 5],
        [8, 7, 2, 6, 7, 2],
        [9, 3, 7, 6, 5, 9],
        [9, 3, 6, 3, 5, 6],
        [3, 5, 3, 2, 1, 3],
        [2, 8, 1, 9, 9, 5],
        [4, 2, 5, 5, 6, 8]])


In [99]:
embeds = x[t.LongTensor(
    [[2,4,0],
    [1,0,5]]
)] # 2 sequences with (word 2, word 4, word 0) and (word 1, word 0, word 5)
print(embeds)
print()
print(embeds.shape) # batch_size, seq_len, d_model

tensor([[[1, 4, 8, 5, 4, 9],
         [8, 7, 2, 6, 7, 2],
         [2, 4, 3, 4, 4, 2]],

        [[7, 2, 6, 9, 9, 5],
         [2, 4, 3, 4, 4, 2],
         [9, 3, 7, 6, 5, 9]]])

torch.Size([2, 3, 6])


In [128]:
class PosEmbed(nn.Module):
  def __init__(self, cfg: Config):
    super().__init__()
    self.cfg = cfg
    # lookup table is seq_len -> d_model
    self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
    nn.init.normal_(self.W_pos, std=self.cfg.init_range)

  def forward(self, tokens: Int[Tensor, "batch posn"]) -> Float[Tensor, "batch posn d_model"]:
    batch, seq_len = tokens.shape
    # specify the new axis (batch) size, or number of repeats
    return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch=batch)

rand_int_test(PosEmbed, [2,4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 38])
Output shape: torch.Size([1, 38, 768])
Reference output shape: torch.Size([1, 38, 768]) 

100.00% of the values are correct



# Sampling