In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
# This is for if we're trying to execute on a remote JupyterHub, where the pwd is set to the server root, or else I think pwd is set correctly already.
# %cd CD_Circuit/

import numpy as np
import os
import sys
import torch
import torch.nn.functional as F
import warnings
import random
import collections

import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import itertools
import operator
import functools

from transformer_lens import utils, HookedTransformer, ActivationCache

warnings.filterwarnings("ignore")

base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)

from pyfunctions.cdt_basic import *
from pyfunctions.cdt_source_to_target import *
from pyfunctions.ioi_dataset import IOIDataset
from pyfunctions.wrappers import Node, AblationSet
from pyfunctions.faithfulness_ablations import logits_to_ave_logit_diff_2, add_mean_ablation_hook


Result = collections.namedtuple('Result', ('ablation_set', 'score'))


KeyboardInterrupt: 

## Load Model


In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch.autograd.set_grad_enabled(False)
# Model code adapted from Callum McDougall's notebook for ARENA on reproducing the IOI paper using TransformerLens.
# This makes some sense, since EasyTransformer, the repo/lib released by the IOI guys, was forked from TransformerLens.
# In fact, this makes the reproduction a little bit more faithful, since they most likely do certain things such as 
# "folding" LayerNorms to improve their interpretability results, and we are able to do the same by using TransformerLens.

model = HookedTransformer.from_pretrained("gpt2-small",
                                          center_unembed=True,
                                          center_writing_weights=True,
                                          fold_ln=False,
                                          refactor_factored_attn_matrices=True)
                                          

## Generate mean activations / Example usage of the IOI dataset

This is not as simple as it sounds; for the IOI paper, for each individual input following a template, they ablate using the mean activations of the "ABC" dataset, generated over sentences following the same template.

For those who are familiar with usage of the IOI dataset code, our code is not designed to take advantage of the IOI dataset's sequence position labels (it fundamentally can't be because our method is semi-automated and therefore can't incorporate knowledge of the sequence position labels, i.e, we can find that unlabeled positions are relevant), so circuit analysis needs to be done on a per-template basis.

In [None]:
# Generate a dataset all consisting of one template, randomly chosen.
# nb_templates = 2 due to some logic internal to IOIDataset:
# essentially, the nouns can be an ABBA or ABAB order and that counts as separate templates.
ioi_dataset = IOIDataset(prompt_type="mixed", N=50, tokenizer=model.tokenizer, prepend_bos=False, nb_templates=2)

# This is the P_ABC that is mentioned in the IOI paper, which we use for mean ablation.
# Importantly, passing in prompt_type="ABC" or similar is NOT the same thing as this.
abc_dataset = (
    ioi_dataset.gen_flipped_prompts(("IO", "RAND"))
    .gen_flipped_prompts(("S", "RAND"))
    .gen_flipped_prompts(("S1", "RAND"))
)

logits, cache = model.run_with_cache(abc_dataset.toks) # run on entire dataset along batch dimension

# A technical detail: We patch at what TLens calls the "z" activation in the attention, which if you think about it is the only natural way to patch attention outputs on a per-head basis with the standard attention implementation that doesn't have a separate dimension for attention heads.
attention_outputs = [cache['blocks.' + str(i) + '.attn.hook_z'] for i in range(12)]
attention_outputs = torch.stack(attention_outputs, dim=1) # now batch, layer, seq, n_heads, dim_attn
mean_acts = torch.mean(attention_outputs, dim=0)

# A sad detail: different implementations of attention have a separate dimension for the attention heads, and we need a lot of boilerplate code to make sure the shapes are as expected everywhere. There may be a more elegant solution available in this repo, but I didn't find it.
old_shape = mean_acts.shape
last_dim = old_shape[-2] * old_shape[-1]
new_shape = old_shape[:-2] + (last_dim,)
mean_acts = mean_acts.view(new_shape)
mean_acts.shape

In [None]:
# This is some ugliness that was ultimately required to get all the forward pass code to share infrastructure between models of different architecture.
# It is possible that there can be less of this ugliness, but not zero, unless we separately implement the method for each model.
# Consider this part of the setup, and don't think too hard about it because there aren't a lot of real subtleties in here.
text = ioi_dataset.sentences[0]
encoding = model.tokenizer.encode_plus(text, 
                                 add_special_tokens=True, 
                                 max_length=512,
                                 truncation=True, 
                                 padding = "longest", 
                                 return_attention_mask=True, 
                                 return_tensors="pt").to(device)
encoding_idxs, attention_mask = encoding.input_ids, encoding.attention_mask
input_shape = encoding_idxs.size()
extended_attention_mask = get_extended_attention_mask(attention_mask, 
                                                        input_shape, 
                                                        model,
                                                        device)

## Analysis

These cells define the two basic operations of our method: decomposing the contribution directly to the logits, and decomposing the contribution to given target nodes.
If you want to perform a specific analysis that requires some degree of human intervention or heuristic pruning, these cells are the place to start.

In [6]:

ranges = [
        [layer for layer in range(12)],
        [sequence_position for sequence_position in range(input_shape[1])],
        [attention_head_idx for attention_head_idx in range(12)]
    ]

source_nodes = [Node(*x) for x in itertools.product(*ranges)]
ablation_sets = [(n,) for n in source_nodes]

target_nodes = []

# cache activations for faster batch run
out_decomp, _, _, pre_layer_activations = prop_GPT(encoding_idxs[0:1, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)

prop_fn = lambda ablation_list: prop_GPT(encoding_idxs[0:1, :], extended_attention_mask, model, ablation_list, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True, cached_pre_layer_acts=pre_layer_activations)
out_decomps, target_decomps = batch_run(prop_fn, ablation_sets)

running input 0
running input 1600


In [1]:
def compute_logits_decomposition_scores(out_decomps):
    logits = (out_decomps[0].rel + out_decomps[0].irrel) # 1, seq_len, 50257=d_vocab
    io_logit = logits[0, -2, ioi_dataset.io_tokenIDs[0]]
    s_logit = logits[0, -2, ioi_dataset.s_tokenIDs[0]]
    full_score = np.abs(io_logit - s_logit)
    assert(full_score > 0) # GPT2 succeeds at this 99%+ of the time but not always. If you are doing analysis over a batch it mostly won't make a difference.

    results = []
    
    for decomp in out_decomps:
        rel_io_logit = decomp.rel[0, -2, ioi_dataset.io_tokenIDs[0]]
        rel_s_logit = decomp.rel[0, -2, ioi_dataset.s_tokenIDs[0]]
        score = rel_io_logit - rel_s_logit
        results.append(Result(decomp.ablation_set, score))
    results.sort(key=operator.attrgetter('score'), reverse=True)
    return results

In [2]:
results = compute_logits_decomposition_scores(out_decomps)

for result in results[:10]:
    print(result)

NameError: name 'out_decomps' is not defined

In [20]:
outliers_per_iter = []
results_per_iter = [results]

In [77]:
# Now, find maximally relevant source nodes to target nodes

outliers = results[:2] # this is a hardcoded heuristic
outliers_per_iter.append(outliers)
target_nodes = [r.ablation_set[0] for r in outliers] # here we assume that we only ever tried to ablate one node at once
print(target_nodes)
ranges = [
        [layer for layer in range(12)],
        [sequence_position for sequence_position in range(16)],
        # [ioi_dataset.word_idx['IO'][0]],
        [attention_head_idx for attention_head_idx in range(12)]
    ]

source_nodes = [Node(*x) for x in itertools.product(*ranges)]
ablation_sets = [(n,) for n in source_nodes]

_, _, _, pre_layer_activations = prop_GPT(encoding_idxs[0:1, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)

prop_fn = lambda ablation_list: prop_GPT(encoding_idxs[0:1, :], extended_attention_mask, model, ablation_list, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True, cached_pre_layer_acts=pre_layer_activations)
out_decomps, target_decomps = batch_run(prop_fn, ablation_sets)

[Node(layer_idx=1, sequence_idx=3, attn_head_idx=7), Node(layer_idx=1, sequence_idx=3, attn_head_idx=10)]
Running inputs 0 to 64 (of 2304)
Running inputs 64 to 128 (of 2304)
Running inputs 128 to 192 (of 2304)
Running inputs 192 to 256 (of 2304)
Running inputs 256 to 320 (of 2304)
Running inputs 320 to 384 (of 2304)
Running inputs 384 to 448 (of 2304)
Running inputs 448 to 512 (of 2304)
Running inputs 512 to 576 (of 2304)
Running inputs 576 to 640 (of 2304)
Running inputs 640 to 704 (of 2304)
Running inputs 704 to 768 (of 2304)
Running inputs 768 to 832 (of 2304)
Running inputs 832 to 896 (of 2304)
Running inputs 896 to 960 (of 2304)
Running inputs 960 to 1024 (of 2304)
Running inputs 1024 to 1088 (of 2304)
Running inputs 1088 to 1152 (of 2304)
Running inputs 1152 to 1216 (of 2304)
Running inputs 1216 to 1280 (of 2304)
Running inputs 1280 to 1344 (of 2304)
Running inputs 1344 to 1408 (of 2304)
Running inputs 1408 to 1472 (of 2304)
Running inputs 1472 to 1536 (of 2304)
Running inputs 15

In [14]:
def calculate_target_decomposition_scores(target_decomps, method="l1", mean_acts=None, attn_cache=None):
    results = []
    relevances = np.zeros((12, 16, 12))
    for layer in range(12):
        for sequence_position in range(16):
            for attention_head_idx in range(12):
                idx = layer * 16 * 12 + sequence_position * 12 + attention_head_idx
                target_decomp = target_decomps[idx]
                if target_decomp.ablation_set[0] in target_nodes:
                    continue
                score = 0
                for i in range(len(target_decomp.target_nodes)):
                    if method == 'l1':
                        rels_magnitude = torch.mean(abs(target_decomp.rels[i])) # np.mean if you are on cpu
                        irrels_magnitude = torch.mean(abs(target_decomp.irrels[i])) # np.mean if you are on cpu
                        target_node_score = rels_magnitude / (rels_magnitude + irrels_magnitude)
                        score += target_node_score
                    if method == 'dot':
                        target_node = target_decomp.target_nodes[i]
                        # this method is only implemented for a single datapoint
                        if mean_acts is None or attn_cache is None:
                            print("Invalid target decomposition score calculation") # and then this is going to crash anyway
                        target_mean_act = mean_acts[target_node.layer_idx, target_node.sequence_idx, target_node.attn_head_idx]
                        target_rel = attn_cache['blocks.' + str(target_node.layer_idx) + '.attn.hook_z'][0][target_node.sequence_idx][target_node.attn_head_idx] - target_mean_act 
                        rel = target_decomp.rels[i][0]
                        #print(target_rel.shape, rel.shape)
                        score += torch.dot(rel, target_rel)
                relevances[layer, sequence_position, attention_head_idx] = score


    sums_per_layer = np.sum(relevances, axis=(1, 2))
    sums_per_layer[sums_per_layer == 0] = -1e-8
    normalized_relevances = relevances / np.expand_dims(sums_per_layer, (1, 2))

    num_layers = 12
    seq_len = 16
    num_attention_heads = 12
    for layer_idx in range(num_layers):
        for seq_pos in range(seq_len):
            for head_idx in range(num_attention_heads):
                target_decomp = target_decomps[layer_idx * seq_len * num_attention_heads + seq_pos * num_attention_heads + head_idx]
                results.append(Result(target_decomp.ablation_set, normalized_relevances[layer_idx, seq_pos, head_idx]))

    results.sort(key=operator.attrgetter('score'), reverse=True)
    return results

In [83]:
outliers_per_iter

[[Result(ablation_set=(Node(layer_idx=8, sequence_idx=14, attn_head_idx=6),), score=0.03813974573802492),
  Result(ablation_set=(Node(layer_idx=8, sequence_idx=11, attn_head_idx=6),), score=0.034177570085989574)],
 [Result(ablation_set=(Node(layer_idx=8, sequence_idx=14, attn_head_idx=6),), score=0.03813974573802492),
  Result(ablation_set=(Node(layer_idx=8, sequence_idx=11, attn_head_idx=6),), score=0.034177570085989574)],
 [Result(ablation_set=(Node(layer_idx=9, sequence_idx=14, attn_head_idx=9),), score=0.87486565),
  Result(ablation_set=(Node(layer_idx=9, sequence_idx=14, attn_head_idx=6),), score=0.37573943)],
 [Result(ablation_set=(Node(layer_idx=8, sequence_idx=14, attn_head_idx=6),), score=0.03813974573802492),
  Result(ablation_set=(Node(layer_idx=8, sequence_idx=11, attn_head_idx=6),), score=0.034177570085989574)],
 [Result(ablation_set=(Node(layer_idx=5, sequence_idx=10, attn_head_idx=5),), score=0.02986675545961833),
  Result(ablation_set=(Node(layer_idx=7, sequence_idx=11,

In [81]:
for result in results[:10]:
    print(result)

Result(ablation_set=(Node(layer_idx=0, sequence_idx=2, attn_head_idx=1),), score=0.08746183454945368)
Result(ablation_set=(Node(layer_idx=0, sequence_idx=2, attn_head_idx=4),), score=0.07542276779756175)
Result(ablation_set=(Node(layer_idx=0, sequence_idx=3, attn_head_idx=6),), score=0.07055409286121249)
Result(ablation_set=(Node(layer_idx=0, sequence_idx=2, attn_head_idx=5),), score=0.06713339385727347)
Result(ablation_set=(Node(layer_idx=0, sequence_idx=2, attn_head_idx=3),), score=0.05414199889043761)
Result(ablation_set=(Node(layer_idx=0, sequence_idx=3, attn_head_idx=7),), score=0.049619719961781564)
Result(ablation_set=(Node(layer_idx=0, sequence_idx=2, attn_head_idx=6),), score=0.04932286453511382)
Result(ablation_set=(Node(layer_idx=0, sequence_idx=2, attn_head_idx=10),), score=0.046369913283579374)
Result(ablation_set=(Node(layer_idx=0, sequence_idx=3, attn_head_idx=0),), score=0.042463772022659954)
Result(ablation_set=(Node(layer_idx=0, sequence_idx=3, attn_head_idx=4),), sco

In [86]:
all_nodes = []
for it in outliers_per_iter:
    for result in it:
        if result.ablation_set[0] not in all_nodes:
            all_nodes.append(result.ablation_set[0])
        

## Automatic search

This is just a bunch of the above cells put into a neat cell that automatically finds some sort of circuit without any manual intervention.
As explained above, the code is not designed to take advantage of the IOI dataset's sequence position labels, so circuit analysis needs to be done on a per-template basis; here a template is hardcoded.

In [52]:
from pyfunctions.ioi_dataset import ABC_TEMPLATES, BAC_TEMPLATES, BABA_TEMPLATES, BABA_LONG_TEMPLATES, BABA_LATE_IOS, BABA_EARLY_IOS, ABBA_TEMPLATES, ABBA_LATE_IOS, ABBA_EARLY_IOS

model.reset_hooks(including_permanent=True)

NUM_SAMPLES = 1
NUM_OUTLIERS_TO_KEEP_PER_ITER = 2
template = ABBA_EARLY_IOS[0]
ioi_dataset = IOIDataset(N=50, tokenizer=model.tokenizer, prepend_bos=False, prompt_type=[template])

# This is the P_ABC that is mentioned in the IOI paper, which we use for mean ablation.
# Importantly, passing in prompt_type="ABC" or similar is NOT the same thing as this.
abc_dataset = (
    ioi_dataset.gen_flipped_prompts(("IO", "RAND"))
    .gen_flipped_prompts(("S", "RAND"))
    .gen_flipped_prompts(("S1", "RAND"))
)
ioi_logits, ioi_cache = model.run_with_cache(ioi_dataset.toks) # run on entire dataset along batch dimension
logits, cache = model.run_with_cache(abc_dataset.toks) # run on entire dataset along batch dimension

attention_outputs = [cache['blocks.' + str(i) + '.attn.hook_z'] for i in range(12)]
attention_outputs = torch.stack(attention_outputs, dim=1) # now batch, layer, seq, n_heads, dim_attn
mean_acts = torch.mean(attention_outputs, dim=0)
old_shape = mean_acts.shape
last_dim = old_shape[-2] * old_shape[-1]
new_shape = old_shape[:-2] + (last_dim,)
mean_acts = mean_acts.view(new_shape)

text = ioi_dataset.sentences[0]
encoding = model.tokenizer.encode_plus(text, 
                                 add_special_tokens=True, 
                                 max_length=512,
                                 truncation=True, 
                                 padding = "longest", 
                                 return_attention_mask=True, 
                                 return_tensors="pt").to(device)
input_shape = encoding.input_ids.size()
extended_attention_mask = get_extended_attention_mask(encoding.attention_mask, 
                                                        input_shape, 
                                                        model,
                                                        device)
seq_len = ioi_dataset.toks.shape[1]
print('sequence length: %d ' % seq_len)
# Calculate relevance to logits
ranges = [
        [layer for layer in range(12)],
        [sequence_position for sequence_position in range(seq_len)],
        [attention_head_idx for attention_head_idx in range(12)]
    ]

source_nodes = [Node(*x) for x in itertools.product(*ranges)]
ablation_sets = [(n,) for n in source_nodes]
target_nodes = []

# cache activations for faster batch run
out_decomp, _, _, pre_layer_activations = prop_GPT(ioi_dataset.toks[0:NUM_SAMPLES, :], extended_attention_mask, model, [ablation_sets[0]], target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True)
prop_fn = lambda ablation_list: prop_GPT(ioi_dataset.toks[0:NUM_SAMPLES, :], extended_attention_mask, model, ablation_list, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True, cached_pre_layer_acts=pre_layer_activations)
out_decomps, _ = batch_run(prop_fn, ablation_sets, num_at_time=(64 // NUM_SAMPLES))
results = compute_logits_decomposition_scores(out_decomps)


# This loop implements a simple heuristic of keeping a hardcoded top N outliers from each iteration.
# It terminates when all the nodes are in the first layer, so it has the shortcoming of continually trying to find nodes even when they are not necessarily important.
# Various heuristic techniques, such as filtering nodes by how their relevance scores compare to others in the same layer, or same iteration, can be applied.
# It is also possible to implement early stopping or other heuristic techniques based on the circuit's performance.


outliers_per_iter = []
while True:
    outliers = results[:NUM_OUTLIERS_TO_KEEP_PER_ITER]
    outliers_per_iter.append(outliers)
    target_nodes = [r.ablation_set[0] for r in outliers]
    print(target_nodes)
    should_break = True
    for node in target_nodes:
        if node.layer_idx != 0:
            should_break = False
    if should_break:
        break

    # In this loop, we implement search over all sequence positions.
    # This result is less stable than the one augmented by some amount of manual analysis.
    ranges = [
            [layer for layer in range(12)],
            [sequence_position for sequence_position in range(seq_len)],
            [attention_head_idx for attention_head_idx in range(12)]
        ]
    source_nodes = [Node(*x) for x in itertools.product(*ranges)]
    ablation_sets = [(n,) for n in source_nodes]
    prop_fn = lambda ablation_list: prop_GPT(ioi_dataset.toks[0:NUM_SAMPLES, :], extended_attention_mask, model, ablation_list, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=True, cached_pre_layer_acts=pre_layer_activations)
    _, target_decomps = batch_run(prop_fn, ablation_sets, num_at_time=(64 // NUM_SAMPLES))
    
    results = calculate_target_decomposition_scores(target_decomps, method="dot", mean_acts=mean_acts.view(old_shape), attn_cache=ioi_cache)


sequence length: 16 
running input 0
running input 1600
[Node(layer_idx=9, sequence_idx=14, attn_head_idx=9), Node(layer_idx=10, sequence_idx=14, attn_head_idx=10)]
running input 0
running input 1600
[Node(layer_idx=9, sequence_idx=14, attn_head_idx=6), Node(layer_idx=9, sequence_idx=2, attn_head_idx=6)]
running input 0
running input 1600
[Node(layer_idx=0, sequence_idx=1, attn_head_idx=1), Node(layer_idx=0, sequence_idx=1, attn_head_idx=4)]


In [53]:
all_nodes = []
for it in outliers_per_iter:
    for result in it:
        if result.ablation_set[0] not in all_nodes:
            all_nodes.append(result.ablation_set[0])
for node in all_nodes:
    print((node))

Node(layer_idx=9, sequence_idx=14, attn_head_idx=9)
Node(layer_idx=10, sequence_idx=14, attn_head_idx=10)
Node(layer_idx=9, sequence_idx=14, attn_head_idx=6)
Node(layer_idx=9, sequence_idx=2, attn_head_idx=6)
Node(layer_idx=0, sequence_idx=1, attn_head_idx=1)
Node(layer_idx=0, sequence_idx=1, attn_head_idx=4)


# Circuit evaluation

Most of the actual evaluation code is implemented in the IOI repo; we just make calls to convenient functions.

In [None]:
'''
ANTI-CIRCUIT for BABA_TEMPLATES[0]: achieves a -1.90 logits difference, i.e, -0.58 faithfulness
Node(layer_idx=9, sequence_idx=14, attn_head_idx=9)
Node(layer_idx=10, sequence_idx=14, attn_head_idx=6)
Node(layer_idx=9, sequence_idx=11, attn_head_idx=6)
Node(layer_idx=9, sequence_idx=11, attn_head_idx=9)
Node(layer_idx=8, sequence_idx=3, attn_head_idx=10)
Node(layer_idx=8, sequence_idx=11, attn_head_idx=10)
Node(layer_idx=7, sequence_idx=3, attn_head_idx=9)
Node(layer_idx=7, sequence_idx=3, attn_head_idx=3)
Node(layer_idx=6, sequence_idx=3, attn_head_idx=4)
Node(layer_idx=6, sequence_idx=3, attn_head_idx=1)
Node(layer_idx=5, sequence_idx=3, attn_head_idx=10)
Node(layer_idx=4, sequence_idx=3, attn_head_idx=3)
Node(layer_idx=4, sequence_idx=3, attn_head_idx=11)
Node(layer_idx=4, sequence_idx=3, attn_head_idx=4)
Node(layer_idx=3, sequence_idx=3, attn_head_idx=6)
Node(layer_idx=3, sequence_idx=3, attn_head_idx=7)
Node(layer_idx=2, sequence_idx=3, attn_head_idx=2)
Node(layer_idx=2, sequence_idx=3, attn_head_idx=9)
Node(layer_idx=1, sequence_idx=2, attn_head_idx=6)
Node(layer_idx=1, sequence_idx=2, attn_head_idx=7)
Node(layer_idx=0, sequence_idx=2, attn_head_idx=4)
Node(layer_idx=0, sequence_idx=2, attn_head_idx=1)
'''

In [72]:
circuit = [Node(layer_idx=8, sequence_idx=14, attn_head_idx=6),
           Node(layer_idx=8, sequence_idx=11, attn_head_idx=6),
           Node(layer_idx=9, sequence_idx=14, attn_head_idx=9),
           Node(layer_idx=9, sequence_idx=14, attn_head_idx=6),
           Node(layer_idx=5, sequence_idx=10, attn_head_idx=5),
           Node(layer_idx=7, sequence_idx=11, attn_head_idx=9),
           Node(layer_idx=6, sequence_idx=10, attn_head_idx=9),
           Node(layer_idx=6, sequence_idx=11, attn_head_idx=0),
           Node(layer_idx=5, sequence_idx=10, attn_head_idx=9),
           Node(layer_idx=3, sequence_idx=10, attn_head_idx=0),
           Node(layer_idx=4, sequence_idx=5, attn_head_idx=11),
           Node(layer_idx=3, sequence_idx=5, attn_head_idx=7),
           Node(layer_idx=3, sequence_idx=3, attn_head_idx=6),
           Node(layer_idx=2, sequence_idx=3, attn_head_idx=2),
           Node(layer_idx=2, sequence_idx=3, attn_head_idx=9),
           Node(layer_idx=1, sequence_idx=3, attn_head_idx=7),
           Node(layer_idx=1, sequence_idx=3, attn_head_idx=10),
           Node(layer_idx=0, sequence_idx=2, attn_head_idx=1),
           Node(layer_idx=0, sequence_idx=2, attn_head_idx=4)]

In [58]:
random_circuit = random.sample(source_nodes, 20)

In [54]:
# This template definitely has to match the template used in the search above, otherwise, the sequence positions will not be validly interpretable.
test_ioi_dataset = IOIDataset(prompt_type=[template], N=10, tokenizer=model.tokenizer, prepend_bos=False)
test_abc_dataset = (
    test_ioi_dataset.gen_flipped_prompts(("IO", "RAND"))
    .gen_flipped_prompts(("S", "RAND"))
    .gen_flipped_prompts(("S1", "RAND"))
)

circuit = all_nodes

model.reset_hooks(including_permanent=True)
model = add_mean_ablation_hook(model, means_dataset=test_abc_dataset, circuit=circuit) #, circuit=random_circuit)
# model = add_mean_ablation_hook(model, means_dataset=test_abc_dataset)
logits, cache = model.run_with_cache(test_ioi_dataset.toks) # run on entire dataset along batch dimension
ave_logit_diff = logits_to_ave_logit_diff_2(logits, test_ioi_dataset)
print(ave_logit_diff)

tensor(-1.4686, device='cuda:0')


In [19]:
# note: for the following circuit:
'''
Node(layer_idx=9, sequence_idx=14, attn_head_idx=9)
Node(layer_idx=9, sequence_idx=14, attn_head_idx=6)
Node(layer_idx=10, sequence_idx=14, attn_head_idx=0)
Node(layer_idx=8, sequence_idx=1, attn_head_idx=11)
Node(layer_idx=8, sequence_idx=1, attn_head_idx=10)
Node(layer_idx=8, sequence_idx=1, attn_head_idx=2)
Node(layer_idx=7, sequence_idx=1, attn_head_idx=1)
Node(layer_idx=7, sequence_idx=1, attn_head_idx=4)
Node(layer_idx=6, sequence_idx=1, attn_head_idx=4)
Node(layer_idx=6, sequence_idx=1, attn_head_idx=0)
Node(layer_idx=5, sequence_idx=1, attn_head_idx=10)
Node(layer_idx=5, sequence_idx=1, attn_head_idx=2)
Node(layer_idx=5, sequence_idx=1, attn_head_idx=3)
Node(layer_idx=5, sequence_idx=1, attn_head_idx=6)
Node(layer_idx=5, sequence_idx=1, attn_head_idx=9)
Node(layer_idx=4, sequence_idx=1, attn_head_idx=3)
Node(layer_idx=4, sequence_idx=1, attn_head_idx=10)
Node(layer_idx=4, sequence_idx=1, attn_head_idx=9)
Node(layer_idx=1, sequence_idx=1, attn_head_idx=3)
Node(layer_idx=1, sequence_idx=1, attn_head_idx=10)
Node(layer_idx=1, sequence_idx=1, attn_head_idx=4)
Node(layer_idx=0, sequence_idx=1, attn_head_idx=3)
Node(layer_idx=0, sequence_idx=1, attn_head_idx=4)
Node(layer_idx=0, sequence_idx=1, attn_head_idx=5)
'''
# removing just one node, (8, 1, 11), raises the score from -2.1718 to -0.4423.
# this node is not identified by the IOI paper.

# Pruning heuristic

In [None]:
# Prune nodes by greedy search to form a better circuit

NAME_MOVER_HEADS = [Node(9, 14, 9), Node(10, 14, 0), Node(9, 14, 6)]
old_circuit = circuit.copy()
best_score = -1.4686 # 
while True:
    node_to_remove = None
    for idx, node in enumerate(circuit):
        if node in NAME_MOVER_HEADS:
            continue
        new_circuit = circuit.copy()
        new_circuit.remove(node)
        # print(new_circuit)
        model.reset_hooks(including_permanent=True)
        model = add_mean_ablation_hook(model, means_dataset=test_abc_dataset, circuit=new_circuit)
        logits, cache = model.run_with_cache(test_ioi_dataset.toks) # run on entire dataset along batch dimension
        ave_logit_diff = logits_to_ave_logit_diff_2(logits, test_ioi_dataset).cpu().numpy().item()
        if ave_logit_diff > best_score:
            best_score = ave_logit_diff
            node_to_remove = node
            print('tentatively improved score to %f ' % best_score, ' by removing node ', node_to_remove)
    if node_to_remove is None: 
        # then we can't improve any further so the algorithm terminates
        break
    print("removing ", node_to_remove, " to achieve score of %f" % best_score)
    circuit.remove(node_to_remove)
print('Done')

In [47]:
model.reset_hooks(including_permanent=True)
# model = add_mean_ablation_hook(model, means_dataset=test_abc_dataset, circuit=circuit)
model = add_mean_ablation_hook(model, means_dataset=test_abc_dataset, circuit=nodes)
logits, cache = model.run_with_cache(test_ioi_dataset.toks) # run on entire dataset along batch dimension
ave_logit_diff = logits_to_ave_logit_diff_2(logits, test_ioi_dataset)
print(ave_logit_diff)

tensor(3.5994, device='cuda:0')
