# Setup

In [1]:
%load_ext autoreload
%autoreload 2

https://arxiv.org/pdf/2305.00586


In [1]:
from transformer_lens import HookedTransformer
import random
import sys
import os
import collections
import operator
import functools
import itertools


base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)
from pyfunctions.general import compare_same
from pyfunctions.cdt_basic import *
from pyfunctions.cdt_source_to_target import *
from pyfunctions.cdt_from_source_nodes import *
from pyfunctions.toy_model import *
from greater_than_task.greater_than_dataset import *
from greater_than_task.utils import get_valid_years

import seaborn as sns
import matplotlib.pyplot as plt


import torch
Result = collections.namedtuple('Result', ('ablation_set', 'score'))




## Load model and dataset


In [2]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch.autograd.set_grad_enabled(False)

from transformer_lens import utils, HookedTransformer, ActivationCache
model = HookedTransformer.from_pretrained("gpt2-small",
                                          center_unembed=True,
                                          center_writing_weights=True,
                                          fold_ln=False,
                                          refactor_factored_attn_matrices=True)
                                          

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Loaded pretrained model gpt2-small into HookedTransformer


In [275]:
# https://github.com/hannamw/gpt2-greater-than/blob/main/circuit_discovery.py; also these files came with their repo
years_to_sample_from = get_valid_years(model.tokenizer, 1000, 1900)
N = 5000
ds = YearDataset(years_to_sample_from, N, Path("../greater_than_task/cache/potential_nouns.txt"), model.tokenizer, balanced=True, device=device, eos=True)
year_indices = torch.load("../greater_than_task/cache/logit_indices.pt")# .to(device)

num_layers = len(model.blocks)
seq_len = ds.good_toks.size()[-1]
num_attention_heads = model.cfg.n_heads

## Exploration

In [9]:
type(model.tokenizer)

transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast

In [147]:
# print(ds)
'''
These guys weirdly implemented all the functionality of their class in class-level attributes
years_to_sample_from: torch.Tensor
    N: int
    ordered: bool
    eos: bool

    nouns: List[str]
    years: torch.Tensor
    years_YY: torch.Tensor
    good_sentences: List[str]
    bad_sentences: List[str]
    good_toks: torch.Tensor
    bad_toks: torch.Tensor
    good_prompt: List[str]
    bad_prompt: List[str]
    good_mask: torch.Tensor
    tokenizer: PreTrainedTokenizer
    '''

# ds.N
# ds.nouns
# print(ds.years[:20]) # not sorted by XX for some reason
# print(ds.years_YY[:]) # but does correspond to these YYs, which are mostly sorted
print(ds.good_sentences[-10:]) # includes The endeavor lasted from the year 1098 to the year 10', but 1099 isn't in the list of years?
# note: we want prediction at the last token, unlike with the IOI dataset where we want second-to-last
# i checked and there is no internal logic to prevent such sentences from being produced, so i guess we're SOL if we sample one?
# print(ds.bad_sentences[-10:]) # these all start with 01, e.g 1601 to. they're bad because there is no possible incorrect input
# print(ds.good_mask.size()) # n, 100 (100 different years)
# print(ds.good_toks.size()) # n, 13
# print(ds.bad_toks.size()) # there isn't any necessary correspondence, N is just the number of good sequences and bad sequences alike
# list(ds.years.cpu().numpy()).index(1099)
print(year_indices)
print(model.tokenizer.convert_ids_to_tokens(year_indices)) # length 100, starts with index for '00' and ends with index for '99', great
# print(model.tokenizer.decode(year_indices, clean_up_tokenization_spaces=False))

['<|endoftext|> The clash lasted from the year 1594 to the year 15', '<|endoftext|> The program lasted from the year 1395 to the year 13', '<|endoftext|> The challenge lasted from the year 1496 to the year 14', '<|endoftext|> The confrontation lasted from the year 1597 to the year 15', '<|endoftext|> The marriage lasted from the year 1098 to the year 10', '<|endoftext|> The journey lasted from the year 1202 to the year 12', '<|endoftext|> The insurgency lasted from the year 1803 to the year 18', '<|endoftext|> The improvement lasted from the year 1404 to the year 14', '<|endoftext|> The consultation lasted from the year 1705 to the year 17', '<|endoftext|> The domination lasted from the year 1606 to the year 16']
tensor([ 405,  486, 2999, 3070, 3023, 2713, 3312, 2998, 2919, 2931,  940, 1157,
        1065, 1485, 1415, 1314, 1433, 1558, 1507, 1129, 1238, 2481, 1828, 1954,
        1731, 1495, 2075, 1983, 2078, 1959, 1270, 3132, 2624, 2091, 2682, 2327,
        2623, 2718, 2548, 2670, 1821,

In [29]:
ds.good_mask[1]

tensor([False, False, False, False,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
       device='cuda:0')

In [30]:
ds.good_sentences[1]

'<|endoftext|> The attempts lasted from the year 1603 to the year 16'

In [41]:
ds.good_mask[0]

tensor([False, False, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
       device='cuda:0')

The task specific objective here is something like (sum of probabilities)

## Setup attention mask and mean activations for ablation

In [10]:
attention_mask = torch.tensor([1 for x in range(seq_len)]).view(1, -1).to(device)
input_shape = ds.good_toks[0:1, :].size() # by making the sample size 1, you can get an extended attention mask with batch size 1, which will broadcast
extended_attention_mask = get_extended_attention_mask(attention_mask, 
                                                        input_shape, 
                                                        model,
                                                        device)



In [20]:
del logits
del cache
import gc
gc.collect()
model.cfg.use_attn_result = False

In [21]:
logits, cache = model.run_with_cache(ds.good_toks) # run on entire dataset along batch dimension

attention_outputs = [cache['blocks.' + str(i) + '.attn.hook_z'] for i in range(num_attention_heads)]
attention_outputs = torch.stack(attention_outputs, dim=1) # now batch, layer, seq, n_heads, dim_attn
print(attention_outputs.shape)
mean_acts = torch.mean(attention_outputs, dim=0)
old_shape = mean_acts.shape
last_dim = old_shape[-2] * old_shape[-1]
new_shape = old_shape[:-2] + (last_dim,)
mean_acts = mean_acts.view(new_shape)
mean_acts.shape

torch.Size([490, 12, 13, 12, 64])


torch.Size([12, 13, 768])

In [23]:
# quick check for equality, particularly to make sure we've made the attention mask correctly
ranges = [
        [layer for layer in range(num_layers)],
        [sequence_position for sequence_position in range(seq_len)],
        [attention_head_idx for attention_head_idx in range(num_attention_heads)]
    ]

source_nodes = [Node(*x) for x in itertools.product(*ranges)]
ablation_sets = [(n,) for n in source_nodes]
target_nodes = []
out_decomp, _, _, _ = prop_GPT(ds.good_toks[0:1, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=None, set_irrel_to_mean=False)

logits, cache = model.run_with_cache(ds.good_toks[0])

compare_same(out_decomp[0].rel + out_decomp[0].irrel, logits)

# Loose experiments

In [294]:
import random
NUM_SAMPLES = 100
sample_idxs = random.sample(range(N), NUM_SAMPLES) # you actually have to sample randomly from this dataset because they are arranged in increasing order of YY token
# sample_idxs

In [None]:
print(sample_idxs)


[172, 392, 73, 394, 157, 273, 369, 200, 402, 373, 202, 127, 163, 365, 186, 326, 124, 438, 227, 129]


In [62]:
print (ds.good_sentences[88])

<|endoftext|> The pursuit lasted from the year 1290 to the year 12


In [61]:
example_prompt = ds.good_sentences[88] # GPT2 doesn't always perform this task correctly, only about 99% of the time.
# On example input <|endoftext|> The pursuit lasted from the year 1290 to the year 12 , the top prediction is '90'.
example_answer = '03'

transformer_lens.utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', '<|endoftext|>', ' The', ' pursuit', ' lasted', ' from', ' the', ' year', ' 12', '90', ' to', ' the', ' year', ' 12']
Tokenized answer: ['03']


Top 0th token. Logit: 25.19 Prob: 22.77% Token: |90|
Top 1th token. Logit: 24.12 Prob:  7.78% Token: |99|
Top 2th token. Logit: 23.77 Prob:  5.51% Token: |94|
Top 3th token. Logit: 23.73 Prob:  5.31% Token: |95|
Top 4th token. Logit: 23.63 Prob:  4.81% Token: |92|
Top 5th token. Logit: 23.30 Prob:  3.44% Token: |60|
Top 6th token. Logit: 23.22 Prob:  3.18% Token: |98|
Top 7th token. Logit: 23.18 Prob:  3.06% Token: |96|
Top 8th token. Logit: 23.14 Prob:  2.94% Token: |50|
Top 9th token. Logit: 23.10 Prob:  2.82% Token: |91|


In [288]:
# This is not a pure function. It depends on ds.good_mask, sample_idxs, and year_indices.
def score_logits(logits, sample_idxs_0):
    probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad
    probs_for_year_tokens = probs[:, year_indices.cpu().numpy()]
    probs_for_correct_years = probs_for_year_tokens[ds.good_mask.cpu().numpy()[sample_idxs_0]]
    correct_score = np.sum(probs_for_correct_years)
    probs_for_incorrect_years = probs_for_year_tokens[np.logical_not(ds.good_mask.cpu().numpy()[sample_idxs_0])]
    incorrect_score = np.sum(probs_for_incorrect_years)
    return (correct_score - incorrect_score) / len(sample_idxs_0)


In [295]:
model.reset_hooks(including_permanent=True)

mean_acts = mean_acts.view(new_shape)
'''
ranges = [
        [layer for layer in range(num_layers)],
        [sequence_position for sequence_position in range(seq_len)],
        [attention_head_idx for attention_head_idx in range(num_attention_heads)]
    ]

source_nodes = [Node(*x) for x in itertools.product(*ranges)]
ablation_sets = [(n,) for n in source_nodes]
'''
ablation_sets = []
for layer in range(num_layers):
    for head_idx in range(num_attention_heads):
        ablation_sets.append(tuple(Node(layer, seq_pos, head_idx) for seq_pos in range(seq_len)))
target_nodes = []

# cache activations for faster batch run
out_decomp, _, _, pre_layer_activations = prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)

prop_fn = lambda ablation_list: prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, ablation_list, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True, cached_pre_layer_acts=pre_layer_activations)
out_decomps, target_decomps = batch_run(prop_fn, ablation_sets, num_at_time=(max(64 // len(sample_idxs), 1)))

running input 0


In [291]:
def compute_logits_decomposition_scores(out_decomps, sample_idxs, normalized=False):
    logits = (out_decomps[0].rel + out_decomps[0].irrel) # 1, seq_len, 50257=d_vocab
    full_score = score_logits(logits, sample_idxs)
    assert(full_score > 0) # this needs to be replaced with a check higher in the pipeline; GPT2 succeeds at this like 99%+ of the time but not always
    
    results = []
    relevances = np.zeros((num_layers, num_attention_heads))

    for layer_idx in range(num_layers):

        for head_idx in range(num_attention_heads):
            decomp = out_decomps[layer_idx * num_attention_heads + head_idx]
            score = score_logits(decomp.rel, sample_idxs)
            norm_score = score / full_score
            relevances[layer_idx, head_idx] = norm_score
            if not normalized:
                results.append(Result(decomp.ablation_set, norm_score))
    if normalized:
        sums_per_layer = np.sum(np.abs(relevances), axis=(1))
        print(sums_per_layer)

        sums_per_layer[sums_per_layer == 0] = -1e-8
        relevances = relevances / np.expand_dims(sums_per_layer, (1))
        for layer_idx in range(num_layers):
            for head_idx in range(num_attention_heads):
                target_decomp = target_decomps[layer_idx * num_attention_heads + head_idx]
                results.append(Result(target_decomp.ablation_set, relevances[layer_idx, head_idx]))
    results.sort(key=operator.attrgetter('score'), reverse=True)

    return results, relevances

In [240]:

def compute_logits_decomposition_scores(out_decomps, sample_idxs, normalized=False):
    logits = (out_decomps[0].rel + out_decomps[0].irrel) # 1, seq_len, 50257=d_vocab
    full_score = score_logits(logits, sample_idxs)
    assert(full_score > 0) # this needs to be replaced with a check higher in the pipeline; GPT2 succeeds at this like 99%+ of the time but not always
    
    results = []
    relevances = np.zeros((num_layers, seq_len, num_attention_heads))

    for layer_idx in range(num_layers):
        for seq_pos in range(seq_len):
            for head_idx in range(num_attention_heads):
                decomp = out_decomps[layer_idx * seq_len * num_attention_heads + seq_pos * num_attention_heads + head_idx]
                score = score_logits(decomp.rel, sample_idxs)
                norm_score = score / full_score
                relevances[layer_idx, seq_pos, head_idx] = norm_score
                if not normalized:
                    results.append(Result(decomp.ablation_set, norm_score))
    if normalized:
        sums_per_layer = np.sum(relevances, axis=(1, 2))
        print(sums_per_layer)
        sums_per_layer = np.abs(sums_per_layer)

        sums_per_layer[sums_per_layer == 0] = -1e-8
        relevances = relevances / np.expand_dims(sums_per_layer, (1, 2))
        for layer_idx in range(num_layers):
            for seq_pos in range(seq_len):
                for head_idx in range(num_attention_heads):
                    target_decomp = target_decomps[layer_idx * seq_len * num_attention_heads + seq_pos * num_attention_heads + head_idx]
                    results.append(Result(target_decomp.ablation_set, relevances[layer_idx, seq_pos, head_idx]))
    results.sort(key=operator.attrgetter('score'), reverse=True)

    return results, relevances



In [296]:
# results = compute_logits_decomposition_scores(out_decomps)
results, relevances = compute_logits_decomposition_scores(out_decomps, sample_idxs, normalized=True)

results.sort(key=operator.attrgetter('score'), reverse=True)
for result in results[:20]:
    # print(result)
    print(result.ablation_set[0], result.score)
'''
 a9.h1, while
MLP 8 relies on a8.h11, a8.h8, a7.h10, a6.h9, a5.h5, and a5.h1

(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)
'''


[0.03981848 0.01361028 0.00921484 0.00663896 0.00468989 0.00379394
 0.00245998 0.00181219 0.00163173 0.00281442 0.00125149 0.00022332]
Node(layer_idx=9, sequence_idx=0, attn_head_idx=1) 0.8089691831315646
Node(layer_idx=10, sequence_idx=0, attn_head_idx=4) 0.5720321180711452
Node(layer_idx=7, sequence_idx=0, attn_head_idx=10) 0.4250676382511792
Node(layer_idx=11, sequence_idx=0, attn_head_idx=8) 0.3279695854853505
Node(layer_idx=10, sequence_idx=0, attn_head_idx=7) 0.2508779920846017
Node(layer_idx=8, sequence_idx=0, attn_head_idx=8) 0.22523116602251275
Node(layer_idx=6, sequence_idx=0, attn_head_idx=9) 0.20666289500310217
Node(layer_idx=8, sequence_idx=0, attn_head_idx=11) 0.18331776048103574
Node(layer_idx=8, sequence_idx=0, attn_head_idx=10) 0.14820980889046276
Node(layer_idx=4, sequence_idx=0, attn_head_idx=3) 0.14483214622315416
Node(layer_idx=6, sequence_idx=0, attn_head_idx=7) 0.1270143676209398
Node(layer_idx=2, sequence_idx=0, attn_head_idx=1) 0.1268682623768458
Node(layer_idx

'\n a9.h1, while\nMLP 8 relies on a8.h11, a8.h8, a7.h10, a6.h9, a5.h5, and a5.h1\n\n(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)\n'

In [88]:
print(relevances[9, 12, 1])

-0.013245150161018487


In [97]:
print(relevances[10, 12, 4]) # 0.000484417607102171

0.006119984588483005


In [170]:
len(sample_idxs)

20

In [225]:
model.reset_hooks(including_permanent=True)

mean_acts = mean_acts.view(new_shape)
'''
target_nodes = [Node(9, 12, 1), Node(10, 12, 4)] # (10, 12, 7), (7, 12, 10)
ranges = [
        [layer for layer in range(num_layers)],
        [sequence_position for sequence_position in range(seq_len)],
        # [ioi_dataset.word_idx['IO'][0]],
        [attention_head_idx for attention_head_idx in range(num_attention_heads)]
    ]
source_nodes = [Node(*x) for x in itertools.product(*ranges)]
ablation_sets = [(n,) for n in source_nodes]
'''
ablation_sets = []
for layer in range(num_layers):
    for head_idx in range(num_attention_heads):
        ablation_sets.append(tuple(Node(layer, seq_pos, head_idx) for seq_pos in range(seq_len)))
target_nodes = []
for layer, head_idx in [(9, 1), (10, 4)]:
    for seq_pos in range(seq_len):
        target_nodes.append(Node(layer, seq_pos, head_idx))

_, _, _, pre_layer_activations = prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)

prop_fn = lambda ablation_list: prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, ablation_list, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True, cached_pre_layer_acts=pre_layer_activations)
out_decomps, target_decomps = batch_run(prop_fn, ablation_sets, num_at_time=max(64 // len(sample_idxs), 1))

running input 0


In [231]:
def calculate_target_decomposition_scores(target_decomps, normalized=False):
    results = []
    relevances = np.zeros((num_layers, num_attention_heads))
    for layer_idx in range(num_layers):
        for head_idx in range(num_attention_heads):
            idx = layer_idx * num_attention_heads + head_idx
            target_decomp = target_decomps[idx]
            if target_decomp.ablation_set[0] in target_nodes:
                continue
            score = 0
            for target_node_idx in range(len(target_decomp.target_nodes)):
                for batch_idx in range(len(target_decomp.rels)):
                    rels_magnitude = torch.mean(abs(target_decomp.rels[target_node_idx])) # np.mean if you are on cpu
                    irrels_magnitude = torch.mean(abs(target_decomp.irrels[batch_idx])) # np.mean if you are on cpu
                    target_node_score = rels_magnitude / (rels_magnitude + irrels_magnitude)
                    score += target_node_score
            if score != 0:
                score /= len(target_decomp.rels)

            relevances[layer_idx, head_idx] = score
            if not normalized:
                results.append(Result(target_decomp.ablation_set, relevances[layer_idx, head_idx]))


    if normalized:
        sums_per_layer = np.abs(np.sum(relevances, axis=(1)))
        sums_per_layer[sums_per_layer == 0] = -1e-8
        relevances = relevances / np.expand_dims(sums_per_layer, (1))

        for layer_idx in range(num_layers):
            for head_idx in range(num_attention_heads):
                target_decomp = target_decomps[layer_idx * num_attention_heads + head_idx]
                results.append(Result(target_decomp.ablation_set, relevances[layer_idx, head_idx]))

    results.sort(key=operator.attrgetter('score'), reverse=True)
    return results, relevances

In [None]:
def calculate_target_decomposition_scores(target_decomps, normalized=False):
    results = []
    relevances = np.zeros((num_layers, seq_len, num_attention_heads))
    for layer_idx in range(num_layers):
        for seq_pos in range(seq_len):
            for head_idx in range(num_attention_heads):
                idx = layer_idx * num_layers * seq_len + seq_pos * num_attention_heads + head_idx
                target_decomp = target_decomps[idx]
                if target_decomp.ablation_set[0] in target_nodes:
                    continue
                score = 0
                for target_node_idx in range(len(target_decomp.target_nodes)):
                    for batch_idx in range(len(target_decomp.rels)):
                        rels_magnitude = torch.mean(abs(target_decomp.rels[target_node_idx])) # np.mean if you are on cpu
                        irrels_magnitude = torch.mean(abs(target_decomp.irrels[batch_idx])) # np.mean if you are on cpu
                        target_node_score = rels_magnitude / (rels_magnitude + irrels_magnitude)
                        score += target_node_score
                if score != 0:
                    score /= len(target_decomp.rels)

                relevances[layer_idx, seq_pos, head_idx] = score
                if not normalized:
                    results.append(Result(target_decomp.ablation_set, relevances[layer_idx, seq_pos, head_idx]))


    if normalized:
        sums_per_layer = np.abs(np.sum(relevances, axis=(1, 2)))
        sums_per_layer[sums_per_layer == 0] = -1e-8
        relevances = relevances / np.expand_dims(sums_per_layer, (1, 2))

        for layer_idx in range(num_layers):
            for seq_pos in range(seq_len):
                for head_idx in range(num_attention_heads):
                    target_decomp = target_decomps[layer_idx * seq_len * num_attention_heads + seq_pos * num_attention_heads + head_idx]
                    results.append(Result(target_decomp.ablation_set, relevances[layer_idx, seq_pos, head_idx]))

    results.sort(key=operator.attrgetter('score'), reverse=True)
    return results, relevances

In [178]:
target_decomps[0].rels[0][0]

tensor([-0.0368, -0.0941,  0.1855,  0.2428, -0.0738,  0.0607, -0.1385,  0.1011,
         0.0750,  0.0955, -0.0027,  0.0664,  0.0500,  0.0146,  0.0697, -0.0249,
         0.0653,  0.1335,  0.1746, -0.1930,  0.1157, -0.0148,  0.2589,  0.1349,
        -0.0696,  0.0200,  0.0364,  0.0313, -0.0468, -0.0105, -0.0036,  0.1675,
        -0.1756,  0.0926, -0.1959, -0.0925, -0.0743,  0.1034,  0.0553,  0.1374,
        -0.0344, -0.1161,  0.0424, -0.2551,  0.0880,  0.0200, -0.0320, -0.2025,
        -0.1221, -0.1395,  0.0507, -0.1609,  0.2095, -0.0270, -0.0257, -0.0935,
        -0.0396,  0.0354, -0.0641,  0.0662,  0.0389,  0.0927, -0.0686,  0.1923],
       device='cuda:0')

In [234]:
results, relevances = calculate_target_decomposition_scores(target_decomps, normalized=True)

for result in results[:20]:
    print(result.ablation_set[0], result.score)
    # print(result)
'''
(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)
'''

Node(layer_idx=9, sequence_idx=0, attn_head_idx=3) 0.1447705091308364
Node(layer_idx=7, sequence_idx=0, attn_head_idx=8) 0.13397999674087258
Node(layer_idx=7, sequence_idx=0, attn_head_idx=5) 0.1245598699413583
Node(layer_idx=6, sequence_idx=0, attn_head_idx=7) 0.11800533643794124
Node(layer_idx=5, sequence_idx=0, attn_head_idx=10) 0.11756238262186239
Node(layer_idx=8, sequence_idx=0, attn_head_idx=8) 0.11658563322680085
Node(layer_idx=9, sequence_idx=0, attn_head_idx=8) 0.11190624527839693
Node(layer_idx=9, sequence_idx=0, attn_head_idx=0) 0.10658123669559956
Node(layer_idx=9, sequence_idx=0, attn_head_idx=5) 0.10474013448306205
Node(layer_idx=9, sequence_idx=0, attn_head_idx=10) 0.10472372792549454
Node(layer_idx=6, sequence_idx=0, attn_head_idx=4) 0.10395374556525286
Node(layer_idx=8, sequence_idx=0, attn_head_idx=5) 0.10361383371558026
Node(layer_idx=4, sequence_idx=0, attn_head_idx=3) 0.10250919493076233
Node(layer_idx=9, sequence_idx=0, attn_head_idx=7) 0.10090647101533862
Node(l

'\n(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)\n'

In [139]:
print(results[0])

Result(ablation_set=(Node(layer_idx=9, sequence_idx=12, attn_head_idx=1),), score=0.0010632349292983398)


In [53]:
all_nodes = []
for it in outliers_per_iter:
    for result in it:
        if result.ablation_set[0] not in all_nodes:
            all_nodes.append(result.ablation_set[0])
for node in all_nodes:
    print((node))

Node(layer_idx=9, sequence_idx=14, attn_head_idx=9)
Node(layer_idx=10, sequence_idx=14, attn_head_idx=10)
Node(layer_idx=9, sequence_idx=14, attn_head_idx=6)
Node(layer_idx=9, sequence_idx=2, attn_head_idx=6)
Node(layer_idx=0, sequence_idx=1, attn_head_idx=1)
Node(layer_idx=0, sequence_idx=1, attn_head_idx=4)


In [13]:
print(ioi_dataset.sentences[0])
print(test_ioi_dataset.sentences[0])

Then, Vanessa and Paul went to the house. Vanessa gave a basketball to Paul
Then, Jessica and Lindsay went to the school. Jessica gave a snack to Lindsay


# Circuit evaluation

In [169]:
# del out_decomps
# del target_decomps
# del logits
# del cache # pretty sure it's this one
print(torch.cuda.memory_allocated(0)/1024/1024)
print(torch.cuda.memory_reserved(0)/1024/1024)

import gc
gc.collect()

torch.cuda.empty_cache()
print(torch.cuda.memory_allocated(0)/1024/1024)
print(torch.cuda.memory_reserved(0)/1024/1024)


7698.91748046875
10790.0
1087.55322265625
4184.0


In [40]:
ranges = [
        [layer for layer in range(num_layers)],
        [sequence_position for sequence_position in range(seq_len)],
        # [ioi_dataset.word_idx['IO'][0]],
        [attention_head_idx for attention_head_idx in range(num_attention_heads)]
    ]

source_nodes = [Node(*x) for x in itertools.product(*ranges)]
random_circuit = random.sample(source_nodes, 20)

# sample_idxs = random.sample(range(N), NUM_SAMPLES)

In [63]:
# implicitly depends on year_indices/
def correctness_rate(logits, sample_idxs_0):
    logits_for_year_tokens = logits[:, -1, year_indices]
    predicted_year_idxs = np.argmax(logits_for_year_tokens.cpu().numpy(), axis=-1)
    # print(predicted_year_idxs.shape)
    correct_per_input = ds.good_mask.cpu().numpy()[sample_idxs_0, predicted_year_idxs]
    return np.sum(correct_per_input) / len(sample_idxs_0)
    '''
    probs_for_year_tokens = probs[:, year_indices.cpu().numpy()]
    probs_for_correct_years = probs_for_year_tokens[ds.good_mask.cpu().numpy()[sample_idxs_0]]
    correct_score = np.sum(probs_for_correct_years)
    probs_for_incorrect_years = probs_for_year_tokens[np.logical_not(ds.good_mask.cpu().numpy()[sample_idxs_0])]
    incorrect_score = np.sum(probs_for_incorrect_years)
    return (correct_score - incorrect_score) / len(sample_idxs_0)
    '''

In [None]:

circuit = []
for (layer_idx, head_idx) in [(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)]: # greater-than paper's result
    for seq_pos in range(seq_len):
        circuit.append(Node(layer_idx, seq_pos, head_idx))

'''
# simply results from first iter
circuit = [Node(layer_idx=9, sequence_idx=12, attn_head_idx=1),
    # Node(layer_idx=10, sequence_idx=12, attn_head_idx=4),
    # Node(layer_idx=10, sequence_idx=12, attn_head_idx=7),
    Node(layer_idx=7, sequence_idx=12, attn_head_idx=10),
]
# 711, 965
'''
'''
circuit = []
for (layer_idx, head_idx) in [(9, 1), (7, 10)]: # the above but without seq pos
    for seq_pos in range(seq_len):
        circuit.append(Node(layer_idx, seq_pos, head_idx))
'''
evaluate_circuit(circuit)

  probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad


0.7689858919143681
0.9896000000000005


In [None]:
evaluate_circuit(None, True)

  probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad


0.8166431448936461
0.9920000000000004


In [218]:


'''
circuit = [Node(layer_idx=9, sequence_idx=12, attn_head_idx=1),
    Node(layer_idx=10, sequence_idx=12, attn_head_idx=4),
Node(layer_idx=9, sequence_idx=9, attn_head_idx=9),
Node(layer_idx=9, sequence_idx=11, attn_head_idx=1),
Node(layer_idx=9, sequence_idx=11, attn_head_idx=9),
# Node(layer_idx=9, sequence_idx=11, attn_head_idx=6),
# Node(layer_idx=9, sequence_idx=4, attn_head_idx=2),
# Node(layer_idx=5, sequence_idx=3, attn_head_idx=10),
# Node(layer_idx=9, sequence_idx=9, attn_head_idx=1),
# Node(layer_idx=7, sequence_idx=2, attn_head_idx=1),
# Node(layer_idx=7, sequence_idx=2, attn_head_idx=3),
# Node(layer_idx=7, sequence_idx=3, attn_head_idx=8),
# Node(layer_idx=9, sequence_idx=9, attn_head_idx=4),
# Node(layer_idx=7, sequence_idx=2, attn_head_idx=5),
# Node(layer_idx=9, sequence_idx=8, attn_head_idx=10),
# Node(layer_idx=7, sequence_idx=2, attn_head_idx=4),
# Node(layer_idx=6, sequence_idx=3, attn_head_idx=7),
# Node(layer_idx=7, sequence_idx=2, attn_head_idx=11),
# Node(layer_idx=9, sequence_idx=5, attn_head_idx=3),
# Node(layer_idx=6, sequence_idx=2, attn_head_idx=7),
]
'''
circuit = []
for (layer_idx, head_idx) in [(9, 1), (10, 4), (9, 9), (9, 6), (9, 2), (5, 10), (7, 1), (7, 3), (7, 8)]: # the above but without seq pos
    for seq_pos in range(seq_len):
        circuit.append(Node(layer_idx, seq_pos, head_idx))
evaluate_circuit(circuit)
'''
circuit = [Node(layer_idx=9, sequence_idx=12, attn_head_idx=1),
    Node(layer_idx=10, sequence_idx=12, attn_head_idx=4),
Node(layer_idx=9, sequence_idx=11, attn_head_idx=1),
Node(layer_idx=9, sequence_idx=9, attn_head_idx=9),
Node(layer_idx=9, sequence_idx=9, attn_head_idx=1),
Node(layer_idx=9, sequence_idx=11, attn_head_idx=9),
Node(layer_idx=9, sequence_idx=11, attn_head_idx=6),
Node(layer_idx=5, sequence_idx=3, attn_head_idx=10),
Node(layer_idx=7, sequence_idx=3, attn_head_idx=8),
Node(layer_idx=9, sequence_idx=5, attn_head_idx=3),
Node(layer_idx=6, sequence_idx=3, attn_head_idx=7),
Node(layer_idx=4, sequence_idx=3, attn_head_idx=3),
Node(layer_idx=9, sequence_idx=10, attn_head_idx=1),
Node(layer_idx=9, sequence_idx=7, attn_head_idx=10),
Node(layer_idx=8, sequence_idx=12, attn_head_idx=11),
# Node(layer_idx=7, sequence_idx=4, attn_head_idx=5),
# Node(layer_idx=9, sequence_idx=4, attn_head_idx=3),
# Node(layer_idx=8, sequence_idx=9, attn_head_idx=3),
# Node(layer_idx=7, sequence_idx=4, attn_head_idx=8),
# Node(layer_idx=5, sequence_idx=4, attn_head_idx=10),
# Node(layer_idx=9, sequence_idx=8, attn_head_idx=10),
]
'''

  probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad


0.6743128757087552
0.9183673469387756


'\ncircuit = [Node(layer_idx=9, sequence_idx=12, attn_head_idx=1),\n    Node(layer_idx=10, sequence_idx=12, attn_head_idx=4),\nNode(layer_idx=9, sequence_idx=11, attn_head_idx=1),\nNode(layer_idx=9, sequence_idx=9, attn_head_idx=9),\nNode(layer_idx=9, sequence_idx=9, attn_head_idx=1),\nNode(layer_idx=9, sequence_idx=11, attn_head_idx=9),\nNode(layer_idx=9, sequence_idx=11, attn_head_idx=6),\nNode(layer_idx=5, sequence_idx=3, attn_head_idx=10),\nNode(layer_idx=7, sequence_idx=3, attn_head_idx=8),\nNode(layer_idx=9, sequence_idx=5, attn_head_idx=3),\nNode(layer_idx=6, sequence_idx=3, attn_head_idx=7),\nNode(layer_idx=4, sequence_idx=3, attn_head_idx=3),\nNode(layer_idx=9, sequence_idx=10, attn_head_idx=1),\nNode(layer_idx=9, sequence_idx=7, attn_head_idx=10),\nNode(layer_idx=8, sequence_idx=12, attn_head_idx=11),\n# Node(layer_idx=7, sequence_idx=4, attn_head_idx=5),\n# Node(layer_idx=9, sequence_idx=4, attn_head_idx=3),\n# Node(layer_idx=8, sequence_idx=9, attn_head_idx=3),\n# Node(laye

In [293]:

'''
Node(layer_idx=9, sequence_idx=0, attn_head_idx=1) 0.4083923634530539
Node(layer_idx=7, sequence_idx=0, attn_head_idx=10) 0.2221098391988148
Node(layer_idx=10, sequence_idx=0, attn_head_idx=4) 0.2202389078105817
Node(layer_idx=10, sequence_idx=0, attn_head_idx=7) 0.15423218716630077
Node(layer_idx=6, sequence_idx=0, attn_head_idx=9) 0.1332779102978367
Node(layer_idx=11, sequence_idx=0, attn_head_idx=8) 0.12993389925707383
Node(layer_idx=8, sequence_idx=0, attn_head_idx=8) 0.12727556319370611
Node(layer_idx=8, sequence_idx=0, attn_head_idx=11) 0.12419588242524643
Node(layer_idx=4, sequence_idx=0, attn_head_idx=3) 0.11703012724861953
Node(layer_idx=6, sequence_idx=0, attn_head_idx=7) 0.10890546106985093
Node(layer_idx=5, sequence_idx=0, attn_head_idx=10) 0.10690893744165206
'''
circuit = []
for (layer_idx, head_idx) in [(9, 1), (10, 4), (7, 10), (11, 8), (10, 7), (6, 9), (8, 11), (8, 8)]: # the above but without seq pos
    for seq_pos in range(seq_len):
        circuit.append(Node(layer_idx, seq_pos, head_idx))
evaluate_circuit(circuit)
'''
Node(layer_idx=9, sequence_idx=0, attn_head_idx=1) 0.7919426329190301
Node(layer_idx=10, sequence_idx=0, attn_head_idx=4) 0.5688594302281663
Node(layer_idx=7, sequence_idx=0, attn_head_idx=10) 0.3724930226405632
Node(layer_idx=10, sequence_idx=0, attn_head_idx=7) 0.2592619470224411
Node(layer_idx=11, sequence_idx=0, attn_head_idx=8) 0.2262736787177263
Node(layer_idx=8, sequence_idx=0, attn_head_idx=10) 0.21704383205004027
Node(layer_idx=8, sequence_idx=0, attn_head_idx=8) 0.2017522938898915
Node(layer_idx=6, sequence_idx=0, attn_head_idx=9) 0.17671132314414148
'''


  probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad


0.7616304731369021
0.9806000000000002


'\nNode(layer_idx=9, sequence_idx=0, attn_head_idx=1) 0.7919426329190301\nNode(layer_idx=10, sequence_idx=0, attn_head_idx=4) 0.5688594302281663\nNode(layer_idx=7, sequence_idx=0, attn_head_idx=10) 0.3724930226405632\nNode(layer_idx=10, sequence_idx=0, attn_head_idx=7) 0.2592619470224411\nNode(layer_idx=11, sequence_idx=0, attn_head_idx=8) 0.2262736787177263\nNode(layer_idx=8, sequence_idx=0, attn_head_idx=10) 0.21704383205004027\nNode(layer_idx=8, sequence_idx=0, attn_head_idx=8) 0.2017522938898915\nNode(layer_idx=6, sequence_idx=0, attn_head_idx=9) 0.17671132314414148\n'

In [273]:
from pyfunctions.faithfulness_ablations import add_mean_ablation_hook
def evaluate_circuit(circuit, full_model=False):
    # mean_acts = mean_acts.view(old_shape)
    model.reset_hooks(including_permanent=True)
    # current findings:
    # full model: 0.817, 0.989 correctness
    # ablate all attention layers entirely: 0.515, 0.891
    # random circuit of 20 "head, seq_pos": 0.532, 0.891
    # our "four head, seq_pos" circuit: 0.711, 0.955
    # their circuit: 0.765, 0.985
    if full_model:
        ablation_model = model
    else:
        ablation_model = add_mean_ablation_hook(model, patch_values=mean_acts.view(old_shape), circuit=circuit)
    
    # batching
    NUM_AT_TIME = 64
    start_idx = 0
    score = 0
    correctness = 0
    while True:
        end_idx = start_idx + NUM_AT_TIME
        if end_idx > N:
            end_idx = N

        logits, cache = model.run_with_cache(ds.good_toks[start_idx:end_idx]) # run on entire dataset along batch dimension
        batch_score = score_logits(logits, range(start_idx, end_idx))
        batch_correctness_rate = correctness_rate(logits, range(start_idx, end_idx))
        num_samples = end_idx - start_idx
        score += batch_score * (num_samples / N)
        correctness += batch_correctness_rate * (num_samples / N)
        start_idx += NUM_AT_TIME
        if end_idx == N:
            break
    print(score)
    print(correctness)
    ablation_model.reset_hooks(including_permanent=True)


In [None]:
# speculative: try to generate a better circuit by greedy search

NAME_MOVER_HEADS = [Node(9, 14, 9), Node(10, 14, 0), Node(9, 14, 6)]
old_circuit = circuit.copy()
best_score = -1.4686 # 
while True:
    node_to_remove = None
    for idx, node in enumerate(circuit):
        if node in NAME_MOVER_HEADS:
            continue
        new_circuit = circuit.copy()
        new_circuit.remove(node)
        # print(new_circuit)
        model.reset_hooks(including_permanent=True)
        model = add_mean_ablation_hook(model, means_dataset=test_abc_dataset, circuit=new_circuit)
        logits, cache = model.run_with_cache(test_ioi_dataset.toks) # run on entire dataset along batch dimension
        ave_logit_diff = logits_to_ave_logit_diff_2(logits, test_ioi_dataset).cpu().numpy().item()
        if ave_logit_diff > best_score:
            best_score = ave_logit_diff
            node_to_remove = node
            print('tentatively improved score to %f ' % best_score, ' by removing node ', node_to_remove)
    if node_to_remove is None: 
        # then we can't improve any further so the algorithm terminates
        break
    print("removing ", node_to_remove, " to achieve score of %f" % best_score)
    circuit.remove(node_to_remove)
print('Done')

tentatively improved score to -1.255778   by removing node  Node(layer_idx=10, sequence_idx=14, attn_head_idx=10)
removing  Node(layer_idx=10, sequence_idx=14, attn_head_idx=10)  to achieve score of -1.255778
tentatively improved score to -1.220982   by removing node  Node(layer_idx=0, sequence_idx=1, attn_head_idx=4)
removing  Node(layer_idx=0, sequence_idx=1, attn_head_idx=4)  to achieve score of -1.220982
Done


## Without sequence positions