<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/minimal_circuit_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup
(No need to change anything)

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-qhef8iim
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-qhef8iim
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 3f929b1d142b8f82bfbb8ae30e69bab7f76cadf3
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
# import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7edd2f1c5840>

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Load Model

Decide which model to use (eg. gpt2-small vs -medium)

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Import functions from repo

In [None]:
!git clone https://github.com/callummcdougall/ARENA_2.0.git

Cloning into 'ARENA_2.0'...
remote: Enumerating objects: 9063, done.[K
remote: Counting objects: 100% (9063/9063), done.[K
remote: Compressing objects: 100% (3540/3540), done.[K
remote: Total 9063 (delta 5508), reused 8890 (delta 5425), pack-reused 0[K
Receiving objects: 100% (9063/9063), 155.49 MiB | 14.22 MiB/s, done.
Resolving deltas: 100% (5508/5508), done.


In [None]:
cd ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification

/content/ARENA_2.0/chapter1_transformers/exercises/part3_indirect_object_identification


In [None]:
import ioi_circuit_extraction as ioi_circuit_extraction

# Generate dataset with multiple prompts

In [None]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["ZZ"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["YY"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'ZZ')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                # target_token = "Ġ" + "17" + prompt[targ]  # does it break 1711 into mulp tokens?
                target_token = "Ġ" + prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

Replace io_tokens with correct answer (next, which is '5') and s_tokens with incorrect (current, which repeats)

In [None]:
def generate_prompts_list(i):
    prompts_list = []
    # for i in range(x, y):
    prompt_dict = {
        'YY': str(i),
        'ZZ': str(i+1),
        # 'text': f"The war lasted from the year 17{i} to the year 17"
        'text': f"The war lasted from the year {i} to the year"
    }
    prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(11)
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

In [None]:
def generate_prompts_list_corr(i):
    prompts_list = []
    # for i in range(x, y):
    prompt_dict = {
        'YY': '1',
        # 'YY': '01', #won't work b/c 01 is not a token, it breaks into 0 and 1 or just one 1701
        # 'ZZ': str(i+1),
        'ZZ': '2', # it doesn't matter what this is when calculating logit diff, as that only uses clean
        # 'text': f"The war lasted from the year 1701 to the year 17"
        'text': f"The war lasted from the year 1 to the year"
    }
    prompts_list.append(prompt_dict)
    return prompts_list

# if we use more than 1 prompt, the texts (by batch ind) must align in clean and corrupted

prompts_list_2 = generate_prompts_list_corr(11)
dataset_2 = Dataset(prompts_list_2, model.tokenizer, S1_is_first=True)

Logit diff is correct - incorr token of CLEAN (so it doesn't matter what the designated 'answer' of corr is stated as, you will get the same answer). Here, correct is S5, and incorr is S4.

Because of this, it's possible to have logit diffs HIGHER than the "full circuit" because the correct token will still be at first place, but the logit scores assigned will just be bigger (perhaps incorrect is scored even lower in the non-full circuit with a higher logit diff score)?

In [None]:
# template = "[S1] [S2] [S3] [S4]"
# prompts_list = [{'S1': '1', 'S2': '2', 'S3': '3', 'S4': '4', 'S5': '5', 'text': '1 2 3 4'}]

# Ablation Expm Functions

In [None]:
# YY_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.YY_tokenIDs]

In [None]:
# from torch import Tensor

# def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
#     '''
#     Returns logit difference between the correct and incorrect answer.
#     If per_prompt=True, return the array of differences rather than the average.
#     '''

#     # Get the right logits; anything greater than YY
#     # range(logits.size(0)) for every input in the batch
#     # dataset.word_idx["end"]: at the last pos, so "what's the next prediction after end?"
#     # what's the logit of the io token (whose pos at an input seq is recorded in the dataset by dataset.YY_tokenIDs)
#     YY_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.YY_tokenIDs]

#     # int(prompts[input_ind]['YY'])  # YY in an input
#     # we want the first token > int(prompts[batch_ind]['YY'])
#     # search thru entire vocab space until find token > int(prompts[input_ind]['YY'])
#     # how do we convert index in vocab space to the token it represents?


#     greater_than_Y_idx = (io_logits > Y).nonzero(as_tuple=True)[0].item()
#     first_token_greater_than_Y = dataset.io_tokenIDs[greater_than_Y_idx]


#     # get the wrong logits; anything less than YY
#     s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
#     # Find logit difference
#     answer_logit_diff = io_logits - s_logits
#     return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [None]:
from torch import Tensor

def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], dataset: Dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # sum up all tokens between YY and 99, minus sum of all YY and 00

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), dataset.word_idx["end"], dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [None]:
def mean_ablate_by_lst(lst, model, print_output=True):
    CIRCUIT = {
        "number mover": lst,
        "number mover 2": lst,
    }

    SEQ_POS_TO_KEEP = {
        "number mover": "end",
        "number mover 2": "YY",
    }

    model.reset_hooks(including_permanent=True)  #must do this after running with mean ablation hook

    ioi_logits_original, ioi_cache = model.run_with_cache(dataset.toks)

    model = ioi_circuit_extraction.add_mean_ablation_hook(model, means_dataset=dataset_2, circuit=CIRCUIT, seq_pos_to_keep=SEQ_POS_TO_KEEP)
    ioi_logits_minimal = model(dataset.toks)

    orig_score = logits_to_ave_logit_diff_2(ioi_logits_original, dataset)
    new_score = logits_to_ave_logit_diff_2(ioi_logits_minimal, dataset)
    if print_output:
        # print(f"Average logit difference (IOI dataset, using entire model): {orig_score:.4f}")
        # print(f"Average logit difference (IOI dataset, only using circuit): {new_score:.4f}")
        print(f"Average logit difference (circuit / full) %: {100 * new_score / orig_score:.4f}")
    return new_score

# Ablate the model and compare with original

## Check how Greater-Than circuit performs here

See how greater-than circuit performs on the greater-than task; it should be similar to the paper. Else, either greater-than paper has issues (less likely) or this mean ablation code/setup was not generalized correctly (more likely).

In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (7, 11), (9,1)]

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 96.1178


It did it get right? Let's try an incompelte circuit for sanity check.

In [None]:
greater_than = [(0, 1), (0, 3), (9,1)]

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 87.0366


In [None]:
greater_than = [(0, 1)]

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 48.5145


These scores are inflated, so it's not accurate; it's just an error in the code.

In [None]:
greater_than = []

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 48.5081


We see that the corrupted dataset is bad; the difference of 1 to i+1 doesn't work. We need to fix that error (see v2 of this nb)

# try 1711 as corr, 1750 as clean

# get rid of "Ġ" +

In [None]:
model.tokenizer("1711")

{'input_ids': [1558, 1157], 'attention_mask': [1, 1]}

In [None]:
model.tokenizer.tokenize("1711")

['17', '11']

B/c 11 is after 17 and doesn't have a space in front, get rid of "Ġ" +

In [None]:
class Dataset:
    def __init__(self, prompts, tokenizer, S1_is_first=False):
        self.prompts = prompts
        self.tokenizer = tokenizer
        self.N = len(prompts)
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.prompts
            ]
        )
        # all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids = [0 for prompt in self.prompts] # only 1 template
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        texts = [ prompt["text"] for prompt in self.prompts ]
        self.toks = torch.Tensor(self.tokenizer(texts, padding=True).input_ids).type(
            torch.int
        )
        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["ZZ"])[0] for prompt in self.prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["YY"])[0] for prompt in self.prompts
        ]

        # word_idx: for every prompt, find the token index of each target token and "end"
        # word_idx is a tensor with an element for each prompt. The element is the targ token's ind at that prompt
        self.word_idx = {}
        for targ in [key for key in self.prompts[0].keys() if (key != 'text' and key != 'ZZ')]:
            targ_lst = []
            for prompt in self.prompts:
                input_text = prompt["text"]
                tokens = model.tokenizer.tokenize(input_text)
                # if S1_is_first and targ == "S1":  # only use this if first token doesn't have space Ġ in front
                #     target_token = prompt[targ]
                # else:
                # target_token = "Ġ" + "17" + prompt[targ]  # does it break 1711 into mulp tokens?
                # target_token = "Ġ" + prompt[targ]
                target_token = prompt[targ]
                target_index = tokens.index(target_token)
                targ_lst.append(target_index)
            self.word_idx[targ] = torch.tensor(targ_lst)

        targ_lst = []
        for prompt in self.prompts:
            input_text = prompt["text"]
            tokens = self.tokenizer.tokenize(input_text)
            end_token_index = len(tokens) - 1
            targ_lst.append(end_token_index)
        self.word_idx["end"] = torch.tensor(targ_lst)

    def __len__(self):
        return self.N

In [None]:
def generate_prompts_list(i):
    prompts_list = []
    # for i in range(x, y):
    prompt_dict = {
        'YY': str(i),
        'ZZ': str(i+1),
        'text': f"The war lasted from the year 17{i} to the year 17"
        # 'text': f"The war lasted from the year {i} to the year"
    }
    prompts_list.append(prompt_dict)
    return prompts_list

# prompts_list = generate_prompts_list(1711)
prompts_list = generate_prompts_list(11)
dataset = Dataset(prompts_list, model.tokenizer, S1_is_first=True)

In [None]:
def generate_prompts_list_corr(i):
    prompts_list = []
    # for i in range(x, y):
    prompt_dict = {
        # 'YY': '1',
        'YY': '01',
        # 'ZZ': str(i+1),
        'ZZ': '2', # it doesn't matter what this is when calculating logit diff, as that only uses clean
        'text': f"The war lasted from the year 1701 to the year 17"
        # 'text': f"The war lasted from the year 1 to the year"
    }
    prompts_list.append(prompt_dict)
    return prompts_list

# if we use more than 1 prompt, the texts (by batch ind) must align in clean and corrupted

prompts_list_2 = generate_prompts_list_corr(11)
dataset_2 = Dataset(prompts_list_2, model.tokenizer, S1_is_first=True)

In [None]:
greater_than = []

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 38.1054


In [None]:
greater_than = [(0, 1), (0, 3), (0, 5), (5, 5), (6, 1), (6, 9), (7, 10), (7, 11), (9,1)]

new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

Average logit difference (circuit / full) %: 98.8186


In [None]:
import random

num_of_tuples = 9  # Number of tuples you want
greater_than = [(random.randint(0, 9), random.randint(0, 9)) for _ in range(num_of_tuples)]

print(greater_than)
new_score = mean_ablate_by_lst(greater_than, model, print_output=True)

[(7, 7), (1, 4), (1, 9), (0, 1), (8, 5), (6, 4), (7, 8), (1, 4), (9, 7)]
Average logit difference (circuit / full) %: 37.2735


So it's better than random.