# Setup
(No need to read)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-faier9zj
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-faier9zj
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 51771ff42e313eb64ffd95841c0923c0c0865efd
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting accelerate>=0.23.0 (from transformer-lens==0.0.0)
  Downloading accelerate-0.24.0-py3-none-any.whl (260 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.0/261.0 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [5]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [6]:
model = HookedTransformer.from_pretrained("gpt2-medium")

Downloading (…)lve/main/config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-medium into HookedTransformer


# Fibonacci

In [7]:
example_prompt = "0, 1, 1, 2, 3, 5,"
example_answer = " 8"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '0', ',', ' 1', ',', ' 1', ',', ' 2', ',', ' 3', ',', ' 5', ',']
Tokenized answer: [' 8']


Top 0th token. Logit: 16.76 Prob: 28.96% Token: | 7|
Top 1th token. Logit: 16.61 Prob: 24.91% Token: | 6|
Top 2th token. Logit: 15.75 Prob: 10.55% Token: | 8|
Top 3th token. Logit: 15.66 Prob:  9.62% Token: | 10|
Top 4th token. Logit: 15.33 Prob:  6.94% Token: | 9|
Top 5th token. Logit: 14.43 Prob:  2.81% Token: | 5|
Top 6th token. Logit: 14.02 Prob:  1.86% Token: | 11|
Top 7th token. Logit: 13.46 Prob:  1.07% Token: | 12|
Top 8th token. Logit: 13.27 Prob:  0.88% Token: | 13|
Top 9th token. Logit: 13.22 Prob:  0.84% Token: | 1|


In [8]:
example_prompt = "0 1 1 2 3 5"
example_answer = " 8"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '0', ' 1', ' 1', ' 2', ' 3', ' 5']
Tokenized answer: [' 8']


Top 0th token. Logit: 18.96 Prob: 50.68% Token: | 6|
Top 1th token. Logit: 17.41 Prob: 10.74% Token: | 7|
Top 2th token. Logit: 17.20 Prob:  8.69% Token: | 8|
Top 3th token. Logit: 16.82 Prob:  5.96% Token: | 10|
Top 4th token. Logit: 16.25 Prob:  3.37% Token: | 0|
Top 5th token. Logit: 16.25 Prob:  3.36% Token: | 5|
Top 6th token. Logit: 16.05 Prob:  2.75% Token: | 9|
Top 7th token. Logit: 15.91 Prob:  2.39% Token: | 11|
Top 8th token. Logit: 15.46 Prob:  1.53% Token: | 4|
Top 9th token. Logit: 14.80 Prob:  0.79% Token: | 12|


In [9]:
example_prompt = "0 1 1 2 3 5 8 13 21"
example_answer = " 34"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '0', ' 1', ' 1', ' 2', ' 3', ' 5', ' 8', ' 13', ' 21']
Tokenized answer: [' 34']


Top 0th token. Logit: 20.99 Prob: 16.89% Token: | 48|
Top 1th token. Logit: 20.88 Prob: 15.10% Token: | 34|
Top 2th token. Logit: 20.48 Prob: 10.18% Token: | 24|
Top 3th token. Logit: 20.09 Prob:  6.86% Token: | 30|
Top 4th token. Logit: 20.06 Prob:  6.66% Token: | 50|
Top 5th token. Logit: 20.00 Prob:  6.28% Token: | 49|
Top 6th token. Logit: 19.86 Prob:  5.44% Token: | 33|
Top 7th token. Logit: 19.56 Prob:  4.03% Token: | Next|
Top 8th token. Logit: 19.19 Prob:  2.81% Token: | 25|
Top 9th token. Logit: 18.95 Prob:  2.19% Token: | 35|


# 2 4 6 8

In [10]:
example_prompt = "2 4 6 8 10 12"
example_answer = " 14"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '2', ' 4', ' 6', ' 8', ' 10', ' 12']
Tokenized answer: [' 14']


Top 0th token. Logit: 16.38 Prob: 44.89% Token: | 14|
Top 1th token. Logit: 15.45 Prob: 17.68% Token: | 13|
Top 2th token. Logit: 14.83 Prob:  9.55% Token: | 16|
Top 3th token. Logit: 14.32 Prob:  5.73% Token: | 15|
Top 4th token. Logit: 13.99 Prob:  4.10% Token: |
|
Top 5th token. Logit: 13.13 Prob:  1.74% Token: | 12|
Top 6th token. Logit: 13.11 Prob:  1.70% Token: | 20|
Top 7th token. Logit: 13.09 Prob:  1.68% Token: | 18|
Top 8th token. Logit: 12.44 Prob:  0.88% Token: | 19|
Top 9th token. Logit: 12.22 Prob:  0.70% Token: | 1|


# arithm plus 1

In [11]:
example_prompt = "1 + 1 ="
example_answer = " 2"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' +', ' 1', ' =']
Tokenized answer: [' 2']


Top 0th token. Logit: 15.83 Prob: 36.51% Token: | 2|
Top 1th token. Logit: 14.89 Prob: 14.18% Token: | 3|
Top 2th token. Logit: 14.57 Prob: 10.32% Token: | 5|
Top 3th token. Logit: 14.47 Prob:  9.34% Token: | 4|
Top 4th token. Logit: 13.72 Prob:  4.43% Token: | 6|
Top 5th token. Logit: 13.50 Prob:  3.54% Token: | 1|
Top 6th token. Logit: 13.23 Prob:  2.71% Token: | 7|
Top 7th token. Logit: 13.03 Prob:  2.20% Token: | 10|
Top 8th token. Logit: 12.97 Prob:  2.08% Token: | 8|
Top 9th token. Logit: 12.69 Prob:  1.58% Token: | 0|


In [13]:
example_prompt = "5 + 1 ="
example_answer = " 6"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '5', ' +', ' 1', ' =']
Tokenized answer: [' 6']


Top 0th token. Logit: 15.12 Prob: 16.57% Token: | 7|
Top 1th token. Logit: 15.11 Prob: 16.33% Token: | 6|
Top 2th token. Logit: 14.88 Prob: 12.96% Token: | 10|
Top 3th token. Logit: 14.45 Prob:  8.44% Token: | 8|
Top 4th token. Logit: 14.14 Prob:  6.20% Token: | 9|
Top 5th token. Logit: 14.02 Prob:  5.50% Token: | 5|
Top 6th token. Logit: 13.63 Prob:  3.72% Token: | 11|
Top 7th token. Logit: 13.58 Prob:  3.54% Token: | 12|
Top 8th token. Logit: 13.22 Prob:  2.48% Token: | 15|
Top 9th token. Logit: 12.93 Prob:  1.86% Token: | 13|


In [12]:
example_prompt = "10 + 1 ="
example_answer = " 11"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '10', ' +', ' 1', ' =']
Tokenized answer: [' 11']


Top 0th token. Logit: 13.91 Prob:  8.64% Token: | 10|
Top 1th token. Logit: 13.80 Prob:  7.77% Token: | 20|
Top 2th token. Logit: 13.67 Prob:  6.78% Token: | 12|
Top 3th token. Logit: 13.66 Prob:  6.76% Token: | 11|
Top 4th token. Logit: 13.54 Prob:  5.98% Token: | 15|
Top 5th token. Logit: 13.13 Prob:  3.95% Token: | 13|
Top 6th token. Logit: 12.87 Prob:  3.04% Token: | 30|
Top 7th token. Logit: 12.81 Prob:  2.89% Token: | 14|
Top 8th token. Logit: 12.73 Prob:  2.67% Token: | 25|
Top 9th token. Logit: 12.60 Prob:  2.33% Token: | $|


In [14]:
example_prompt = "1 + 1 = 2. 10 + 1 = 11. 3 + 1 = 4. 5 + 1 ="
example_answer = " 6"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' +', ' 1', ' =', ' 2', '.', ' 10', ' +', ' 1', ' =', ' 11', '.', ' 3', ' +', ' 1', ' =', ' 4', '.', ' 5', ' +', ' 1', ' =']
Tokenized answer: [' 6']


Top 0th token. Logit: 18.35 Prob: 43.05% Token: | 7|
Top 1th token. Logit: 17.90 Prob: 27.55% Token: | 6|
Top 2th token. Logit: 17.17 Prob: 13.16% Token: | 8|
Top 3th token. Logit: 16.21 Prob:  5.06% Token: | 5|
Top 4th token. Logit: 16.06 Prob:  4.36% Token: | 9|
Top 5th token. Logit: 15.16 Prob:  1.78% Token: | 10|
Top 6th token. Logit: 14.33 Prob:  0.77% Token: | 4|
Top 7th token. Logit: 13.93 Prob:  0.52% Token: | 11|
Top 8th token. Logit: 13.81 Prob:  0.46% Token: | 12|
Top 9th token. Logit: 13.72 Prob:  0.42% Token: | 1|


In [15]:
example_prompt = "1 table 2. 10 table 11. 3 table 4. 5 table"
example_answer = " 6"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' table', ' 2', '.', ' 10', ' table', ' 11', '.', ' 3', ' table', ' 4', '.', ' 5', ' table']
Tokenized answer: [' 6']


Top 0th token. Logit: 17.03 Prob: 51.56% Token: | 6|
Top 1th token. Logit: 15.45 Prob: 10.62% Token: | 5|
Top 2th token. Logit: 15.26 Prob:  8.75% Token: | 7|
Top 3th token. Logit: 14.23 Prob:  3.13% Token: | 8|
Top 4th token. Logit: 14.10 Prob:  2.75% Token: | 9|
Top 5th token. Logit: 14.07 Prob:  2.67% Token: | 12|
Top 6th token. Logit: 14.07 Prob:  2.66% Token: | 11|
Top 7th token. Logit: 13.84 Prob:  2.12% Token: | 4|
Top 8th token. Logit: 13.66 Prob:  1.77% Token: | 10|
Top 9th token. Logit: 13.56 Prob:  1.60% Token: | 13|


In [17]:
example_prompt = "1 table 2. 10 table 11. 3 table 4. 15 table"
example_answer = " 16"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' table', ' 2', '.', ' 10', ' table', ' 11', '.', ' 3', ' table', ' 4', '.', ' 15', ' table']
Tokenized answer: [' 16']


Top 0th token. Logit: 16.46 Prob: 30.21% Token: | 16|
Top 1th token. Logit: 15.19 Prob:  8.42% Token: | 12|
Top 2th token. Logit: 15.11 Prob:  7.82% Token: | 6|
Top 3th token. Logit: 14.98 Prob:  6.88% Token: | 17|
Top 4th token. Logit: 14.65 Prob:  4.91% Token: | 18|
Top 5th token. Logit: 14.48 Prob:  4.17% Token: | 4|
Top 6th token. Logit: 14.42 Prob:  3.91% Token: | 13|
Top 7th token. Logit: 14.42 Prob:  3.91% Token: | 14|
Top 8th token. Logit: 14.36 Prob:  3.68% Token: | 2|
Top 9th token. Logit: 14.10 Prob:  2.83% Token: | 15|


In [19]:
example_prompt = "1 table 2. 10 table 11. 3 table 4. 101 table"
example_answer = " 102"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' table', ' 2', '.', ' 10', ' table', ' 11', '.', ' 3', ' table', ' 4', '.', ' 101', ' table']
Tokenized answer: [' 102']


Top 0th token. Logit: 15.68 Prob: 22.84% Token: | 102|
Top 1th token. Logit: 15.23 Prob: 14.58% Token: | 12|
Top 2th token. Logit: 14.67 Prob:  8.29% Token: | 10|
Top 3th token. Logit: 14.64 Prob:  8.04% Token: | 2|
Top 4th token. Logit: 14.09 Prob:  4.66% Token: | 11|
Top 5th token. Logit: 13.56 Prob:  2.75% Token: | 13|
Top 6th token. Logit: 13.46 Prob:  2.48% Token: | 112|
Top 7th token. Logit: 13.22 Prob:  1.95% Token: |
|
Top 8th token. Logit: 13.10 Prob:  1.73% Token: | 6|
Top 9th token. Logit: 13.08 Prob:  1.70% Token: | 101|


# in-context greater-than

In [20]:
example_prompt = "1 table 4. 10 table 100. 3 table 6. 101 table"
example_answer = " 102"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' table', ' 4', '.', ' 10', ' table', ' 100', '.', ' 3', ' table', ' 6', '.', ' 101', ' table']
Tokenized answer: [' 102']


Top 0th token. Logit: 15.96 Prob: 14.18% Token: | 200|
Top 1th token. Logit: 15.53 Prob:  9.25% Token: | 101|
Top 2th token. Logit: 15.36 Prob:  7.76% Token: | 100|
Top 3th token. Logit: 15.11 Prob:  6.02% Token: | 102|
Top 4th token. Logit: 15.02 Prob:  5.53% Token: | 300|
Top 5th token. Logit: 14.89 Prob:  4.85% Token: | 110|
Top 6th token. Logit: 14.67 Prob:  3.90% Token: | 150|
Top 7th token. Logit: 14.17 Prob:  2.37% Token: | 201|
Top 8th token. Logit: 14.10 Prob:  2.20% Token: | 500|
Top 9th token. Logit: 13.98 Prob:  1.95% Token: | 103|


In [21]:
example_prompt = "1 table 4. 10 table 100. 3 table 6. 1 table"
example_answer = " 102"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' table', ' 4', '.', ' 10', ' table', ' 100', '.', ' 3', ' table', ' 6', '.', ' 1', ' table']
Tokenized answer: [' 102']


Top 0th token. Logit: 15.31 Prob:  9.78% Token: | 4|
Top 1th token. Logit: 15.02 Prob:  7.29% Token: | 10|
Top 2th token. Logit: 14.93 Prob:  6.69% Token: | 6|
Top 3th token. Logit: 14.80 Prob:  5.90% Token: | 2|
Top 4th token. Logit: 14.80 Prob:  5.89% Token: | 5|
Top 5th token. Logit: 14.63 Prob:  4.94% Token: | 8|
Top 6th token. Logit: 14.58 Prob:  4.73% Token: | 3|
Top 7th token. Logit: 14.53 Prob:  4.48% Token: | 1|
Top 8th token. Logit: 14.25 Prob:  3.39% Token: | 12|
Top 9th token. Logit: 14.21 Prob:  3.26% Token: | 7|


In [22]:
example_prompt = "1 table 4. 10 table 100. 3 table 6. 4 table"
example_answer = " 102"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', '1', ' table', ' 4', '.', ' 10', ' table', ' 100', '.', ' 3', ' table', ' 6', '.', ' 4', ' table']
Tokenized answer: [' 102']


Top 0th token. Logit: 16.10 Prob:  7.74% Token: | 10|
Top 1th token. Logit: 15.93 Prob:  6.53% Token: | 6|
Top 2th token. Logit: 15.88 Prob:  6.20% Token: | 8|
Top 3th token. Logit: 15.81 Prob:  5.77% Token: | 12|
Top 4th token. Logit: 15.51 Prob:  4.28% Token: | 7|
Top 5th token. Logit: 15.47 Prob:  4.13% Token: | 4|
Top 6th token. Logit: 15.39 Prob:  3.79% Token: | 9|
Top 7th token. Logit: 15.35 Prob:  3.66% Token: | 11|
Top 8th token. Logit: 15.21 Prob:  3.17% Token: | 5|
Top 9th token. Logit: 15.13 Prob:  2.92% Token: | 100|
