In [1]:
%cd /home/aoq559/dev/transformer/eap/edge-attribution-patching

from IPython import get_ipython
ipython = get_ipython()
if ipython is not None:
    ipython.magic("%load_ext autoreload")
    ipython.magic("%autoreload 2")
import torch as t
from torch import Tensor
import einops
from transformer_lens import HookedTransformer
from transformers import AutoTokenizer
import numpy as np
from eapp.eap_wrapper import EAP
from jaxtyping import Float
device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')

/home/aoq559/dev/transformer/eap/edge-attribution-patching


  ipython.magic("%load_ext autoreload")
  ipython.magic("%autoreload 2")


In [2]:
model = HookedTransformer.from_pretrained(
    'EleutherAI/pythia-12b-deduped-v0',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
    n_devices=7,
    move_to_device=True,
    dtype='float16'
)
# model = HookedTransformer.from_pretrained(
#     'gpt-neo-125M',
#     center_writing_weights=False,
#     center_unembed=False,
#     fold_ln=False,
#     device=device,
#     n_devices=5,
#     move_to_device=True,
#     dtype='float16'
# )
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)
model.tokenizer.padding_side = "left"
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-12b-deduped-v0')
print(f"Using tokenizer {tokenizer}")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-12b-deduped-v0 into HookedTransformer
Using tokenizer GPT2TokenizerFast(name_or_path='EleutherAI/gpt-neo-125m', vocab_size=50257, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}


In [3]:
# load data
import yaml
import pickle
import os
class DotDict(dict):
    """ Dot notation access to dictionary attributes """
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
yaml_file_path = "./conf/config.yaml"
with open(yaml_file_path, "r") as f:
    args = DotDict(yaml.safe_load(f))

file_name = args.data_dir
file_name += '/' + str(args.model)
file_name += '/intervention_' + str(args.n_shots) + '_shots_max_' + str(args.max_n) + '_' + args.representation
file_name += '_further_templates' if args.extended_templates else ''
file_name += '_acdc' if args.acdc_data else ''
file_name += '.pkl'

with open(file_name, 'rb') as f:
    intervention_list = pickle.load(f)
print("Loaded data from", file_name)
if args.debug_run:
    intervention_list = intervention_list[:2]

Loaded data from /shared-network/shared/2024_ml_master/data/EleutherAI/pythia-12b-deduped-v0/intervention_1_shots_max_20_arabic_further_templates_acdc.pkl


In [4]:
import intervention_dataset
intervention_data = intervention_dataset.InterventionDataset(intervention_list, device, model.tokenizer)
intervention_data.create_intervention_dataset()
intervention_data.shuffle()

In [5]:
def ave_logit_difference(
    logits: Float[Tensor, 'batch seq d_vocab'],
    intervention_dataset,
    per_prompt: bool = False
):
    batch_size = logits.size(0)
    clean_logits = logits[range(batch_size), -1, intervention_dataset.pred_res_alt_toks[:batch_size]]
    corrupt_logits = logits[range(batch_size), -1, intervention_dataset.res_base_toks[:batch_size]]
    logit_diff = corrupt_logits - clean_logits
    return logit_diff if per_prompt else logit_diff.mean()

def logits_in_batches(model, tokens, attn_mask, bsize):
    model.eval()
    seq_len = tokens.size(0)
    all_logits = []

    with t.no_grad():
        for i in range(0, seq_len, bsize):
            input = tokens[i:i+bsize].to(model.cfg.device)
            attn_mask = attn_mask[i:i+bsize].to(model.cfg.device)
            logits = model(input=input, attention_mask=attn_mask)
            logits = logits.detach().cpu()
            input = input.detach().cpu()
            attn_mask = attn_mask.detach().cpu()
            all_logits.append(logits)
    return t.cat(all_logits, dim=0)

clean_logits = logits_in_batches(model, intervention_data.alt_string_toks, intervention_data.base_attention_mask, 9)
corrupt_logits = logits_in_batches(model, intervention_data.base_string_toks, intervention_data.alt_attention_mask, 9)
clean_logit_diff = ave_logit_difference(clean_logits, intervention_data, per_prompt=False).item()
corrupt_logit_diff = ave_logit_difference(corrupt_logits, intervention_data, per_prompt=False).item()
print(clean_logit_diff)
print(corrupt_logit_diff)

def metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    intervention_dataset: intervention_data = intervention_data,
    per_prompt: bool = False
 ):
    patched_logit_diff = ave_logit_difference(logits, intervention_dataset, per_prompt)
    metric_result = (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)
    return metric_result

with t.no_grad():   
    clean_metric = metric(clean_logits, corrupt_logit_diff, clean_logit_diff, intervention_data, per_prompt = False)
    corrupt_metric = metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, intervention_data, per_prompt = False)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

-0.78173828125
2.037109375
Clean direction: -0.78173828125, Corrupt direction: 2.037109375
Clean metric: 1.0, Corrupt metric: -0.0


In [6]:
# def ave_logit_difference(
#     logits: Float[Tensor, 'batch seq d_vocab'],
#     intervention_dataset,
#     per_prompt: bool = False
# ):
#     batch_size = logits.size(0)
#     clean_logits = logits[range(batch_size), -1, intervention_dataset.res_base_toks[:batch_size]]
#     corrupt_logits = logits[range(batch_size), -1, intervention_dataset.pred_res_alt_toks[:batch_size]]
#     logit_diff = corrupt_logits - clean_logits
#     return logit_diff if per_prompt else logit_diff.mean()
    
    
# with t.no_grad():
#     clean_logits = model(intervention_data.alt_string_toks, 
#                          attention_mask=intervention_data.alt_attention_mask)
#     corrupt_logits = model(intervention_data.base_string_toks,
#                            attention_mask=intervention_data.base_attention_mask)
#     clean_logit_diff = ave_logit_difference(clean_logits, intervention_data, per_prompt=False).item()
#     corrupt_logit_diff = ave_logit_difference(corrupt_logits, intervention_data, per_prompt=False).item()
# print(clean_logit_diff)
# print(corrupt_logit_diff)

# def metric(
#     logits: Float[Tensor, "batch seq_len d_vocab"],
#     corrupted_logit_diff: float = corrupt_logit_diff,
#     clean_logit_diff: float = clean_logit_diff,
#     intervention_dataset: intervention_data = intervention_data,
#     per_prompt: bool = False
#  ):
#     patched_logit_diff = ave_logit_difference(logits, intervention_dataset, per_prompt)
#     return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)


# #Get clean and corrupt logit differences
# with t.no_grad():
#     print(f"clean_logits metric shape {clean_logits.shape}")
#     clean_metric = metric(clean_logits, corrupt_logit_diff, clean_logit_diff, intervention_data, per_prompt = False)
#     corrupt_metric = metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, intervention_data, per_prompt = False)

# print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
# print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

In [7]:
model.reset_hooks()

graph = EAP(
    model,
    intervention_data.base_string_toks,
    intervention_data.alt_string_toks,
    metric,
    upstream_nodes=["mlp", "head"],
    downstream_nodes=["mlp", "head"],
    batch_size=3,
    alt_attention_mask=intervention_data.base_attention_mask,
    base_attention_mask=intervention_data.alt_attention_mask
)

top_edges = graph.top_edges(n=20, abs_scores=True)
for from_edge, to_edge, score in top_edges:
    print(f'{from_edge} -> [{round(score, 3)}] -> {to_edge}')

graph.show()

Saving activations requires 0.0141 GB of memory per token


  0%|          | 0/36 [00:00<?, ?it/s]

  0%|          | 0/36 [00:30<?, ?it/s]

hookname: blocks.35.hook_mlp_in
result shape: torch.Size([1475, 1])
result_over_positions shape: torch.Size([24, 1475, 1])
earlier_upstream_nodes_slice shape: slice(0, 1475, None)
hook_slice shape: slice(4355, 4356, None)
++++++++++++++++++++++++++++++++++++++
torch.Size([1476, 4356])





IndexError: too many indices for tensor of dimension 2