In [1]:
from IPython import get_ipython
ipython = get_ipython()
if ipython is not None:
    ipython.magic("%load_ext autoreload")
    ipython.magic("%autoreload 2")

import os
import sys
sys.path.append('../Automatic-Circuit-Discovery/')
sys.path.append('..')
import torch
import re

import acdc
from utils.prune_utils import get_3_caches, split_layers_and_heads
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

  ipython.magic("%load_ext autoreload")
  ipython.magic("%autoreload 2")


Device: cuda


# Model Setup

In [2]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Dataset Setup

In [3]:
from ioi_dataset import IOIDataset, format_prompt, make_table
N = 25
clean_dataset = IOIDataset(
    prompt_type='mixed',
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=device
)
corr_dataset = clean_dataset.gen_flipped_prompts('ABC->XYZ, BAB->XYZ')

make_table(
  colnames = ["IOI prompt", "IOI subj", "IOI indirect obj", "ABC prompt"],
  cols = [
    map(format_prompt, clean_dataset.sentences),
    model.to_string(clean_dataset.s_tokenIDs).split(),
    model.to_string(clean_dataset.io_tokenIDs).split(),
    map(format_prompt, clean_dataset.sentences),
  ],
  title = "Sentences from IOI vs ABC distribution",
)

# Metric Setup

In [4]:
def ave_logit_diff(
    logits: Float[Tensor, 'batch seq d_vocab'],
    ioi_dataset: IOIDataset,
    per_prompt: bool = False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    # Get logits for indirect objects
    io_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.io_tokenIDs]
    s_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.s_tokenIDs]
    # Get logits for subject
    logit_diff = io_logits - s_logits
    return logit_diff if per_prompt else logit_diff.mean()

with t.no_grad():
    clean_logits = model(clean_dataset.toks)
    corrupt_logits = model(corr_dataset.toks)
    clean_logit_diff = ave_logit_diff(clean_logits, clean_dataset).item()
    corrupt_logit_diff = ave_logit_diff(corrupt_logits, corr_dataset).item()

def ioi_metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    ioi_dataset: IOIDataset = clean_dataset
 ):
    patched_logit_diff = ave_logit_diff(logits, ioi_dataset)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

def negative_abs_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -abs(ioi_metric(logits))
    
# Get clean and corrupt logit differences
with t.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, clean_dataset)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, corr_dataset)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

Clean direction: 2.805178165435791, Corrupt direction: 1.7939791679382324
Clean metric: 1.0, Corrupt metric: 0.0


# Run Experiment

In [7]:
from ACDCPPExperiment import ACDCPPExperiment
import numpy as np
THRESHOLDS = [0.077153]#[-10] #np.logspace(-5, 1, num=20, base=5)
# I'm just using one threshold so I can move fast!

model.reset_hooks()
RUN_NAME = 'acdcpp_edges'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=clean_dataset.toks,
    corr_data=corr_dataset.toks,
    acdc_metric=negative_abs_ioi_metric,
    acdcpp_metric=ioi_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=0,
    run_acdcpp=True,
    run_acdc=False,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
)

pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])



Edge pruning:   0%|          | 0/1034 [00:00<?, ?it/s][A
Edge pruning:  13%|█▎        | 138/1034 [00:00<00:00, 1370.05it/s][A
Edge pruning:  27%|██▋       | 276/1034 [00:00<00:00, 1129.88it/s][A
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 2502.47it/s][A

Edge pruning:   0%|          | 0/1034 [00:00<?, ?it/s][A
Edge pruning:   5%|▌         | 54/1034 [00:00<00:02, 469.05it/s][A
Edge pruning:  10%|▉         | 101/1034 [00:00<00:07, 122.10it/s][A
Edge pruning:  14%|█▍        | 143/1034 [00:00<00:05, 171.52it/s][A
Edge pruning:  17%|█▋        | 175/1034 [00:01<00:07, 111.84it/s][A
Edge pruning:  22%|██▏       | 229/1034 [00:01<00:04, 168.80it/s][A
Edge pruning:  25%|██▌       | 262/1034 [00:01<00:06, 123.92it/s][A
Edge pruning:  31%|███       | 316/1034 [00:02<00:04, 175.82it/s][A
Edge pruning:  34%|███▍      | 349/1034 [00:02<00:05, 136.47it/s][A
Edge pruning:  39%|███▉      | 403/1034 [00:02<00:03, 187.69it/s][A
Edge pruning:  42%|████▏     | 437/1034 [00:02<00:03

In [None]:
import json
with open(f'{RUN_NAME}_acdcpp_scores.json', 'w') as f:
    json.dump(acdcpp_pruned_attrs, f)

In [None]:
import json
with open(f'{RUN_NAME}_acdc_pruned_attrs.json', 'w') as f:
    json.dump(acdc_pruned_attrs, f)
with open(f'{RUN_NAME}_num_passes.json', 'w') as f:
    json.dump(num_passes, f)

In [7]:
np.quantile(list(acdcpp_pruned_attrs[0.5].values()), [0, 0.25, 0.5, 0.75, 0.9, 1])

array([0.00000000e+00, 8.23954397e-05, 2.85615824e-04, 9.29769565e-04,
       3.00903956e-03, 1.31640387e+00])

In [None]:
import json

for thresh in pruned_heads.keys():
    pruned_heads[thresh][0] = list(pruned_heads[thresh][0])
    pruned_heads[thresh][1] = list(pruned_heads[thresh][1])
    
cleaned_pp_attrs = {}
for thresh in acdcpp_pruned_attrs.keys():
    cleaned_pp_attrs[thresh] = {}
    for (parent, child), attr in acdcpp_pruned_attrs[thresh].items():
        cleaned_pp_attrs[thresh]\
        [f'{child.hook_point_name}{child.replace_parens(child.index)}{parent.hook_point_name}{parent.replace_parens(parent.index)}'] = attr
        
# Cleaning the edges
for thresh in edges_after_acdcpp.keys():
    edges_after_acdcpp[thresh] = list(edges_after_acdcpp[thresh])
    edges_after_acdc[thresh] = list(edges_after_acdc[thresh])
    
APPEND = False

if APPEND:
    with open(f'{RUN_NAME}_pruned_heads.json', 'r') as f:
        old_pruned_heads = json.load(f)
    with open(f'{RUN_NAME}_acdcpp_pruned_attrs.json', 'r') as f:
        old_acdcpp_pruned_attrs = json.load(f)
    with open(f'{RUN_NAME}_acdc_pruned_attrs.json', 'r') as f:
        old_acdc_pruned_attrs = json.load(f)
    with open(f'{RUN_NAME}_num_passes.json', 'r') as f:
        old_num_passes = json.load(f)
    with open(f'{RUN_NAME}_edges_after_acdcpp.json', 'r') as f:
        old_edges_after_acdcpp = json.load(f)
    with open(f'{RUN_NAME}_edges_after_acdc.json', 'r') as f:
        old_edges_after_acdc = json.load(f)
    
    pruned_heads = pruned_heads.update(old_pruned_heads)
    cleaned_pp_attrs = cleaned_pp_attrs.update(old_acdcpp_pruned_attrs)
    acdc_pruned_attrs = acdc_pruned_attrs.update(old_acdc_pruned_attrs)
    num_passes = num_passes.update(old_num_passes)
    edges_after_acdcpp = edges_after_acdcpp.update(old_edges_after_acdcpp)
    edges_after_acdc = edges_after_acdc.update(old_edges_after_acdc)

with open(f'{RUN_NAME}_pruned_heads.json', 'w') as f:
    json.dump(pruned_heads, f)
with open(f'{RUN_NAME}_num_passes.json', 'w') as f:
    json.dump(num_passes, f)
with open(f'{RUN_NAME}_acdcpp_pruned_attrs.json', 'w') as f:
    json.dump(cleaned_pp_attrs, f)
with open(f'{RUN_NAME}_acdc_pruned_attrs.json', 'w') as f:
    json.dump(acdc_pruned_attrs, f)

with open(f'{RUN_NAME}_edges_after_acdcpp.json', 'w') as f:
    json.dump(edges_after_acdcpp, f)
with open(f'{RUN_NAME}_edges_after_acdc.json', 'w') as f:
    json.dump(edges_after_acdc, f)

In [None]:
# get the 2 fwd and 1 bwd caches; cache "normalized" and "result" of attn layers
clean_cache, corrupted_cache, clean_grad_cache = get_3_caches(
    model, 
    clean_dataset.toks,
    corr_dataset.toks,
    metric=negative_ioi_metric,
    mode = "edge",
)

In [None]:
clean_head_act = split_layers_and_heads(clean_cache.stack_head_results(), model=model)
corr_head_act = split_layers_and_heads(corrupted_cache.stack_head_results(), model=model)

In [None]:
stacked_grad_act = torch.zeros(
    3, # QKV
    model.cfg.n_layers,
    model.cfg.n_heads,
    clean_head_act.shape[-3], # Batch
    clean_head_act.shape[-2], # Seq
    clean_head_act.shape[-1], # D
)

for letter_idx, letter in enumerate("qkv"):
    for layer_idx in range(model.cfg.n_layers):
        stacked_grad_act[letter_idx, layer_idx] = einops.rearrange(clean_grad_cache[f"blocks.{layer_idx}.hook_{letter}_input"], "batch seq n_heads d -> n_heads batch seq d")

In [None]:
results = {}

for upstream_layer_idx in range(model.cfg.n_layers):
    for upstream_head_idx in range(model.cfg.n_heads):
        for downstream_letter_idx, downstream_letter in enumerate("qkv"):
            for downstream_layer_idx in range(upstream_layer_idx+1, model.cfg.n_layers):
                for downstream_head_idx in range(model.cfg.n_heads):
                    results[
                        (
                            upstream_layer_idx,
                            upstream_head_idx,
                            downstream_letter,
                            downstream_layer_idx,
                            downstream_head_idx,
                        )
                    ] = (stacked_grad_act[downstream_letter_idx, downstream_layer_idx, downstream_head_idx].cpu() * (clean_head_act[upstream_layer_idx, upstream_head_idx] - corr_head_act[upstream_layer_idx, upstream_head_idx]).cpu()).sum()

In [None]:
sorted_results = sorted(results.items(), key=lambda x: x[1].abs(), reverse=True)

In [None]:
print("Top 10 most important edges:")
for i in range(10):
    print(
        f"{sorted_results[i][0][0]}:{sorted_results[i][0][1]} -> {sorted_results[i][0][3]}:{sorted_results[i][0][4]}",
    )