In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import numpy as np
import os
import pandas as pd
import scipy as sp
import sys
import torch
import torch.nn.functional as F
import warnings
import random
import collections

# CD-T Imports
import math
import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import itertools

from torch import nn

warnings.filterwarnings("ignore")

base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)

from argparse import Namespace
from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars
from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals
from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels
from pyfunctions.patch_hh import *
from pyfunctions.ioi_dataset import IOIDataset
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AutoTokenizer, AutoModel
from transformers import GPT2Tokenizer, GPT2Model

In [3]:
torch.autograd.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x77cb334f2cf0>

## Load Model

Note: Unlike with the BERT model + medical dataset objective, it is not necessary to pretrain GPT-2 to perform the IOI dataset.
GPT-2-small is already capable of performing IOI; that's part of the point of the Mech Interp in the Wild paper.
We only need to examine how it does it.

In [4]:
device = 'cuda:0'

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')

In [6]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [8]:
import inspect
# inspect.getclasstree(inspect.getmro(type(model)))

[(object, ()),
 [(torch.nn.modules.module.Module, (object,)),
  [(transformers.modeling_utils.PreTrainedModel,
    (torch.nn.modules.module.Module,
     transformers.modeling_utils.ModuleUtilsMixin,
     transformers.generation.utils.GenerationMixin,
     transformers.utils.hub.PushToHubMixin,
     transformers.integrations.peft.PeftAdapterMixin)),
   [(transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel,
     (transformers.modeling_utils.PreTrainedModel,)),
    [(transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel,
      (transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel,))]]],
  (transformers.generation.utils.GenerationMixin, (object,)),
  [(transformers.modeling_utils.PreTrainedModel,
    (torch.nn.modules.module.Module,
     transformers.modeling_utils.ModuleUtilsMixin,
     transformers.generation.utils.GenerationMixin,
     transformers.utils.hub.PushToHubMixin,
     transformers.integrations.peft.PeftAdapterMixin)),
   [(transformers.models.gpt2.modeling_gpt2.

In [35]:
text = "Replace me by any text you'd"
input = tokenizer(text, return_tensors='pt').input_ids
# print(encoded_input) # has 'input_idx' and 'attention_mask'
# output = model(input)
# print(output.last_hidden_state.shape)
gen_tokens = model.generate(input, pad_token_id=tokenizer.pad_token_id, output_scores=True)
print(gen_tokens)
gen_text = tokenizer.batch_decode(gen_tokens)
gen_text

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,  198,
           40, 1101,  407, 1654,  611,  345,  821, 3910]])




["Replace me by any text you'd like.\n\nI'm not sure if you're aware"]

In [21]:
# other exploratory stuff
# model.to_tokens(text) #turns out this is a utility of trasnformer_lens
# print(output.past_key_values[0][0].shape) # this has to do with key matrix stuff
#print(output.values())
#output.logits

AttributeError: 'BaseModelOutputWithPastAndCrossAttentions' object has no attribute 'logits'

In [21]:
!pip install datasets

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (3.3 kB)
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)
Collecting tqdm>=4.66.3 (from datasets)
  Downloading tqdm-4.66.4-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.6/57.6 kB[0m [31m320.7 kB/s[0m eta [36m0:00:00[0m:--:--[0m
[?25hCollecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collect

In [22]:
from datasets import load_dataset
ioi_dataset = load_dataset("fahamu/ioi")
# i've decided against using this for the most part; it's better to use the raw IOIDataset 
# from the paper and from the related notebook using EasyTransformers, since these both provide many utilities for dealing with the data

Downloading readme: 100%|██████████| 3.46k/3.46k [00:00<00:00, 10.5MB/s]
Downloading data: 100%|██████████| 5.84M/5.84M [00:01<00:00, 4.59MB/s]
Downloading data: 100%|██████████| 756M/756M [01:27<00:00, 8.65MB/s] 
Generating train split: 100%|██████████| 26210000/26210000 [00:03<00:00, 7827513.55 examples/s]


In [37]:
!pip install transformer_lens
!pip install einops

python(5228) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




python(5229) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [5]:
# Model code adapted from Callum McDougall's notebook for ARENA on reproducing the IOI paper using TransformerLens.
# This makes some sense, since EasyTransformer, the repo/lib released by the IOI guys, was forked from TransformerLens.
# In fact, this makes the reproduction a little bit more faithful, since they most likely do certain things such as 
# "folding" LayerNorms to improve their interpretability results, and we are able to do the same by using TransformerLens.
# HuggingFace, by contrast, has the most impenetrable docs and tons of outdated APIs and etc.; even their source 
# code is impossible to traverse, and I gave up on it, thankfully quickly.

from transformer_lens import utils, HookedTransformer, ActivationCache
model = HookedTransformer.from_pretrained("gpt2-small",
                                          center_unembed=True,
                                          center_writing_weights=True,
                                          fold_ln=False,
                                          refactor_factored_attn_matrices=True)
                                          

Loaded pretrained model gpt2-small into HookedTransformer


### Verify forward pass
Because the print of the model is a little messy because of hooks, and for whatever reason the model doesn't expose all of its
intermediate activations, we do this to ensure that the model's architecture is what we think it is.
(There are some complications with the layernorm folding that transformer_lens, and therefore the IOI paper, use.)

In [None]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
encoding = get_encoding(text, model.tokenizer, "cpu")
embedding_output = model.embed(encoding.input_ids)
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
# actually, not finishing this until it's clear that it's needed

In [8]:
next(model.parameters()).dtype

torch.float32

In [13]:
print(model)
# print(model.config) # doesn't work on hookedtransformer, is a huggingface thing
# print(model.embed.dtype) same, but can use dtype trick
# print(type(model))
#model.state_dict().keys()#.blocks[0].mlp

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_re

In [23]:
!pip install torchsummary
!pip install torchinfo

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [7]:
import pdb
from torchinfo import summary

text = "After John and Mary went to the store, John gave a bottle of milk to"
encoding = get_encoding(text, model.tokenizer, "cpu")
# embedding_output = model.embed(encoding.input_ids)
input_shape = encoding.input_ids.shape
print(input_shape)
pdb.set_trace()
summary(model, input_shape, device='cpu')


torch.Size([1, 512])
--Return--
None
> [0;32m/var/folders/c9/mvmx0pt17m51nyh8wgjnqj900000gn/T/ipykernel_11966/1150960534.py[0m(9)[0;36m<module>[0;34m()[0m
[0;32m      6 [0;31m[0;31m# embedding_output = model.embed(encoding.input_ids)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m[0minput_shape[0m [0;34m=[0m [0mencoding[0m[0;34m.[0m[0minput_ids[0m[0;34m.[0m[0mshape[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m[0mprint[0m[0;34m([0m[0minput_shape[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 9 [0;31m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     10 [0;31m[0msummary[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0minput_shape[0m[0;34m,[0m [0mdevice[0m[0;34m=[0m[0;34m'cpu'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
Moving model to device:  cpu
> [0;32m/Users/georgiazhou/miniconda3/lib/python3.12/site-packages/transformer_lens/HookedTransformer.py[0m(332)[0;36

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

In [33]:
# Same as in the notebook, example
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = "Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


## Generate dataset/Explore types

In [16]:
data = IOIDataset(N=500, prompt_type="ABBA", tokenizer=model.tokenizer)
#data.tokenized_prompts
data.ioi_prompts[0]
[x['TEMPLATE_IDX'] for x in data.ioi_prompts[0:10]]

[12, 4, 9, 6, 1, 13, 2, 8, 8, 4]

In [16]:
# test
pos_specific_hs = [
        [i for i in range(12)],
        [0],
        [i for i in range(12)]
    ]
all_heads = list(itertools.product(*pos_specific_hs))
target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)] # not meaningful in a GPT context
source_list = [[node] for node in all_heads if node not in target_nodes]

text = "After John and Mary went to the store, John gave a bottle of milk to"
encoding = get_encoding(text, model.tokenizer, device)
# encoding.input_ids.shape # 512-long vector, not sure why the tokens change from EOS to 0 at some point
# embedding = model.embed(encoding.input_ids)

out_decomps, target_decomps = prop_model_hh_batched(encoding, model, source_list, target_nodes,
                                                                   device=device,
                                                                   patched_values=None, mean_ablated=False)
                                                                   # patched_values=mean_act, mean_ablated=True)
                                                                

AttributeError: 'Attention' object has no attribute 'self'

## Generate mean activations

## Head to head direct influence

In [10]:
def patch_hh_at_pos(encoding, model, target_nodes, pos=0, mean_act=None, mean_ablated=False):
    pos_specific_hs = [
        [i for i in range(12)],
        [pos],
        [i for i in range(12)]
    ]
    all_heads = list(itertools.product(*pos_specific_hs))

    # patch one node at a time
    h_ctbn_list = []
    
    source_list = [[node] for node in all_heads if node not in target_nodes]
    out_decomps, target_decomps = prop_classifier_model_hh_batched(encoding, model, source_list, target_nodes,
                                                                   device=device,
                                                                   patched_values=mean_act, mean_ablated=True)
    for i, _ in enumerate(source_list):
        ctbn = 0
        for l in range(12):
            if target_decomps[l][i][0].shape[0] != 0:
                rel_part = np.mean(abs(target_decomps[l][i][0]))
                irrel_part = np.mean(abs(target_decomps[l][i][1]))
                ctbn += rel_part / abs(rel_part + irrel_part) * 100
        h_ctbn_list.append(ctbn)
        
    return source_list, h_ctbn_list

In [13]:
# perform on one doc as an example
text = documents_full[0]
label = labels_full[0]
encoding = get_encoding(text, tokenizer, device)

In [106]:
# perform one iteration measuring effect of the source nodes to target nodes as an example
# note that target nodes get updated in each iteration

target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]

all_source_hs = []
all_htbn = []
for pos in tqdm.tqdm(range(512)):
    with torch.no_grad():
        source_list, h_ctbn_list = patch_hh_at_pos(encoding, model, target_nodes, pos=pos, mean_act=back, mean_ablated=True)
    torch.cuda.empty_cache()
    all_source_hs.append(source_list)
    all_htbn.append(h_ctbn_list)

100%|██████████| 512/512 [1:52:20<00:00, 13.17s/it]


In [107]:
flat_ctbn = [c for sublist in all_htbn for c in sublist]
flat_source_h = [c for sublist in all_source_hs for c in sublist]


In [108]:
top_idx = sorted(range(len(flat_ctbn)), key=lambda i: flat_ctbn[i])[-6:]

In [109]:
for i in top_idx:
    print(flat_source_h[i], flat_ctbn[i])

[(6, 82, 4)] 36.659424751996994
[(5, 82, 4)] 42.90880411863327
[(3, 82, 0)] 51.443591713905334
[(4, 82, 0)] 87.46089041233063
[(5, 82, 0)] 103.08798849582672
[(6, 82, 0)] 132.23715126514435


In [112]:
# save the identified heads
path = f"{base_dir}/output/{args['task']}/{args['model_type']}_{args['field']}/h3"
os.makedirs(path, exist_ok=True)

with open(os.path.join(path, f"flat_source_h.pkl"), 'wb') as handle:
    pickle.dump(flat_source_h, handle)
    
with open(os.path.join(path, f"flat_source_h.pkl"), 'rb') as handle:
    back = pickle.load(handle)

## Examine the attended words by the identified heads

In [24]:
def collect_attended_tokens_hh(positives_heads, device, tokenizer, N=100, Z_thres=2, percentile=75, use_perc=False):
    index_lst = random.sample(range(0, len(documents_full)), N)
    docs = [documents_full[i] for i in index_lst]
    
    collect = collections.defaultdict(int)
    for doc in docs:
        encoding = get_encoding(doc, tokenizer, device)
        
        _, _, raw_att_probs_lst = prop_classifier_model_hh(encoding, model, [[]], [], device=device, output_att_prob=True)
        raw_att_probs = torch.stack(raw_att_probs_lst).cpu().numpy()

        avg_att_m = np.zeros((512))
        for level, pos, h in positives_heads:
            att_m = raw_att_probs[level, h, pos, :]
            avg_att_m += att_m

        avg_att_m /= len(positives)
        
        # convert to word level
        interval_dict, word_lst = compute_word_intervals(encoding, tokenizer)
        word_att_m = combine_token_attn(interval_dict, avg_att_m)
        
        if use_perc:
            perc_cutoff = np.percentile(word_att_m, percentile)
            positive_words = np.where(word_att_m > perc_cutoff)
        else:
            Z = (word_att_m - np.mean(word_att_m)) / np.std(word_att_m)
            positive_words = np.where(Z > Z_thres)
        
        for w_idx in positive_words[0]:
            w = word_lst[w_idx]
            #collect[w] += 1
            collect[w] += word_att_m[w_idx]
            
    return collect

In [65]:
def collect_attended_tokens_hh_rm_pos(positives_heads, device, tokenizer, N=100, Z_thres=2, percentile=75, use_perc=False):
    index_lst = random.sample(range(0, len(documents_full)), N)
    docs = [documents_full[i] for i in index_lst]
    
    collect = collections.defaultdict(int)
    for doc in docs:
        encoding = get_encoding(doc, tokenizer, device)
        
        _, _, raw_att_probs_lst = prop_classifier_model_hh(encoding, model, [[]], [], device=device, output_att_prob=True)
        raw_att_probs = torch.stack(raw_att_probs_lst).cpu().numpy()

        avg_att_m = np.zeros((512))
        for level, _, h in positives_heads:
            att_m = raw_att_probs[level, h, :, :]
            #att_m = np.mean(att_m, axis=0)
            max_row = np.unravel_index(np.argmax(att_m, axis=None), att_m.shape)[0]
            avg_att_m += att_m[max_row, :]

        avg_att_m /= len(positives)
        
        # convert to word level
        interval_dict, word_lst = compute_word_intervals(encoding, tokenizer)
        word_att_m = combine_token_attn(interval_dict, avg_att_m)
        
        if use_perc:
            perc_cutoff = np.percentile(word_att_m, percentile)
            positive_words = np.where(word_att_m > perc_cutoff)
        else:
            Z = (word_att_m - np.mean(word_att_m)) / np.std(word_att_m)
            positive_words = np.where(Z > Z_thres)
        
        for w_idx in positive_words[0]:
            w = word_lst[w_idx]
            #collect[w] += 1
            collect[w] += word_att_m[w_idx]
            
    return collect

In [62]:
negatives = [(11, 2), (11, 5), (11, 6), (11, 9), (11, 10), (11, 11),
             (10, 3), (10, 4), (10, 5), (10, 6), (10, 9), (10, 10), (10, 11),
             (9, 1), (9, 2), (9, 3), (9, 4), (9, 5), (9, 6), (9, 8), (9, 9), (9, 10), (9, 11),
             (8, 1), (8, 2), (8, 3), (8, 4), (8, 5), (8, 6), (8, 7), (8, 8), (8, 9), (8, 10), (8, 11),
             (7, 1), (7, 2), (7, 3), (7, 4), (7, 5), (7, 7), (7, 8), (7, 9), (7, 10),
             (6, 1), (6, 2), (6, 3), (6, 5), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11),
             (5, 1), (5, 2), (5, 3), (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), (5, 10), (5, 11),
             (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (4, 11),
             (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9), (3, 10), (3, 11),
             (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11),
             (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 0), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11),
             (0, 0), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 8), (0, 10), (0, 11),
            ]

In [88]:
# the identified attn heads using CD-T
pos_specific_hs = [
            [i for i in range(12)],
            [i for i in range(512)],
            [i for i in range(12)]
        ]
all_heads = list(itertools.product(*pos_specific_hs))
random_heads = random.sample(all_heads, 6)
positives = random_heads

In [107]:
positives = [(1, 169, 2), (2, 169, 2), (2, 169, 3), (4, 169, 8), (1, 411, 3)]

In [108]:
positive_attended_token_freq = collect_attended_tokens_hh_rm_pos(positives, device, tokenizer, N=200, use_perc=True)
positive_attended_token_freq = sorted(positive_attended_token_freq.items(), key=lambda k_v: k_v[1], reverse=True)

In [61]:
#h1 = [(10, 82, 0), (10, 61, 8), (10, 82, 7), (10, 176, 2), (10, 467, 1), (10, 91, 7)]
#h2 = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]
#h3 = [(6, 82, 4), (5, 82, 4), (3, 82, 0), (4, 82, 0), (5, 82, 0), (6, 82, 0)]
#positives = [(0, 82, 9), (0, 82, 1), (0, 82, 7), (1, 82, 6), (0, 82, 6), (2, 82, 0)]
positive_attended_token_freq = collect_attended_tokens_hh(positives, device, tokenizer, N=500, use_perc=True)
positive_attended_token_freq = sorted(positive_attended_token_freq.items(), key=lambda k_v: k_v[1], reverse=True)

In [109]:
import json
with open('pp15_h2.json', 'w') as fp:
    json.dump(positive_attended_token_freq, fp)

## Tests

In [42]:
text = documents_full[0]
label = labels_full[0]
encoding = get_encoding(text, tokenizer, device)

In [96]:
"""
source_list_30 = [#list(itertools.product(range(12), range(512), range(12))), 
                  # list(itertools.product(range(12), range(70, 85), range(12))), 
                  # [(11, 0, i) for i in range(12)]
                  [(0, 0, 0)]] * 30
source_list_60 = [#list(itertools.product(range(12), range(512), range(12))), 
                  # list(itertools.product(range(12), range(70, 85), range(12))), 
                  # [(11, 0, i) for i in range(12)]
                  [(0, 0, 0)], []] * 30
"""
target_nodes = [(11, 8), (11, 0), (11, 1), (11, 4), (11, 3), (11, 7)]
source_list = [[(5, 7, 0)], [(5, 5, 0)]]

In [None]:
source_list = [[(5, 7, 0)], [(5, 5, 0)]]
out_decomps, target_decomps = prop_classifier_model_hh_batched(encoding, model, source_list, target_nodes, patched_values=back, mean_ablated=True)

In [None]:
out_decomps, target_decomps, _ = prop_classifier_model_hh(encoding, model, source_list, target_nodes)

In [103]:
out_decomps

[(tensor([-0.0171,  0.0302, -0.0149], device='cuda:0'),
  tensor([-3.2685,  6.1339, -3.2953], device='cuda:0')),
 (tensor([-0.0179,  0.0306, -0.0147], device='cuda:0'),
  tensor([-3.2677,  6.1335, -3.2955], device='cuda:0'))]

In [104]:
target_decomps[11][0][0].shape

torch.Size([2, 64])

In [110]:
target_decomps[0][0][0].shape

torch.Size([0, 64])

### Function description:

prop_classifier_model_hh_batched(encoding, model, source_list, target_nodes):

- encoding - Encoding given by tokenizer
- model - BERT model
- source_list - List of lists where each list consists of tuples (layer, position, head) indexing a particular attention head whose influence is to be calculated
- target_nodes - A single list of tuples (layer, position, head) containing attention heads on whom the influence is to be measured
- num_at_time (optional) - Number of source_lists to be processed in a batch
- n_layers - Number of layers
- att_list - Attention probabilities if precomputed

Output consists of two lists - out_decomps and target_decomps:
- out_decomps - Consists of a list of tuples (rel, irrel) reflecting the decomposition of the _output_
- target_decomps - A list containining 12 (one for each layer) where each list is of length len(source_list). For any layer l, each entry of target_decomps[l] is a tuple (rel, irrel) decomposition of the target nodes at that layer for the corresponding set of source nodes. rel, irrel are of dimension #number of target nodes in layer l x head_size and the ordering of the target nodes in this layer is the same as provided 