In [1]:
%cd /home/aoq559/dev/transformer/eap/edge-attribution-patching

from IPython import get_ipython
ipython = get_ipython()
if ipython is not None:
    ipython.magic("%load_ext autoreload")
    ipython.magic("%autoreload 2")

import torch

import torch as t
from torch import Tensor
import einops

from transformer_lens import HookedTransformer
from transformers import AutoTokenizer
import numpy as np

from eapp.eap_wrapper import EAP

from jaxtyping import Float

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

/home/aoq559/dev/transformer/eap/edge-attribution-patching


  ipython.magic("%load_ext autoreload")
  ipython.magic("%autoreload 2")


Device: cuda


# Model Setup

In [2]:
# model = HookedTransformer.from_pretrained(
#     'EleutherAI/pythia-12b-deduped-v0',
#     center_writing_weights=False,
#     center_unembed=False,
#     fold_ln=False,
#     device=device,
#     n_devices=7,
#     move_to_device=True,
#     dtype='float16'
# )
model = HookedTransformer.from_pretrained(
    'gpt-neo-125M',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
    n_devices=5,
    move_to_device=True,
    dtype='float16'
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125m')
print(f"Using tokenizer {tokenizer}")

Loaded pretrained model gpt-neo-125M into HookedTransformer
Using tokenizer GPT2TokenizerFast(name_or_path='EleutherAI/gpt-neo-125m', vocab_size=50257, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}


In [3]:
# load data
import yaml
import pickle
import os
class DotDict(dict):
    """ Dot notation access to dictionary attributes """
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
yaml_file_path = "./conf/config.yaml"
with open(yaml_file_path, "r") as f:
    args = DotDict(yaml.safe_load(f))

file_name = args.data_dir
file_name += '/' + str(args.model)
file_name += '/intervention_' + str(args.n_shots) + '_shots_max_' + str(args.max_n) + '_' + args.representation
file_name += '_further_templates' if args.extended_templates else ''
file_name += '.pkl'
print(file_name)

with open(file_name, 'rb') as f:
    intervention_list = pickle.load(f)
print("Loaded data from", file_name)
if args.debug_run:
    intervention_list = intervention_list[:2]

/shared-network/shared/2024_ml_master/data/EleutherAI/gpt-neo-125m/intervention_1_shots_max_20_arabic_further_templates.pkl


Loaded data from /shared-network/shared/2024_ml_master/data/EleutherAI/gpt-neo-125m/intervention_1_shots_max_20_arabic_further_templates.pkl


# Dataset Setup

In [4]:
from ioi_dataset import IOIDataset, format_prompt, make_table
N = 25
clean_dataset = IOIDataset(
    prompt_type='mixed',
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=device
)
corr_dataset = clean_dataset.gen_flipped_prompts('ABC->XYZ, BAB->XYZ')

make_table(
  colnames = ["IOI prompt", "IOI subj", "IOI indirect obj", "ABC prompt"],
  cols = [
    map(format_prompt, clean_dataset.sentences),
    model.to_string(clean_dataset.s_tokenIDs).split(),
    model.to_string(clean_dataset.io_tokenIDs).split(),
    map(format_prompt, clean_dataset.sentences),
  ],
  title = "Sentences from IOI vs ABC distribution",
)

# Metric Setup

In [5]:
intervention = intervention_list[0]
print(intervention.res_base_tok)
intervention.__dict__

[1315]


{'op3_pos': 18,
 'operator_word': None,
 'operands_alt': '2 7 6',
 'operands_base': '2 7 6',
 'operator_pos': None,
 'op2_pos': 16,
 'op1_pos': 14,
 'res_alt_tok': [1315],
 'res_base_tok': [1315],
 'res_string': None,
 'res_base_string': '15',
 'res_alt_string': '15',
 'device': 'cpu',
 'multitoken': False,
 'is_llama': False,
 'is_opt': False,
 'is_bloom': False,
 'is_mistral': False,
 'is_persimmon': False,
 'representation': 'arabic',
 'extended_templates': True,
 'template_id': '-',
 'n_vars': 2,
 'base_string': 'The result of 2 + 7 + 6 =',
 'alt_string': 'The result of 2 + 7 + 6 =',
 'few_shots': 'The result of 7 + 5 + 3 = 15. ',
 'few_shots_t2': ' ',
 'equation': '({x}+{y} + {z})',
 'enc': GPT2Tokenizer(name_or_path='EleutherAI/gpt-neo-125m', vocab_size=50257, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=Tr

In [6]:
intervention = intervention_list[0]
with torch.no_grad():
    clean_logits = model(intervention.base_string_tok).cpu().numpy()
    print(clean_logits)
    corrupt_logits = model(intervention.alt_string_tok).cpu().numpy()
    clean_logits_argmax = np.argmax(clean_logits, axis=2)[0, -1]
    print(clean_logits_argmax)
    clean_logit = clean_logits[:, -1, intervention.res_base_tok[0]]
    corrupt_logit = corrupt_logits[:, -1, intervention.pred_res_alt_tok[0]]
    logit_diff = corrupt_logit - clean_logit
    print(clean_logit)
    print(corrupt_logit)

    
    # clean_logit_argmax = torch.argmax(clean_logits, dim=2)
    # corrupt_logit_argmax = torch.argmax(corrupt_logits, dim=2)
    # print(clean_logit_argmax)
    # next_word_index = clean_logit_argmax[0][-1]
    # next_word = model.tokenizer.convert_ids_to_tokens(next_word_index.item())
    # print(next_word)
    
    # clean_prediction = torch.argmax(clean_logits, dim=2)
    # print(clean_prediction)
    # next_word_index = clean_prediction[0][-1]
    # next_word = tokenizer.convert_ids_to_tokens(next_word_index.item())
    # print(next_word_index)
    # print(next_word)
    # print(torch.argmax(clean_logits[:, -1, :]))
    # print(corrupt_logits.shape)

[[[ -7.465  -6.3    -8.26  ... -13.19   -9.43   -6.61 ]
  [ -7.74  -10.9   -15.34  ... -20.4   -18.86  -14.484]
  [-12.34  -10.46  -13.24  ... -14.28  -12.445 -10.54 ]
  ...
  [-13.74  -12.58  -14.51  ... -22.11  -15.08  -10.7  ]
  [-11.086 -12.234 -14.95  ... -23.89  -17.08  -11.43 ]
  [-11.68  -11.29  -13.55  ... -18.84  -12.17   -7.54 ]]]
1315
[-0.77]
[0.03464]


In [7]:
import intervention_dataset
intervention_data = intervention_dataset.InterventionDataset(intervention_list, device, model.tokenizer)
intervention_data.create_intervention_dataset()

In [8]:
for batch_idx, batch in enumerate(intervention_data.base_string_toks):
    batch = torch.vstack(batch)
    print(batch_idx, batch)
    break

0 tensor([[  17, 1635,  362, 1635,  513,  796, 1105,   13,  362, 1635,  362, 1635,
          513,  796],
        [  17, 1635,  362, 1635,  362,  796,  807,   13,  362, 1635,  362, 1635,
          362,  796],
        [  17, 1635,  362, 1635,  604,  796, 1467,   13,  362, 1635,  604, 1635,
          362,  796],
        [  18, 1635,  513, 1635,  362,  796, 1248,   13,  362, 1635,  513, 1635,
          513,  796],
        [  17, 1635,  604, 1635,  362,  796, 1467,   13,  604, 1635,  362, 1635,
          362,  796],
        [  19, 1635,  362, 1635,  362,  796, 1467,   13,  362, 1635,  604, 1635,
          362,  796],
        [  18, 1635,  362, 1635,  513,  796, 1248,   13,  513, 1635,  513, 1635,
          362,  796],
        [  18, 1635,  513, 1635,  362,  796, 1248,   13,  513, 1635,  513, 1635,
          362,  796],
        [  19, 1635,  362, 1635,  362,  796, 1467,   13,  362, 1635,  362, 1635,
          604,  796]], device='cuda:0')


In [9]:
example_tensor = t.tensor([[  17, 1635,  362, 1635,  513,  796, 1105,   13,  362, 1635,  362, 1635,
          513,  796],
        [  17, 1635,  362, 1635,  362,  796,  807,   13,  362, 1635,  362, 1635,
          362,  796]])
example_tensor = example_tensor.to(device)

In [49]:
clean_logits_example = model(example_tensor)
print(clean_logits_example.shape)
print(ave_logit_difference(clean_logits_example, intervention_data, False))
clean_logit_diff_example = ave_logit_difference(clean_logits_example, intervention_data, per_prompt=False).item()
print(f"clean_logit_diff_example: {clean_logit_diff_example}")

torch.Size([2, 14, 50257])
tensor(-1.7578, device='cuda:5', dtype=torch.float16, grad_fn=<MeanBackward0>)
clean_logit_diff_example: -1.7578125


In [12]:
def ave_logit_difference(
    logits: Float[Tensor, 'batch seq d_vocab'],
    intervention_dataset,
    per_prompt: bool = False
):
    batch_size = logits.size(0)
    #print(f"batch_size: {batch_size}")
    #print(f"intervention_dataset.res_base_toks[:batch_size]: {intervention_dataset.res_base_toks[:batch_size]}")
    clean_logits = logits[range(batch_size), -1, intervention_dataset.res_base_toks[:batch_size]]
    corrupt_logits = logits[range(batch_size), -1, intervention_dataset.pred_res_alt_toks[:batch_size]]
    logit_diff = corrupt_logits - clean_logits
    return logit_diff if per_prompt else logit_diff.mean()

def process_batches_and_compute_logit_difference(model, intervention_dataset, tokens, per_prompt=True):
    logit_differences = []
    logits = []

    # Iterate over each group of batches
    for batch_idx, base_string_tok in enumerate(tokens):
        with t.no_grad():
            tok = t.vstack(base_string_tok)
            logit = model(tok)
            logit_diff = ave_logit_difference(logit, intervention_dataset, per_prompt)
            logit_differences.append(logit_diff)
            logits.append(logit)
    return logits, logit_differences

with t.no_grad():
    clean_logits, clean_logit_diff = process_batches_and_compute_logit_difference(model, intervention_data,
                                                                intervention_data.base_string_toks, per_prompt=True)
    corrupt_logits, corrupt_logit_diff = process_batches_and_compute_logit_difference(model, intervention_data,
                                                                intervention_data.alt_string_toks, per_prompt=True)
    print(len(clean_logits[1]))
    print(f"clean_logits 1 shape {clean_logits[1].shape}")
    print(f"clean_logits_1 diff {clean_logit_diff[1]}")

    #clean_logits = t.cat([t.flatten(diff) for diff in clean_logits])
    #corrupt_logits = t.cat([t.flatten(diff) for diff in corrupt_logits])
    #clean_logit_diff = t.mean(t.flatten(clean_logits)).item()
    #corrupt_logit_diff = t.mean(t.flatten(corrupt_logits)).item()
    print(clean_logit_diff)
    print(corrupt_logit_diff)

    def batched_metric(
        logits: Float[Tensor, "batch seq_len d_vocab"],
        corrupted_logit_diff = corrupt_logit_diff,
        clean_logit_diff = clean_logit_diff,
        intervention_dataset: intervention_data = intervention_data,
        per_prompt: bool = True
    ):
        metrics = []
        for batch_idx, logit in enumerate(logits):
            patched_logit_diff = ave_logit_difference(logit, intervention_dataset, per_prompt)
            metrics.append((patched_logit_diff - corrupted_logit_diff[batch_idx]) / (clean_logit_diff[batch_idx] - corrupted_logit_diff[batch_idx]))
        
        metrics = t.cat([t.flatten(metric) for metric in metrics])
        mean_metric = t.mean(metrics)#.detach().cpu().item()

        return mean_metric

    with t.no_grad():
        clean_metrics = batched_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, intervention_data, per_prompt = True)
        corrupt_metrics = batched_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, intervention_data, per_prompt = True)
        print(f"clean metrics: {clean_metrics}")
        print(f"corrupt metrics: {corrupt_metrics}")
    
    
# with t.no_grad():
#     clean_logits = model(intervention_data.alt_string_toks)
#     corrupt_logits = model(intervention_data.base_string_toks)
#     clean_logit_diff = ave_logit_difference(clean_logits, intervention_data, per_prompt=False).item()
#     corrupt_logit_diff = ave_logit_difference(corrupt_logits, intervention_data, per_prompt=False).item()
# print(clean_logit_diff)
# print(corrupt_logit_diff)

def metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    intervention_dataset: intervention_data = intervention_data,
    per_prompt: bool = False
 ):
    patched_logit_diff = ave_logit_difference(logits, intervention_dataset, per_prompt)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)


# Get clean and corrupt logit differences
# with t.no_grad():
#     print(f"clean_logits metric shape {clean_logits.shape}")
#     clean_metric = metric(clean_logits, corrupt_logit_diff, clean_logit_diff, intervention_data, per_prompt = True)
#     corrupt_metric = metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, intervention_data)

# print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
# print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

22
clean_logits 1 shape torch.Size([22, 20, 50257])
clean_logits_1 diff tensor([-1.7666, -1.4717, -1.6963, -1.8682, -6.6523,  5.2578, -1.9863, -1.8691,
        -0.3564, -0.7568, -0.3042,  1.1621, -0.2988, -0.0527, -3.8613, -5.2930,
        -4.7852, -6.2031, -4.1953, -5.2344, -5.2930, -4.2227], device='cuda:5',
       dtype=torch.float16)
[tensor([-1.5996, -1.9082, -2.8613, -1.4355, -4.3906, -3.3340, -3.3262, -3.9297,
        -2.6582], device='cuda:5', dtype=torch.float16), tensor([-1.7666, -1.4717, -1.6963, -1.8682, -6.6523,  5.2578, -1.9863, -1.8691,
        -0.3564, -0.7568, -0.3042,  1.1621, -0.2988, -0.0527, -3.8613, -5.2930,
        -4.7852, -6.2031, -4.1953, -5.2344, -5.2930, -4.2227], device='cuda:5',
       dtype=torch.float16), tensor([-2.4043, -0.1768,  2.7695, -0.7246, -4.4258,  1.8838, -7.4297, -4.7969,
        -0.8564], device='cuda:5', dtype=torch.float16)]
[tensor([3.8750, 3.5566, 3.0293, 4.8672, 3.4883, 3.0293, 5.3125, 5.3125, 2.8633],
       device='cuda:5', dtype=torc

In [14]:
def ave_logit_diff(
    logits: Float[Tensor, 'batch seq d_vocab'],
    ioi_dataset: IOIDataset,
    per_prompt: bool = False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    # Get logits for indirect objects
    batch_size = logits.size(0)
    # logits[batch_size, last_position, ID of IO]
    io_logits = logits[range(batch_size), ioi_dataset.word_idx['end'][:batch_size], ioi_dataset.io_tokenIDs[:batch_size]]
    s_logits = logits[range(batch_size), ioi_dataset.word_idx['end'][:batch_size], ioi_dataset.s_tokenIDs[:batch_size]]
    # Get logits for subject
    logit_diff = io_logits - s_logits
    return logit_diff if per_prompt else logit_diff.mean()

with t.no_grad():
    clean_logits = model(clean_dataset.toks)
    corrupt_logits = model(corr_dataset.toks)
    clean_logit_diff = ave_logit_diff(clean_logits, clean_dataset).item() # logit difference for clean run between correct and incorrect answer
    corrupt_logit_diff = ave_logit_diff(corrupt_logits, corr_dataset).item() # logit difference for corrupt run between correct and incorrect answer

def ioi_metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    ioi_dataset: IOIDataset = clean_dataset
 ):
    patched_logit_diff = ave_logit_diff(logits, ioi_dataset)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

def negative_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -ioi_metric(logits)
    
# Get clean and corrupt logit differences
with t.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, clean_dataset)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, corr_dataset)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

Clean direction: 2.982421875, Corrupt direction: 2.83984375
Clean metric: 1.0, Corrupt metric: 0.0


# Run Experiment

In [15]:
model.reset_hooks()

graph = EAP(
    model,
    clean_dataset.toks,
    corr_dataset.toks,
    ioi_metric,
    upstream_nodes=["mlp", "head"],
    downstream_nodes=["mlp", "head"],
    batch_size=25
)

top_edges = graph.top_edges(n=10, abs_scores=True)
for from_edge, to_edge, score in top_edges:
    print(f'{from_edge} -> [{round(score, 3)}] -> {to_edge}')

Saving activations requires 0.0002 GB of memory per token


  0%|          | 0/1 [00:00<?, ?it/s]

metric <function ioi_metric at 0x7f645f1ec4c0>
value: 1.0
model config: HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': ['global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local'],
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float16,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'model_name': 'gpt-neo-125M',
 'n_ctx': 2048,
 'n_devices': 5,
 'n_heads': 12,
 'n_key_

100%|██████████| 1/1 [00:04<00:00,  4.58s/it]


mlp.0 -> [0.143] -> head.9.4.v
mlp.0 -> [0.141] -> mlp.2
head.0.2 -> [-0.098] -> mlp.0
mlp.0 -> [-0.098] -> mlp.5
head.0.2 -> [0.086] -> head.9.4.v
head.1.11 -> [0.086] -> head.9.4.k
mlp.0 -> [-0.081] -> head.9.4.k
head.0.3 -> [-0.072] -> mlp.0
head.0.2 -> [-0.067] -> mlp.2
head.0.5 -> [0.061] -> head.9.4.v


In [None]:
graph.show()

In [36]:
for i in intervention_data.base_string_toks:
    base_tokens = t.vstack(i)
for i in base_tokens:
    i.to(device)
base_tokens

tensor([[ 464, 1255,  286,  357, 1315,  532, 1478, 1267, 1635,  860,  796,  860,
           13,  383, 1255,  286,  357, 1315,  532, 1478, 1267, 1635, 1105,  796],
        [ 464, 1255,  286,  357,  807,  532,  767, 1267, 1635,  807,  796,  807,
           13,  383, 1255,  286,  357,  678,  532, 1248, 1267, 1635,  362,  796],
        [ 464, 1255,  286,  357, 1248,  532, 1596, 1267, 1635,  718,  796,  718,
           13,  383, 1255,  286,  357,  718,  532,  642, 1267, 1635,  513,  796],
        [ 464, 1255,  286,  357,  513,  532,  362, 1267, 1635, 1467,  796, 1467,
           13,  383, 1255,  286,  357, 1160,  532,  678, 1267, 1635, 1467,  796],
        [ 464, 1255,  286,  357, 1160,  532,  678, 1267, 1635,  718,  796,  718,
           13,  383, 1255,  286,  357,  718,  532,  642, 1267, 1635, 1478,  796],
        [ 464, 1255,  286,  357,  860,  532,  807, 1267, 1635,  860,  796,  860,
           13,  383, 1255,  286,  357, 1511,  532, 1105, 1267, 1635,  362,  796],
        [ 464, 1255,  

In [27]:
model.reset_hooks()

graph = EAP(
    model,
    intervention_data.alt_string_toks[0],
    intervention_data.base_string_toks[0],
    batched_metric,
    upstream_nodes=["mlp", "head"],
    downstream_nodes=["mlp", "head"],
    batch_size=9
)

top_edges = graph.top_edges(n=10, abs_scores=True)
for from_edge, to_edge, score in top_edges:
    print(f'{from_edge} -> [{round(score, 3)}] -> {to_edge}')

Saving activations requires 0.0002 GB of memory per token


AttributeError: 'list' object has no attribute 'shape'

In [None]:
top_edges

In [None]:
graph.show()