In [None]:
%pip install git+https://github.com/redwoodresearch/Easy-Transformer.git
%pip install einops datasets transformers fancy_einsum

In [30]:
from copy import deepcopy
import torch

from tqdm import tqdm
import pandas as pd
import torch
import torch as t
from easy_transformer.EasyTransformer import (
    EasyTransformer,
)
from time import ctime
from functools import partial

import numpy as np
from tqdm import tqdm
import pandas as pd

from easy_transformer.experiments import (
    ExperimentMetric,
    AblationConfig,
    EasyAblation,
    EasyPatching,
    PatchingConfig,
)
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import random
import einops
from IPython import get_ipython
from copy import deepcopy
from easy_transformer.ioi_dataset import (
    IOIDataset,
)
from easy_transformer.ioi_utils import (
    path_patching,
    max_2d,
    CLASS_COLORS,
    show_pp,
    show_attention_patterns,
    scatter_attention_and_contribution,
)
from random import randint as ri
from easy_transformer.ioi_circuit_extraction import (
    do_circuit_extraction,
    get_heads_circuit,
    CIRCUIT,
)
from easy_transformer.ioi_utils import logit_diff, probs
from easy_transformer.ioi_utils import get_top_tokens_and_probs as g


import torch.nn.functional as F
from typing import List

In [106]:
model = EasyTransformer.from_pretrained("gpt2")
model.set_use_attn_result(True)

Moving model to device:  cpu
Finished loading pretrained model gpt2 into EasyTransformer!


## Task icl functions

In [91]:
class icl_dataset:
    def __init__(self, input, labels, N, max_len, device):
        self.input = input.to(device)
        self.labels = labels
        self.N = N
        self.max_len = max_len
        self.device = device

def generate_data(model, device, x_initial, y_initial, icl_length, n, offset=0):
    prompts = []
    correct_answers = []
    # Initialize x and y for the sequence
    x, y = x_initial + offset, y_initial + offset


    for i in range(n):
        prompt = ''
        for j in range(icl_length):
            if j < icl_length - 1:
                prompt += f"Input: {j}, Output: {j * x + y}\n"
            else:
                prompt += f"Input: {j}, Output:"
                correct_answers.append(j * x + y)  # Record the correct answer for the last input
        prompts.append(prompt)
        # Update x and y after generating each full prompt set
        if i % 2 == 0:
            x += 1
        else:
            y += 1


    # Convert prompts into tokens
    data_tokens = model.to_tokens(prompts).to(device)
    correct_answers_tensor = torch.tensor(correct_answers).to(torch.double).unsqueeze(-1).to(device)
    return icl_dataset(input=data_tokens, labels=correct_answers_tensor, N=n, max_len = data_tokens.shape[1], device = device)

def validation_metric(model, dataset, device='cpu', return_one_element = False):
        # dataset: {input: data, labels: correct, }
        logits = model(dataset.input, return_type="logits")


        # Select the logits for the last token in each sequence
        # model_output shape: [batch_size, seq_length, vocab_size] => [10, 103, 50257]
        # We select [:, -1, :] to get the last token logits for each example in the batch
        last_token_logits = logits[:, -1, :]  # Shape: [10, 50257]

        # Now, find the indices of the 10 highest logits for the last token across the batch
        # We use torch.topk to get the top 10 logits' indices for each example
        topk_values, topk_indices = torch.topk(last_token_logits, 1, dim=1)

        predictions = model.to_str_tokens(topk_indices)
        predictions = torch.tensor([int(pred) for pred in predictions]).to(torch.double).unsqueeze(-1).to(device)

        # Calculate MSE
        mse = F.mse_loss(predictions, dataset.labels, reduction='mean' if not return_one_element else 'sum')
        return mse

In [94]:
def test_validation_metric(device, model, x_initial, y_initial, icl_length, n):
        dataset = generate_data(model, device, x_initial, y_initial, icl_length, n, 10)
        mse = validation_metric(model, dataset)

        return mse
        print('This is the MSE: ', mse)
test_validation_metric('cpu', model, 2, 1, 12, 10)  # 3.2

tensor(61.4000, dtype=torch.float64)

# Ablating functions

In [101]:
def list_diff(l1, l2):
    #print(l1, l2)
    l2_ = [int(x) for x in l2]
    return list(set(l1).difference(set(l2_)))
def turn_keep_into_rmv(to_keep, max_len):
    to_rmv = {}
    for t in to_keep.keys():
        to_rmv[t] = []
        for idxs in to_keep[t]:
            #to_rmv[t].append(list_diff(list(range(max_len)), idxs))
            if idxs == []:
              to_rmv[t].append(list(range(max_len)))
            else:
              to_rmv[t].append([])
    return to_rmv
def process_heads_and_mlps(
    heads_to_remove=None,  # {(2,3) : List[List[int]]: dimensions dataset_size * datapoint_length
    mlps_to_remove=None,  # {2: List[List[int]]: dimensions dataset_size * datapoint_length
    heads_to_keep=None,  # as above for heads
    mlps_to_keep=None,  # as above for mlps
    ioi_dataset=None,
    model=None,
):
    assert (heads_to_remove is None) != (heads_to_keep is None)
    assert (mlps_to_keep is None) != (mlps_to_remove is None)

    n_layers = model.cfg.n_layers
    n_heads = model.cfg.n_heads

    dataset_length = ioi_dataset.max_len

    #commented out since I only want to remove attention
    if mlps_to_remove is not None:
        mlps = mlps_to_remove.copy()
    else:  # MARCO, if list of mlps to remove available just use, otherwise remove all not in 'to keep'. it do smart computation in mean cache
        mlps = mlps_to_keep.copy()
        for l in range(n_layers):
            if l not in mlps_to_keep:
                mlps[l] = [[] for _ in range(dataset_length)]
        mlps = turn_keep_into_rmv(
            mlps, ioi_dataset.max_len
        )  # TODO check that this is still right for the max_len of maybe shortened datasets

    # MARCO Same as MLP above
    if heads_to_remove is not None:
        heads = heads_to_remove.copy()
    else:
        heads = heads_to_keep.copy()

        for l in range(n_layers):
            for h in range(n_heads):
                if (l, h) not in heads_to_keep:
                    heads[(l, h)] = [[] for _ in range(dataset_length)]
        heads = turn_keep_into_rmv(heads, ioi_dataset.max_len)
    return heads, mlps
    # print(mlps, heads)

def get_circuit_replacement_hook(
    heads_to_remove=None,
    mlps_to_remove=None,
    heads_to_keep=None,
    mlps_to_keep=None,
    heads_to_remove2=None,  # TODO @Alex ehat are these
    mlps_to_remove2=None,
    heads_to_keep2=None,
    mlps_to_keep2=None,
    ioi_dataset=None,
    model=None,
):
    # MARCO function above, just get a list
    heads, mlps = process_heads_and_mlps(
        heads_to_remove=heads_to_remove,  # {(2,3) : List[List[int]]: dimensions dataset_size * datapoint_length
        mlps_to_remove=mlps_to_remove,  # {2: List[List[int]]: dimensions dataset_size * datapoint_length
        heads_to_keep=heads_to_keep,  # as above for heads
        mlps_to_keep=mlps_to_keep,  # as above for mlps
        ioi_dataset=ioi_dataset,
        model=model,
    )

    if (heads_to_remove2 is not None) or (heads_to_keep2 is not None):
        heads2, mlps2 = process_heads_and_mlps(
            heads_to_remove=heads_to_remove2,  # {(2,3) : List[List[int]]: dimensions dataset_size * datapoint_length
            mlps_to_remove=mlps_to_remove2,  # {2: List[List[int]]: dimensions dataset_size * datapoint_length
            heads_to_keep=heads_to_keep2,  # as above for heads
            mlps_to_keep=mlps_to_keep2,  # as above for mlps
            ioi_dataset=ioi_dataset,
            model=model,
        )
    else:
        heads2, mlps2 = heads, mlps

    dataset_length = ioi_dataset.N

    def circuit_replmt_hook(z, act, hook):  # batch, seq, heads, head dim
        layer = int(hook.name.split(".")[1])
        if False or "mlp" in hook.name and layer in mlps:
            for i in range(dataset_length):
                z[i, mlps[layer][i], :] = act[
                    i, mlps2[layer][i], :
                ]  # ablate all the indices in mlps[layer][i]; mean may contain semantic ablation
                # TODO can this i loop be vectorized?

        if "attn.hook_result" in hook.name and (layer, hook.ctx["idx"]) in heads:
            for i in range(
                dataset_length
            ):  # we use the idx from contex to get the head
                z[i, heads[(layer, hook.ctx["idx"])][i], :] = act[
                    i,
                    heads2[(layer, hook.ctx["idx"])][i],
                    :,
                ]

        return z

    return circuit_replmt_hook, heads, mlps

def do_circuit_extraction(
    heads_to_remove=None,  # {(2,3) : List[List[int]]: dimensions dataset_size * datapoint_length
    mlps_to_remove=None,  # {2: List[List[int]]: dimensions dataset_size * datapoint_length
    heads_to_keep=None,  # as above for heads
    mlps_to_keep=None,  # as above for mlps
    ioi_dataset=None,
    mean_dataset=None,
    model=None,
    metric=None,
    excluded=[],  # tuple of (layer, head) or (layer, None for MLPs)
    return_hooks=False,
    hooks_dict=False,
):
    """
    ..._to_remove means the indices ablated away. Otherwise the indices not ablated away.

    `exclude_heads` is a list of heads that actually we won't put any hooks on. Just keep them as is

    if `mean_dataset` is None, just use the ioi_dataset for mean
    """

    # check if we are either in keep XOR remove move from the args
    ablation, heads, mlps = get_circuit_replacement_hook(
        heads_to_remove=heads_to_remove,  # {(2,3) : List[List[int]]: dimensions dataset_size * datapoint_length
        mlps_to_remove=mlps_to_remove,  # {2: List[List[int]]: dimensions dataset_size * datapoint_length
        heads_to_keep=heads_to_keep,  # as above for heads
        mlps_to_keep=mlps_to_keep,  # as above for mlps
        ioi_dataset=ioi_dataset,
        model=model,
    )

    metric = ExperimentMetric(
        metric=metric, dataset=ioi_dataset.input, relative_metric=False
    )  # TODO make dummy metric

    if mean_dataset is None:
        mean_dataset = ioi_dataset

    config = AblationConfig(
        abl_type="custom",
        abl_fn=ablation,
        mean_dataset=mean_dataset.input.long(),
        target_module="attn_head",
        head_circuit="result",
        cache_means=True,  # circuit extraction *has* to cache means. the get_mean reset the
        verbose=True,
    )
    abl = EasyAblation(
        model,
        config,
        metric,
        semantic_indices=None,  # ioi_dataset.sem_tok_idx,
    )
    model.reset_hooks()

    hooks = {}

    heads_keys = list(heads.keys())
    # sort in lexicographic order
    heads_keys.sort(key=lambda x: (x[0], x[1]))

    for (
        layer,
        head,
    ) in heads_keys:  # a sketchy edit here didn't really improve things : (
        if (layer, head) in excluded:
            continue
        assert (layer, head) not in hooks, ((layer, head), "already in hooks")
        hooks[(layer, head)] = abl.get_hook(layer, head)
        # model.add_hook(*abl.get_hook(layer, head))
    for layer in mlps.keys():
        hooks[(layer, None)] = abl.get_hook(layer, head=None, target_module="mlp")
        # model.add_hook(*abl.get_hook(layer, head=None, target_module="mlp"))

    if return_hooks:
        if hooks_dict:
            return hooks
        else:
            return list(hooks.values())

    else:
        for hook in hooks.values():
            model.add_hook(*hook)
        return model, abl

def get_heads_circuit(ioi_dataset, excluded=[], mlp0=False, circuit=CIRCUIT):

    for excluded_thing in excluded:
        assert (
            isinstance(excluded_thing, tuple) or excluded_thing in circuit.keys()
        ), excluded_thing

    heads_to_keep = {}


    for circuit_class in circuit.keys():
        if circuit_class in excluded:
            continue
        for head in circuit[circuit_class]:
            if head in excluded:
                continue
            heads_to_keep[head] = list(range(ioi_dataset.input.shape[1]))

    if mlp0:
        raise ValueError("Arthur moved this to get_mlps_circuit")
    return heads_to_keep

# Check mse diff

In [107]:
circuit = {
  'operating': [(8, 1), (6, 9), (5, 0)]
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N = 10

# we make the ABC dataset in order to knockout other model components
# generate_data(model, device, x_initial, y_initial, icl_length, n, offset=0)
base_dataset = generate_data(model=model, device=device, x_initial=2, y_initial=1, icl_length=12, n=N, offset=0)
benchmark_mse = validation_metric(model, base_dataset)

patch_dataset = generate_data(model=model, device=device, x_initial=2, y_initial=1, icl_length=12, n=N, offset=10)
benchmark_patch_mse = validation_metric(model, patch_dataset)

# we then add hooks to the model to knockout all the heads except the circuit
model.reset_hooks()
ablated_model, _ = do_circuit_extraction(
    model=model,
    heads_to_keep=get_heads_circuit(ioi_dataset=base_dataset, circuit=circuit),
    mlps_to_remove={},
    ioi_dataset=base_dataset,
    mean_dataset=patch_dataset
)

circuit_mse = validation_metric(ablated_model, base_dataset)
print(
    f"The circuit has mse {circuit_mse} over {N} examples. Full model mse was {benchmark_mse}"
)

The circuit has mse 15570.5 over 10 examples. Full model mse was 3.2


# Try model

In [None]:
reference_text = """Google decreased revenue by 50% this quarter. // Sell
Apple increased revenue by 50% this quarter. // Buy
OpenAI increased revenue by 50% this quarter. // Buy
Tesla decreased revenue by 50% this quarter. // Sell
Audi increased revenue by 50% this quarter. // Buy
Microsoft decreased revenue by 50% this quarter. // Sell
Ford decreased decreased revenue by 40% this quarter. //"""

tokens = ablated_model.to_tokens(reference_text)
print(tokens)
print(tokens.shape)
print(ablated_model.to_str_tokens(tokens))


logits, cache = ablated_model.run_with_cache(tokens)
list(zip(ablated_model.to_str_tokens(reference_text), ablated_model.tokenizer.batch_decode(logits.argmax(dim=-1)[0])))

In [22]:
validation_metric(model, base_dataset)

tensor(3.2000, dtype=torch.float64)

In [19]:
model.reset_hooks()

In [21]:
validation_metric(model, patch_dataset)

tensor(61.4000, dtype=torch.float64)

In [23]:
ablated_model.reset_hooks()

# Now Try IOI

In [36]:
RELEVANT_TOKENS = {}
for head in CIRCUIT["name mover"] + CIRCUIT["negative"] + CIRCUIT["s2 inhibition"]:
    RELEVANT_TOKENS[head] = ["end"]

for head in CIRCUIT["induction"]:
    RELEVANT_TOKENS[head] = ["S2"]

for head in CIRCUIT["duplicate token"]:
    RELEVANT_TOKENS[head] = ["S2"]

for head in CIRCUIT["previous token"]:
    RELEVANT_TOKENS[head] = ["S+1"]


CIRCUIT = {
    "name mover": [
        (9, 9),  # by importance
        (10, 0),
        (9, 6),
        (10, 10),
        (10, 6),
        (10, 2),
        (10, 1),
        (11, 2),
        (9, 7),
        (9, 0),
        (11, 9),
    ],
    "negative": [(10, 7), (11, 10)],
    "s2 inhibition": [(7, 3), (7, 9), (8, 6), (8, 10)],
    "induction": [(5, 5), (5, 8), (5, 9), (6, 9)],
    "duplicate token": [
        (0, 1),
        (0, 10),
        (3, 0),
        # (7, 1),
    ],  # unclear exactly what (7,1) does
    "previous token": [
        (2, 2),
        # (2, 9),
        (4, 11),
        # (4, 3),
        # (4, 7),
        # (5, 6),
        # (3, 3),
        # (3, 7),
        # (3, 6),
    ],
}
from easy_transformer.ioi_dataset import IOIDataset

# we make the ABC dataset in order to knockout other model components

ioi_dataset = IOIDataset(
    prompt_type="mixed",
    N=50,
    tokenizer=model.tokenizer,
    prepend_bos=False,
)  # TODO make this a seeded dataset

abc_dataset = (  # TODO seeded
    ioi_dataset.gen_flipped_prompts(("IO", "RAND"))
    .gen_flipped_prompts(("S", "RAND"))
    .gen_flipped_prompts(("S1", "RAND"))
)
# we then add hooks to the model to knockout all the heads except the circuit
model.reset_hooks()
model, _ = do_circuit_extraction(
    model=model,
    heads_to_keep=get_heads_circuit(ioi_dataset=ioi_dataset, circuit=CIRCUIT),
    mlps_to_remove={},
    ioi_dataset=ioi_dataset,
    mean_dataset=abc_dataset,
)

circuit_logit_diff = logit_diff(model, ioi_dataset)
print(
    f"The circuit gets average logit difference {circuit_logit_diff.item()} over {50} examples"
)



The circuit gets average logit difference 2.97226619720459 over 50 examples


In [38]:
logit_diff(model, ioi_dataset)

tensor(2.9723)

In [39]:
reference_text = """Google decreased revenue by 50% this quarter. // Sell
Apple increased revenue by 50% this quarter. // Buy
OpenAI increased revenue by 50% this quarter. // Buy
Tesla decreased revenue by 50% this quarter. // Sell
Audi increased revenue by 50% this quarter. // Buy
Microsoft decreased revenue by 50% this quarter. // Sell
Ford decreased decreased revenue by 40% this quarter. //"""

tokens = model.to_tokens(reference_text)
print(tokens)
print(tokens.shape)
print(model.to_str_tokens(tokens))


logits, cache = model.run_with_cache(tokens)
list(zip(model.to_str_tokens(reference_text), model.tokenizer.batch_decode(logits.argmax(dim=-1)[0])))

tensor([[50256, 11708, 11832,  6426,   416,  2026,     4,   428,  3860,    13,
          3373, 25688,   198, 16108,  3220,  6426,   416,  2026,     4,   428,
          3860,    13,  3373, 11763,   198, 11505, 20185,  3220,  6426,   416,
          2026,     4,   428,  3860,    13,  3373, 11763,   198, 41351, 11832,
          6426,   416,  2026,     4,   428,  3860,    13,  3373, 25688,   198,
         16353,    72,  3220,  6426,   416,  2026,     4,   428,  3860,    13,
          3373, 11763,   198, 15905, 11832,  6426,   416,  2026,     4,   428,
          3860,    13,  3373, 25688,   198, 37308, 11832, 11832,  6426,   416,
          2319,     4,   428,  3860,    13,  3373]])
torch.Size([1, 86])
['<|endoftext|>', 'Google', ' decreased', ' revenue', ' by', ' 50', '%', ' this', ' quarter', '.', ' //', ' Sell', '\n', 'Apple', ' increased', ' revenue', ' by', ' 50', '%', ' this', ' quarter', '.', ' //', ' Buy', '\n', 'Open', 'AI', ' increased', ' revenue', ' by', ' 50', '%', ' this', ' qua

[('<|endoftext|>', '\n'),
 ('Google', ' has'),
 (' decreased', ' its'),
 (' revenue', ' by'),
 (' by', ' $'),
 (' 50', '%'),
 ('%', ' in'),
 (' this', ' year'),
 (' quarter', ','),
 ('.', '\n'),
 (' //', ' Google'),
 (' Sell', ' your'),
 ('\n', '\n'),
 ('Apple', "'s"),
 (' increased', ' revenue'),
 (' revenue', ' by'),
 (' by', ' 50'),
 (' 50', '%'),
 ('%', ' this'),
 (' this', ' quarter'),
 (' quarter', '.'),
 ('.', ' //'),
 (' //', ' Sell'),
 (' Buy', '\n'),
 ('\n', 'Google'),
 ('Open', ' source'),
 ('AI', ' increased'),
 (' increased', ' revenue'),
 (' revenue', ' by'),
 (' by', ' 50'),
 (' 50', '%'),
 ('%', ' this'),
 (' this', ' quarter'),
 (' quarter', '.'),
 ('.', ' //'),
 (' //', ' Sell'),
 (' Buy', '\n'),
 ('\n', 'Google'),
 ('Tesla', ' increased'),
 (' decreased', ' revenue'),
 (' revenue', ' by'),
 (' by', ' 50'),
 (' 50', '%'),
 ('%', ' this'),
 (' this', ' quarter'),
 (' quarter', '.'),
 ('.', ' //'),
 (' //', ' Buy'),
 (' Sell', '\n'),
 ('\n', '\n'),
 ('Aud', 'ience'),
 (