# Setup

In [1]:
%load_ext autoreload
%autoreload 2

https://arxiv.org/pdf/2305.00586


In [2]:
from transformer_lens import HookedTransformer
import random
import sys
import os
import collections
import operator
import functools
import itertools


base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)
from pyfunctions.general import compare_same
from pyfunctions.cdt_basic import *
from pyfunctions.cdt_source_to_target import *
from pyfunctions.cdt_from_source_nodes import *
from pyfunctions.cdt_ablations import *
from pyfunctions.cdt_core import *
from pyfunctions.toy_model import *
from pyfunctions.faithfulness_ablations import add_mean_ablation_hook

from greater_than_task.greater_than_dataset import *
from greater_than_task.utils import get_valid_years


import seaborn as sns
import matplotlib.pyplot as plt


import torch
Result = collections.namedtuple('Result', ('ablation_set', 'score'))




## Load model and dataset


In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch.autograd.set_grad_enabled(False)

from transformer_lens import utils, HookedTransformer, ActivationCache
model = HookedTransformer.from_pretrained("gpt2-small",
                                          center_unembed=True,
                                          center_writing_weights=True,
                                          fold_ln=False,
                                          refactor_factored_attn_matrices=True)
                                          

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
# https://github.com/hannamw/gpt2-greater-than/blob/main/circuit_discovery.py; also these files came with their repo
years_to_sample_from = get_valid_years(model.tokenizer, 1000, 1900)
N = 5000
ds = YearDataset(years_to_sample_from, N, Path("../greater_than_task/cache/potential_nouns.txt"), model.tokenizer, balanced=True, device=device, eos=True)
year_indices = torch.load("../greater_than_task/cache/logit_indices.pt")# .to(device)

num_layers = len(model.blocks)
seq_len = ds.good_toks.size()[-1]
num_attention_heads = model.cfg.n_heads

## Exploration

In [9]:
type(model.tokenizer)

transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast

In [147]:
# print(ds)
'''
These guys weirdly implemented all the functionality of their class in class-level attributes
years_to_sample_from: torch.Tensor
    N: int
    ordered: bool
    eos: bool

    nouns: List[str]
    years: torch.Tensor
    years_YY: torch.Tensor
    good_sentences: List[str]
    bad_sentences: List[str]
    good_toks: torch.Tensor
    bad_toks: torch.Tensor
    good_prompt: List[str]
    bad_prompt: List[str]
    good_mask: torch.Tensor
    tokenizer: PreTrainedTokenizer
    '''

# ds.N
# ds.nouns
# print(ds.years[:20]) # not sorted by XX for some reason
# print(ds.years_YY[:]) # but does correspond to these YYs, which are mostly sorted
print(ds.good_sentences[-10:]) # includes The endeavor lasted from the year 1098 to the year 10', but 1099 isn't in the list of years?
# note: we want prediction at the last token, unlike with the IOI dataset where we want second-to-last
# i checked and there is no internal logic to prevent such sentences from being produced, so i guess we're SOL if we sample one?
# print(ds.bad_sentences[-10:]) # these all start with 01, e.g 1601 to. they're bad because there is no possible incorrect input
# print(ds.good_mask.size()) # n, 100 (100 different years)
# print(ds.good_toks.size()) # n, 13
# print(ds.bad_toks.size()) # there isn't any necessary correspondence, N is just the number of good sequences and bad sequences alike
# list(ds.years.cpu().numpy()).index(1099)
print(year_indices)
print(model.tokenizer.convert_ids_to_tokens(year_indices)) # length 100, starts with index for '00' and ends with index for '99', great
# print(model.tokenizer.decode(year_indices, clean_up_tokenization_spaces=False))

['<|endoftext|> The clash lasted from the year 1594 to the year 15', '<|endoftext|> The program lasted from the year 1395 to the year 13', '<|endoftext|> The challenge lasted from the year 1496 to the year 14', '<|endoftext|> The confrontation lasted from the year 1597 to the year 15', '<|endoftext|> The marriage lasted from the year 1098 to the year 10', '<|endoftext|> The journey lasted from the year 1202 to the year 12', '<|endoftext|> The insurgency lasted from the year 1803 to the year 18', '<|endoftext|> The improvement lasted from the year 1404 to the year 14', '<|endoftext|> The consultation lasted from the year 1705 to the year 17', '<|endoftext|> The domination lasted from the year 1606 to the year 16']
tensor([ 405,  486, 2999, 3070, 3023, 2713, 3312, 2998, 2919, 2931,  940, 1157,
        1065, 1485, 1415, 1314, 1433, 1558, 1507, 1129, 1238, 2481, 1828, 1954,
        1731, 1495, 2075, 1983, 2078, 1959, 1270, 3132, 2624, 2091, 2682, 2327,
        2623, 2718, 2548, 2670, 1821,

## Setup attention mask and mean activations for ablation

In [5]:
attention_mask = torch.tensor([1 for x in range(seq_len)]).view(1, -1).to(device)
input_shape = ds.good_toks[0:1, :].size() # by making the sample size 1, you can get an extended attention mask with batch size 1, which will broadcast
extended_attention_mask = get_extended_attention_mask(attention_mask, 
                                                        input_shape, 
                                                        model,
                                                        device)



In [6]:
NUM_AT_TIME = 64
start_idx = 0
score = 0
correctness = 0

all_attention_outputs = []
while True:
    end_idx = start_idx + NUM_AT_TIME
    if end_idx > N:
        end_idx = N
    logits, cache = model.run_with_cache(ds.good_toks[start_idx:end_idx]) # run on entire dataset along batch dimension
    attention_outputs = [cache['blocks.' + str(i) + '.attn.hook_z'] for i in range(num_attention_heads)]
    attention_outputs = torch.stack(attention_outputs, dim=1) # now batch, layer, seq, n_heads, dim_attn
    all_attention_outputs.append(attention_outputs)

    start_idx += NUM_AT_TIME
    if end_idx == N:
        break
all_attention_outputs = torch.cat(all_attention_outputs, dim=0)
mean_acts = torch.mean(all_attention_outputs, dim=0)
old_shape = mean_acts.shape
last_dim = old_shape[-2] * old_shape[-1]
new_shape = old_shape[:-2] + (last_dim,)
mean_acts = mean_acts.view(new_shape)
mean_acts.shape

torch.Size([12, 13, 768])

In [7]:
# quick check for equality, particularly to make sure we've made the attention mask correctly
ranges = [
        [layer for layer in range(num_layers)],
        [sequence_position for sequence_position in range(seq_len)],
        [attention_head_idx for attention_head_idx in range(num_attention_heads)]
    ]

source_nodes = [Node(*x) for x in itertools.product(*ranges)]
ablation_sets = [(n,) for n in source_nodes]
target_nodes = []
out_decomp, _, _, _ = prop_GPT(ds.good_toks[0:1, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=None, set_irrel_to_mean=False)

logits, cache = model.run_with_cache(ds.good_toks[0])

compare_same(out_decomp[0].rel + out_decomp[0].irrel, logits)

100.00% of the values are equal


1.0

# Experiments

In [8]:
import random
NUM_SAMPLES = 100
sample_idxs = random.sample(range(N), NUM_SAMPLES) # you actually have to sample randomly from this dataset because they are arranged in increasing order of YY token
# sample_idxs

In [9]:
example_prompt = ds.good_sentences[88] # GPT2 doesn't always perform this task correctly, only about 99% of the time.
# On example input <|endoftext|> The pursuit lasted from the year 1290 to the year 12 , the top prediction is '90'.
example_answer = '03'

transformer_lens.utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', '<|endoftext|>', ' The', ' custody', ' lasted', ' from', ' the', ' year', ' 17', '90', ' to', ' the', ' year', ' 17']
Tokenized answer: ['03']


Top 0th token. Logit: 27.91 Prob: 19.42% Token: |90|
Top 1th token. Logit: 27.74 Prob: 16.48% Token: |95|
Top 2th token. Logit: 27.37 Prob: 11.38% Token: |99|
Top 3th token. Logit: 27.30 Prob: 10.55% Token: |94|
Top 4th token. Logit: 27.10 Prob:  8.65% Token: |96|
Top 5th token. Logit: 26.75 Prob:  6.11% Token: |92|
Top 6th token. Logit: 26.69 Prob:  5.72% Token: |91|
Top 7th token. Logit: 26.64 Prob:  5.48% Token: |98|
Top 8th token. Logit: 26.46 Prob:  4.57% Token: |97|
Top 9th token. Logit: 26.35 Prob:  4.09% Token: |93|


In [10]:
# This is not a pure function. It depends on ds.good_mask, sample_idxs, and year_indices.
def score_logits(logits, sample_idxs_0):
    probs = torch.nn.functional.softmax(torch.tensor(logits[:, -1, :], device='cpu'), dim=-1).numpy() # sad
    probs_for_year_tokens = probs[:, year_indices.cpu().numpy()]
    probs_for_correct_years = probs_for_year_tokens[ds.good_mask.cpu().numpy()[sample_idxs_0]]
    correct_score = np.sum(probs_for_correct_years)
    probs_for_incorrect_years = probs_for_year_tokens[np.logical_not(ds.good_mask.cpu().numpy()[sample_idxs_0])]
    incorrect_score = np.sum(probs_for_incorrect_years)
    return (correct_score - incorrect_score) / len(sample_idxs_0)


In [None]:
model.reset_hooks(including_permanent=True)

mean_acts = mean_acts.view(new_shape)
'''
ranges = [
        [layer for layer in range(num_layers)],
        [sequence_position for sequence_position in range(seq_len)],
        [attention_head_idx for attention_head_idx in range(num_attention_heads)]
    ]

source_nodes = [Node(*x) for x in itertools.product(*ranges)]
ablation_sets = [(n,) for n in source_nodes]
'''
ablation_sets = []
for layer in range(num_layers):
    for head_idx in range(num_attention_heads):
        ablation_sets.append(tuple(Node(layer, seq_pos, head_idx) for seq_pos in range(seq_len)))
target_nodes = []

# cache activations for faster batch run
out_decomp, _, _, pre_layer_activations = prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)

prop_fn = lambda ablation_list: prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, ablation_list, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True, cached_pre_layer_acts=pre_layer_activations)
out_decomps, target_decomps = batch_run(prop_fn, ablation_sets, num_at_time=(max(64 // len(sample_idxs), 1)))

In [None]:
def compute_logits_decomposition_scores(out_decomps, sample_idxs, normalized=False):
    logits = (out_decomps[0].rel + out_decomps[0].irrel)
    full_score = score_logits(logits, sample_idxs)
    assert(full_score > 0) # as mentioned above, GPT2 doesn't succeed at this 100% of the time
    
    results = []
    relevances = np.zeros((num_layers, num_attention_heads))

    for layer_idx in range(num_layers):

        for head_idx in range(num_attention_heads):
            decomp = out_decomps[layer_idx * num_attention_heads + head_idx]
            score = score_logits(decomp.rel, sample_idxs)
            norm_score = score / full_score
            relevances[layer_idx, head_idx] = norm_score
            if not normalized:
                results.append(Result(decomp.ablation_set, norm_score))
    if normalized:
        sums_per_layer = np.sum(np.abs(relevances), axis=(1))
        print(sums_per_layer)

        sums_per_layer[sums_per_layer == 0] = -1e-8
        relevances = relevances / np.expand_dims(sums_per_layer, (1))
        for layer_idx in range(num_layers):
            for head_idx in range(num_attention_heads):
                target_decomp = target_decomps[layer_idx * num_attention_heads + head_idx]
                results.append(Result(target_decomp.ablation_set, relevances[layer_idx, head_idx]))
    results.sort(key=operator.attrgetter('score'), reverse=True)

    return results, relevances

In [None]:
results, relevances = compute_logits_decomposition_scores(out_decomps, sample_idxs, normalized=True)

results.sort(key=operator.attrgetter('score'), reverse=True)
for result in results[:20]:
    print(result.ablation_set[0], result.score)
'''
 a9.h1, while
MLP 8 relies on a8.h11, a8.h8, a7.h10, a6.h9, a5.h5, and a5.h1

(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)
'''


In [None]:
model.reset_hooks(including_permanent=True)

mean_acts = mean_acts.view(new_shape)

ablation_sets = []
for layer in range(num_layers):
    for head_idx in range(num_attention_heads):
        ablation_sets.append(tuple(Node(layer, seq_pos, head_idx) for seq_pos in range(seq_len)))
target_nodes = []
for layer, head_idx in [(9, 1), (10, 4)]:
    for seq_pos in range(seq_len):
        target_nodes.append(Node(layer, seq_pos, head_idx))

_, _, _, pre_layer_activations = prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)

prop_fn = lambda ablation_list: prop_GPT(ds.good_toks[sample_idxs, :], extended_attention_mask, model, ablation_list, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True, cached_pre_layer_acts=pre_layer_activations)
out_decomps, target_decomps = batch_run(prop_fn, ablation_sets, num_at_time=max(64 // len(sample_idxs), 1))

In [None]:
def calculate_target_decomposition_scores(target_decomps, normalized=False):
    results = []
    relevances = np.zeros((num_layers, num_attention_heads))
    for layer_idx in range(num_layers):
        for head_idx in range(num_attention_heads):
            idx = layer_idx * num_attention_heads + head_idx
            target_decomp = target_decomps[idx]
            if target_decomp.ablation_set[0] in target_nodes:
                continue
            score = 0
            for target_node_idx in range(len(target_decomp.target_nodes)):
                for batch_idx in range(len(target_decomp.rels)):
                    rels_magnitude = torch.mean(abs(target_decomp.rels[target_node_idx])) # np.mean if you are on cpu
                    irrels_magnitude = torch.mean(abs(target_decomp.irrels[batch_idx])) # np.mean if you are on cpu
                    target_node_score = rels_magnitude / (rels_magnitude + irrels_magnitude)
                    score += target_node_score
            if score != 0:
                score /= len(target_decomp.rels)

            relevances[layer_idx, head_idx] = score
            if not normalized:
                results.append(Result(target_decomp.ablation_set, relevances[layer_idx, head_idx]))


    if normalized:
        sums_per_layer = np.abs(np.sum(relevances, axis=(1)))
        sums_per_layer[sums_per_layer == 0] = -1e-8
        relevances = relevances / np.expand_dims(sums_per_layer, (1))

        for layer_idx in range(num_layers):
            for head_idx in range(num_attention_heads):
                target_decomp = target_decomps[layer_idx * num_attention_heads + head_idx]
                results.append(Result(target_decomp.ablation_set, relevances[layer_idx, head_idx]))

    results.sort(key=operator.attrgetter('score'), reverse=True)
    return results, relevances

In [None]:
results, relevances = calculate_target_decomposition_scores(target_decomps, normalized=True)

for result in results[:20]:
    print(result.ablation_set[0], result.score)
    # print(result)
'''
(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)
'''

In [None]:
all_nodes = []
for it in outliers_per_iter:
    for result in it:
        if result.ablation_set[0] not in all_nodes:
            all_nodes.append(result.ablation_set[0])
for node in all_nodes:
    print((node))

# Circuit evaluation

In [None]:
ranges = [
        [layer for layer in range(num_layers)],
        [sequence_position for sequence_position in range(seq_len)],
        # [ioi_dataset.word_idx['IO'][0]],
        [attention_head_idx for attention_head_idx in range(num_attention_heads)]
    ]

source_nodes = [Node(*x) for x in itertools.product(*ranges)]
random_circuit = random.sample(source_nodes, 20)



In [None]:
# implicitly depends on year_indices/
def correctness_rate(logits, sample_idxs_0):
    logits_for_year_tokens = logits[:, -1, year_indices]
    predicted_year_idxs = np.argmax(logits_for_year_tokens.cpu().numpy(), axis=-1)
    # print(predicted_year_idxs.shape)
    correct_per_input = ds.good_mask.cpu().numpy()[sample_idxs_0, predicted_year_idxs]
    return np.sum(correct_per_input) / len(sample_idxs_0)


In [None]:
def evaluate_circuit(circuit, full_model=False):
    # mean_acts = mean_acts.view(old_shape)
    model.reset_hooks(including_permanent=True)

    if full_model:
        ablation_model = model
    else:
        ablation_model = add_mean_ablation_hook(model, patch_values=mean_acts.view(old_shape), circuit=circuit)
    
    # batching, since the datasets for this task are typically large
    NUM_AT_TIME = 64
    start_idx = 0
    score = 0
    correctness = 0
    while True:
        end_idx = start_idx + NUM_AT_TIME
        if end_idx > N:
            end_idx = N

        logits, cache = model.run_with_cache(ds.good_toks[start_idx:end_idx]) # run on entire dataset along batch dimension
        batch_score = score_logits(logits, range(start_idx, end_idx))
        batch_correctness_rate = correctness_rate(logits, range(start_idx, end_idx))
        num_samples = end_idx - start_idx
        score += batch_score * (num_samples / N)
        correctness += batch_correctness_rate * (num_samples / N)
        start_idx += NUM_AT_TIME
        if end_idx == N:
            break
    print(score)
    print(correctness)
    ablation_model.reset_hooks(including_permanent=True)


In [None]:

circuit = []
for (layer_idx, head_idx) in [(9, 1), (8, 11), (8, 8), (7, 10), (6, 9), (5, 5), (5, 1)]: # greater-than paper's result
    for seq_pos in range(seq_len):
        circuit.append(Node(layer_idx, seq_pos, head_idx))

evaluate_circuit(circuit)

In [None]:
evaluate_circuit(None, True)

In [None]:

circuit = []
for (layer_idx, head_idx) in [(9, 1), (10, 4), (7, 10), (11, 8), (10, 7), (6, 9), (8, 11), (8, 8)]: # the above but without seq pos
    for seq_pos in range(seq_len):
        circuit.append(Node(layer_idx, seq_pos, head_idx))
evaluate_circuit(circuit)
