In [1]:
import sys
sys.path.append('../Automatic-Circuit-Discovery/')
sys.path.append('..')

from acdc.greaterthan.utils import get_all_greaterthan_things
from ACDCPPExperiment import ACDCPPExperiment
from transformer_lens import HookedTransformer

import numpy as np
import torch as t
import tqdm.notebook as tqdm
import json

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cuda


# Model Setup

In [2]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Dataset Setup

In [3]:
# Make clean dataset and reference dataset
N = 25

things = get_all_greaterthan_things(
    num_examples=N, metric_name="greaterthan", device=device
)
greaterthan_metric = things.validation_metric
toks_int_values = things.validation_data # clean data x_i
toks_int_values_other = things.validation_patch_data # corrupted data x_i'

print("\nClean dataset samples")
for i in range(5):
    print(model.tokenizer.decode(toks_int_values[i]))

print("\nReference dataset samples")
for i in range(5):
    print(model.tokenizer.decode(toks_int_values_other[i]))

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda


Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda

Clean dataset samples
The demonstrations lasted from the year 1267 to 12
The assaults lasted from the year 1644 to 16
The affair lasted from the year 1268 to 12
The stature lasted from the year 1653 to 16
The effort lasted from the year 1318 to 13

Reference dataset samples
The demonstrations lasted from the year 1201 to 12
The assaults lasted from the year 1601 to 16
The affair lasted from the year 1201 to 12
The stature lasted from the year 1601 to 16
The effort lasted from the year 1301 to 13


# Run Experiment

In [11]:
from ACDCPPExperiment import ACDCPPExperiment
import numpy as np
THRESHOLDS = [0.008828]#np.logspace(-4, 1, num=20, base=5)
# I'm just using one threshold so I can move fast!

model.reset_hooks()
RUN_NAME = 'acdcpp_edges'
acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=toks_int_values,
    corr_data=toks_int_values_other,
    acdc_metric=greaterthan_metric,
    acdcpp_metric=greaterthan_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=0,
    run_acdcpp=True,
    run_acdc=False,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
)

pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])



Edge pruning:   0%|          | 0/1034 [00:00<?, ?it/s][A
Edge pruning:  13%|█▎        | 139/1034 [00:00<00:00, 1374.79it/s][A
Edge pruning:  27%|██▋       | 277/1034 [00:00<00:00, 1141.52it/s][A
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 2501.39it/s][A

Edge pruning:   0%|          | 0/1034 [00:00<?, ?it/s][A
Edge pruning:   5%|▌         | 54/1034 [00:00<00:02, 473.19it/s][A
Edge pruning:  10%|▉         | 102/1034 [00:00<00:07, 122.09it/s][A
Edge pruning:  14%|█▍        | 143/1034 [00:00<00:05, 169.23it/s][A
Edge pruning:  17%|█▋        | 174/1034 [00:01<00:07, 111.45it/s][A
Edge pruning:  22%|██▏       | 228/1034 [00:01<00:04, 168.59it/s][A
Edge pruning:  25%|██▌       | 261/1034 [00:01<00:06, 121.48it/s][A
Edge pruning:  31%|███       | 316/1034 [00:02<00:04, 172.77it/s][A
Edge pruning:  34%|███▍      | 349/1034 [00:02<00:05, 135.18it/s][A
Edge pruning:  39%|███▉      | 403/1034 [00:02<00:03, 185.62it/s][A
Edge pruning:  42%|████▏     | 436/1034 [00:02<00:03

In [5]:
acdcpp_pruned_attrs

{-10: {'blocks.11.hook_resid_post[:]blocks.1.attn.hook_result[:, :, 8]': 9.101534669753164e-05,
  'blocks.11.hook_resid_post[:]blocks.2.attn.hook_result[:, :, 2]': -5.7828328863251954e-05,
  'blocks.11.hook_resid_post[:]blocks.5.attn.hook_result[:, :, 2]': 2.9632432415382937e-05,
  'blocks.11.hook_resid_post[:]blocks.10.attn.hook_result[:, :, 5]': -2.0922090698149987e-05,
  'blocks.11.hook_resid_post[:]blocks.11.attn.hook_result[:, :, 2]': 0.0011569790076464415,
  'blocks.11.hook_resid_post[:]blocks.4.attn.hook_result[:, :, 6]': -0.00021867756731808186,
  'blocks.11.hook_resid_post[:]blocks.10.attn.hook_result[:, :, 7]': -0.028109099715948105,
  'blocks.11.hook_resid_post[:]blocks.5.attn.hook_result[:, :, 8]': -0.00023876517661847174,
  'blocks.11.hook_resid_post[:]blocks.8.attn.hook_result[:, :, 7]': 5.55131264263764e-05,
  'blocks.11.hook_resid_post[:]blocks.6.attn.hook_result[:, :, 7]': -3.374364314367995e-05,
  'blocks.11.hook_resid_post[:]blocks.8.attn.hook_result[:, :, 0]': -0.00

In [7]:
import json
with open(f'{RUN_NAME}_acdcpp_scores.json', 'w') as f:
    json.dump(acdcpp_pruned_attrs, f)

# Save Data

In [None]:
def convert_to_torch_index(index_list):
    return ''.join(['None' if i == ':' else i for i in index_list])

for thresh in pruned_heads.keys():
    pruned_heads[thresh][0] = list(pruned_heads[thresh][0])
    pruned_heads[thresh][1] = list(pruned_heads[thresh][1])

cleaned_attrs = {}
for thresh in pruned_attrs.keys():
    cleaned_attrs[thresh] = []
    for ((e1, i1), (e2, i2)), attr in pruned_attrs[thresh].items():
        cleaned_attrs[thresh].append([e1, convert_to_torch_index(str(i1)), e2, convert_to_torch_index(str(i2)), attr])
        
with open(f'{RUN_NAME}_pruned_heads.json', 'w') as f:
    json.dump(pruned_heads, f)
with open(f'{RUN_NAME}_num_passes.json', 'w') as f:
    json.dump(num_passes, f)
with open(f'{RUN_NAME}_pruned_attrs.json', 'w') as f:
    json.dump(cleaned_attrs, f)