In [124]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [125]:
import argparse
import numpy as np
import os
import pandas as pd
import scipy as sp
import sys
import torch
import torch.nn.functional as F
import warnings
import random
import collections

# ACD Imports
import math
import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

from torch import nn

warnings.filterwarnings("ignore")

base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)

from argparse import Namespace
from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars
from pyfunctions.general import extractListFromDic, readJson
from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels
from pyfunctions.cd import *
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AutoTokenizer, AutoModel
from transformers import BertTokenizer, BertForSequenceClassification

In [126]:
torch.autograd.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f3795492fd0>

## Model Arguments

In [127]:
args = {
    'model_type': 'bert', # bert, medical_bert, pubmed_bert, biobert, clinical_biobert
    'task': 'path',
    'field': 'PrimaryGleason'
}

device = 'cuda:0'

## Load Data

In [128]:
if args['model_type'] == 'bert':
    bert_path = 'bert-base-uncased'
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
elif args['model_type'] == 'medical_bert':
    bert_path = f"{base_dir}/models/pretrained/bert_pretrain_output_all_notes_150000/"
    tokenizer = BertTokenizer.from_pretrained(bert_path, local_files_only=True)
elif args['model_type'] == 'pubmed_bert':
    bert_path = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
elif args['model_type'] == 'pubmed_bert_full':
    bert_path = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
elif args['model_type'] == 'biobert':
    bert_path = "dmis-lab/biobert-v1.1"
    tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
elif args['model_type'] == 'clinical_biobert':
    bert_path = "emilyalsentzer/Bio_ClinicalBERT"
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

In [129]:
# Read in data
#field = 'PrimaryGleason' # out of PrimaryGleason, SecondaryGleason', 'MarginStatusNone', 'SeminalVesicleNone'
path = f"../data/prostate.json"
data = readJson(path)

# Clean reports
data = cleanSplit(data, stripChars)
data['dev_test'] = cleanReports(data['dev_test'], stripChars)
data = fixLabel(data)

train_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['train']]
val_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['val']]
test_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['test']]
print(len(train_documents), len(val_documents),len(test_documents))

Token indices sequence length is longer than the specified maximum sequence length for this model (1345 > 512). Running this sequence through the model will result in indexing errors


2066 517 324


In [130]:
# Create datasets
train_labels = [patient['labels'][args['field']] for patient in data['train']]
val_labels = [patient['labels'][args['field']] for patient in data['val']]
test_labels = [patient['labels'][args['field']] for patient in data['test']]

train_documents, train_labels = exclude_labels(train_documents, train_labels)
val_documents, val_labels = exclude_labels(val_documents, val_labels)
test_documents, test_labels = exclude_labels(test_documents, test_labels)

le = preprocessing.LabelEncoder()
le.fit(train_labels)

# Map raw label to processed label
le_dict = dict(zip(le.classes_, le.transform(le.classes_)))
le_dict = {str(key):le_dict[key] for key in le_dict}

for label in val_labels + test_labels:
    if str(label) not in le_dict:
        le_dict[str(label)] = len(le_dict)

# Map processed label back to raw label
inv_le_dict = {v: k for k, v in le_dict.items()}

In [131]:
documents_full = train_documents + val_documents + test_documents
labels_full = train_labels + val_labels + test_labels

In [132]:
type(tokenizer)

transformers.models.bert.tokenization_bert.BertTokenizer

## Load Trained Models

In [133]:
#load finetuned model
model_path = f"{base_dir}/models/{args['task']}/{args['model_type']}_{args['field']}"
checkpoint_file = f"{model_path}/save_output"
config_file = f"{model_path}/save_output/config.json"

model = BertForSequenceClassification.from_pretrained(checkpoint_file, num_labels=len(le_dict), output_hidden_states=True)

model = model.eval()
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [11]:
'''
encoding = tokenizer.encode_plus(train_documents[0], 
                                         add_special_tokens=True, 
                                         max_length=512,
                                         truncation=True, 
                                         padding = "max_length", 
                                         return_attention_mask=True, 
                                         pad_to_max_length=True,
                                         return_tensors="pt")

list(encoding.keys())
'''

'\nencoding = tokenizer.encode_plus(train_documents[0], \n                                         add_special_tokens=True, \n                                         max_length=512,\n                                         truncation=True, \n                                         padding = "max_length", \n                                         return_attention_mask=True, \n                                         pad_to_max_length=True,\n                                         return_tensors="pt")\n\nlist(encoding.keys())\n'

## Path Patching Code

In [11]:
'''
def patch_context(rel, irrel, patched_entries, sa_module):
    rel = reshape_separate_attention_heads(rel, sa_module)
    irrel = reshape_separate_attention_heads(irrel, sa_module)
    
    for entry in patched_entries:
        pos = entry[1]
        att_head = entry[2]

        rel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]
        irrel[:, pos, att_head, :] = 0
        
        # irrel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]
        # rel[:, pos, att_head, :] = 0

    
    rel = reshape_concatenate_attention_heads(rel, sa_module)
    irrel = reshape_concatenate_attention_heads(irrel, sa_module)
    
    return rel, irrel

def prop_self_attention_patched(rel, irrel, attention_mask, 
                                head_mask, patched_entries, 
                                sa_module, att_probs = None):
    if att_probs is not None:
        att_probs = att_probs
    else:
        att_probs = get_attention_probs(rel + irrel, attention_mask, head_mask, sa_module)
    
    rel_value, irrel_value = prop_linear(rel, irrel, sa_module.value)
    
    rel_context = mul_att(att_probs, rel_value, sa_module)
    irrel_context = mul_att(att_probs, irrel_value, sa_module)
    
    rel_context, irrel_context = patch_context(rel_context, irrel_context, patched_entries, sa_module)
    
    return rel_context, irrel_context

def prop_attention_patched(rel, irrel, attention_mask, 
                           head_mask, patched_entries, a_module, 
                           att_probs = None):
    
    rel_context, irrel_context = prop_self_attention_patched(rel, irrel, 
                                                             attention_mask, 
                                                             head_mask, 
                                                             patched_entries,
                                                             a_module.self, att_probs)
    
    # if len(patched_entries):
    #     print(rel_context[0, 0, :])
    #     print(irrel_context[0, 0, :])
    
    output_module = a_module.output
    
    rel_dense, irrel_dense = prop_linear(rel_context, irrel_context, output_module.dense)
    rel_tot = rel_dense + rel
    irrel_tot = irrel_dense + irrel
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)
    
#     print('AttRes: ', torch.norm(rel_tot[:, 0]), torch.norm(irrel_tot[:, 0]))
    
#     rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)
    
#     print('AttOut: ', torch.norm(rel_out[:, 0]), torch.norm(irrel_out[:, 0]))

    
    return rel_out, irrel_out

def prop_layer_patched(rel, irrel, attention_mask, head_mask, patched_entries, layer_module, att_probs = None):
    rel_a, irrel_a = prop_attention_patched(rel, irrel, attention_mask, head_mask, patched_entries, layer_module.attention, att_probs)
    
    # print('Attention: ', torch.norm(rel_a[:, 0]), torch.norm(irrel_a[:, 0]))
    
    i_module = layer_module.intermediate
    rel_id, irrel_id = prop_linear(rel_a, irrel_a, i_module.dense)
    rel_iact, irrel_iact = prop_act(rel_id, irrel_id, i_module.intermediate_act_fn)
    
    # print('Intermediate: ', torch.norm(rel_iact[:, 0]), torch.norm(irrel_iact[:, 0]))
    
    o_module = layer_module.output
    rel_od, irrel_od = prop_linear(rel_iact, irrel_iact, o_module.dense)
    
    # print('Output: ', torch.norm(rel_od[:, 0]), torch.norm(irrel_od[:, 0]))
    
    rel_tot = rel_od + rel_a
    irrel_tot = irrel_od + irrel_a
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, o_module.LayerNorm)
    
    # print('LayerNorm: ', torch.norm(rel_out[:, 0]), torch.norm(irrel_out[:, 0]))
    
    # import pdb; pdb.set_trace()
    
    return rel_out, irrel_out
'''

In [12]:
'''
def prop_classifier_model_patched(encoding, model, patched_entries, att_list = None):
    embedding_output = get_embeddings_bert(encoding, model.bert)
    input_shape = encoding['input_ids'].size()
    extended_attention_mask = get_extended_attention_mask(attention_mask = encoding['attention_mask'], 
                                                          input_shape = input_shape, 
                                                          bert_model = model.bert)
    
    head_mask = [None] * model.bert.config.num_hidden_layers
    encoder_module = model.bert.encoder
    
    sh = list(embedding_output.shape)
    
    rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    
    irrel[:] = embedding_output[:]
    
    
    for i, layer_module in enumerate(encoder_module.layer):
        layer_patched_entries = [p_entry for p_entry in patched_entries if p_entry[0] == i]
        layer_head_mask = head_mask[i]
        
        rel_n, irrel_n = prop_layer_patched(rel, irrel, extended_attention_mask, layer_head_mask, layer_patched_entries, layer_module, att_probs = None)
        # print(torch.norm(rel_n[:, 0]), torch.norm(irrel_n[:, 0]))
        normalize_rel_irrel(rel_n, irrel_n)
        rel, irrel = rel_n, irrel_n
        # if i == 11:
        # print(torch.norm(rel[:, 0]), torch.norm(irrel[:, 0]))
    
    rel_pool, irrel_pool = prop_pooler(rel, irrel, model.bert.pooler)
    rel_out, irrel_out = prop_linear(rel_pool, irrel_pool, model.classifier)
    
    return rel_out, irrel_out
'''

In [23]:
text = documents_full[0]
label = labels_full[0]
encoding = get_encoding(text, tokenizer, device)

In [45]:
patched_entries_1 = [(i, i, i) for i in range(12)]
patched_entries_2 = [(11, 0, i) for i in range(12)]
patched_entries_3 = []

In [46]:
# Get raw output
raw_logit = ft_model(**encoding, output_hidden_states = False)

In [47]:
raw_logit

SequenceClassifierOutput(loss=None, logits=tensor([[-2.4799, -1.4704,  8.0667, -2.2106, -2.3697]], device='cuda:1'), hidden_states=None, attentions=None)

In [None]:
rel_2, irrel_2 = prop_classifier_model_patched(encoding, ft_model, patched_entries_2)
rel_3, irrel_3 = prop_classifier_model_patched(encoding, ft_model, patched_entries_3)

In [53]:
rel_2

tensor([[ 5.0000, -0.9562,  5.1698, -1.4072, -1.6844]], device='cuda:1')

## Head-head Patching

In [134]:
def reshape_separate_attention_heads(context_layer, sa_module):
    new_shape = context_layer.size()[:-1] + (sa_module.num_attention_heads, sa_module.attention_head_size)
    context_layer = context_layer.view(new_shape)
    return context_layer

def reshape_concatenate_attention_heads(context_layer, sa_module):
    new_shape = context_layer.size()[:-2] + (sa_module.all_head_size,)
    context_layer = context_layer.view(*new_shape)
    return context_layer

def patch_context_hh(rel, irrel, source_node_list, target_nodes, level, sa_module):
    rel = reshape_separate_attention_heads(rel, sa_module)
    irrel = reshape_separate_attention_heads(irrel, sa_module)
    
    target_nodes_at_level = [node for node in target_nodes if node[0] == level]
    target_decomps = []
    
    for s_ind, sn_list in enumerate(source_node_list):
        out_shape = (len(target_nodes_at_level), sa_module.attention_head_size)
        
        rel_st = torch.zeros(out_shape, dtype = rel.dtype, device = device)
        irrel_st = torch.zeros(out_shape, dtype = rel.dtype, device = device)
        
        for t_ind, t in enumerate(target_nodes_at_level):
            if t[0] == level:
                t_pos = t[1]
                t_head = t[2]

                rel_st[t_ind, :] = rel[s_ind, t_pos, t_head, :]
                irrel_st[t_ind, :] = irrel[s_ind, t_pos, t_head, :]
        
        target_decomps.append((rel_st.detach().cpu().numpy(), irrel_st.detach().cpu().numpy()))
        
        for entry in sn_list:
            if entry[0] == level:
                pos = entry[1]
                att_head = entry[2]

                rel[s_ind, pos, att_head, :] = rel[s_ind, pos, att_head, :] + irrel[s_ind, pos, att_head, :]
                irrel[s_ind, pos, att_head, :] = 0

    
    rel = reshape_concatenate_attention_heads(rel, sa_module)
    irrel = reshape_concatenate_attention_heads(irrel, sa_module)
    
    return rel, irrel, target_decomps

def patch_context_hh_mean_ablated(rel, irrel, source_node_list, target_nodes, level, layer_patched_values, sa_module):
    rel = reshape_separate_attention_heads(rel, sa_module)
    irrel = reshape_separate_attention_heads(irrel, sa_module)
    
    target_nodes_at_level = [node for node in target_nodes if node[0] == level]
    target_decomps = []
    
    if layer_patched_values is not None:
        layer_patched_values = layer_patched_values[None, :, :, :]

    for s_ind, sn_list in enumerate(source_node_list):
        out_shape = (len(target_nodes_at_level), sa_module.attention_head_size)
        
        rel_st = torch.zeros(out_shape, dtype = rel.dtype, device = device)
        irrel_st = torch.zeros(out_shape, dtype = rel.dtype, device = device)

        for t_ind, t in enumerate(target_nodes_at_level):
            if t[0] == level:
                t_pos = t[1]
                t_head = t[2]
                rel_st[t_ind, :] = rel[s_ind, t_pos, t_head, :]
                irrel_st[t_ind, :] = irrel[s_ind, t_pos, t_head, :]

        
        target_decomps.append((rel_st.detach().cpu().numpy(), irrel_st.detach().cpu().numpy()))
        
        for entry in sn_list:
            if entry[0] == level:
                pos = entry[1]
                att_head = entry[2]

                #rel[s_ind, pos, att_head, :] = rel[s_ind, pos, att_head, :] + irrel[s_ind, pos, att_head, :]
                #irrel[s_ind, pos, att_head, :] = 0
                
                rel[s_ind, pos, att_head, :] = irrel[s_ind, pos, att_head, :] + rel[s_ind, pos, att_head, :] - torch.Tensor(layer_patched_values[:, pos, att_head, :]).to(device)
                irrel[s_ind, pos, att_head, :] = torch.Tensor(layer_patched_values[:, pos, att_head, :]).to(device)

    
    rel = reshape_concatenate_attention_heads(rel, sa_module)
    irrel = reshape_concatenate_attention_heads(irrel, sa_module)
    
    return rel, irrel, target_decomps

def prop_self_attention_hh(rel, irrel, attention_mask, 
                           head_mask, source_node_list, target_nodes, 
                           level, sa_module, att_probs = None, output_att_prob=False):
    
    if att_probs is not None:
        att_probs = att_probs
    else:
        att_probs = get_attention_probs(rel[0].unsqueeze(0) + irrel[0].unsqueeze(0), attention_mask, head_mask, sa_module)

    rel_value, irrel_value = prop_linear(rel, irrel, sa_module.value)

    rel_context = mul_att(att_probs, rel_value, sa_module)

    irrel_context = mul_att(att_probs, irrel_value, sa_module)
    
    #rel_context, irrel_context, target_decomps = patch_context_hh(rel_context, irrel_context, source_node_list, target_nodes, level, sa_module)
    
    if output_att_prob:
        return rel_context, irrel_context, att_probs
    else:
        return rel_context, irrel_context, None
    
    #return rel_context, irrel_context, target_decomps

def prop_attention_hh(rel, irrel, attention_mask, 
                      head_mask, source_node_list, target_nodes, level,
                      layer_patched_values,
                      a_module, att_probs = None, output_att_prob=False, mean_ablated=False):
    
    rel_context, irrel_context, returned_att_probs = prop_self_attention_hh(rel, irrel, 
                                                                        attention_mask, 
                                                                        head_mask, 
                                                                        source_node_list,
                                                                        target_nodes,
                                                                        level,
                                                                        a_module.self, att_probs,
                                                                        output_att_prob=output_att_prob)
    normalize_rel_irrel(rel_context, irrel_context)
    
    output_module = a_module.output
    
    rel_dense, irrel_dense = prop_linear(rel_context, irrel_context, output_module.dense)
    
    normalize_rel_irrel(rel_dense, irrel_dense)
    
    rel_tot = rel_dense + rel
    irrel_tot = irrel_dense + irrel
    
    normalize_rel_irrel(rel_tot, irrel_tot)
    
    if not mean_ablated:
        rel_tot, irrel_tot, target_decomps = patch_context_hh(rel_tot, irrel_tot, source_node_list, target_nodes, level, a_module.self)
    else:
        rel_tot, irrel_tot, target_decomps = patch_context_hh_mean_ablated(rel_tot, irrel_tot, source_node_list, target_nodes, level,
                                                                            layer_patched_values, a_module.self)
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)

    normalize_rel_irrel(rel_out, irrel_out)
    
    return rel_out, irrel_out, target_decomps, returned_att_probs

def prop_layer_hh(rel, irrel, attention_mask, head_mask, 
                  source_node_list, target_nodes, level, layer_patched_values,
                  layer_module, att_probs = None, output_att_prob=False, mean_ablated=False):
    
    rel_a, irrel_a, target_decomps, returned_att_probs = prop_attention_hh(rel, irrel, attention_mask, 
                                                                           head_mask, source_node_list, 
                                                                           target_nodes, level, layer_patched_values,
                                                                           layer_module.attention,
                                                                           att_probs, output_att_prob, mean_ablated=mean_ablated)

    i_module = layer_module.intermediate
    rel_id, irrel_id = prop_linear(rel_a, irrel_a, i_module.dense)
    normalize_rel_irrel(rel_id, irrel_id)
    
    rel_iact, irrel_iact = prop_act(rel_id, irrel_id, i_module.intermediate_act_fn)
    
    o_module = layer_module.output
    rel_od, irrel_od = prop_linear(rel_iact, irrel_iact, o_module.dense)
    normalize_rel_irrel(rel_od, irrel_od)
    
    rel_tot = rel_od + rel_a
    irrel_tot = irrel_od + irrel_a
    normalize_rel_irrel(rel_tot, irrel_tot)

    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, o_module.LayerNorm)
    
    
    return rel_out, irrel_out, target_decomps, returned_att_probs

In [135]:
def prop_classifier_model_hh(encoding, model, source_node_list, target_nodes, 
                             patched_values=None, att_list = None, output_att_prob=False, mean_ablated=False):
    embedding_output = get_embeddings_bert(encoding, model.bert)
    input_shape = encoding['input_ids'].size()
    extended_attention_mask = get_extended_attention_mask(attention_mask = encoding['attention_mask'], 
                                                          input_shape = input_shape, 
                                                          bert_model = model.bert,
                                                          device = device)
    
    head_mask = [None] * model.bert.config.num_hidden_layers
    encoder_module = model.bert.encoder
    
    sh = list(embedding_output.shape)
    sh[0] = len(source_node_list)
    
    rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    
    irrel[:] = embedding_output[:]
    
    target_decomps = []
    att_probs_lst = []
    for i, layer_module in enumerate(encoder_module.layer):
        layer_head_mask = head_mask[i]
        att_probs = None
        
        if patched_values is not None:
            layer_patched_values = patched_values[i] #[512, 12, 64]
        else:
            layer_patched_values = None
            
        rel_n, irrel_n, layer_target_decomps, returned_att_probs = prop_layer_hh(rel, irrel, extended_attention_mask, 
                                                                                 layer_head_mask, source_node_list, 
                                                                                 target_nodes, i, 
                                                                                 layer_patched_values,
                                                                                 layer_module, att_probs, output_att_prob,
                                                                                 mean_ablated=mean_ablated)
        target_decomps.append(layer_target_decomps)
        normalize_rel_irrel(rel_n, irrel_n)
        rel, irrel = rel_n, irrel_n
        
        if output_att_prob:
            att_probs_lst.append(returned_att_probs.squeeze(0))
    
    rel_pool, irrel_pool = prop_pooler(rel, irrel, model.bert.pooler)
    rel_out, irrel_out = prop_linear(rel_pool, irrel_pool, model.classifier)
    
    out_decomps = []

    for i, sn_list in enumerate(source_node_list):
        rel_vec = rel_out[i, :].detach().cpu().numpy()
        irrel_vec = irrel_out[i, :].detach().cpu().numpy()
        
        out_decomps.append((rel_vec, irrel_vec))
    
    return out_decomps, target_decomps, att_probs_lst

def prop_classifier_model_hh_batched(encoding, model, source_node_list, target_nodes, patched_values=None, 
                                     num_at_time = 64, n_layers = 12, att_list = None, output_att_prob=False, mean_ablated=False):
    out_decomps = []
    target_decomps = [[] for i in range(n_layers)]
    
    n_source_lists = len(source_node_list)
    n_batches = int((n_source_lists + (num_at_time - 1)) / num_at_time)

    for b_no in range(n_batches):
        b_st = b_no * num_at_time
        b_end = min(b_st + num_at_time, n_source_lists)
        layer_out_decomps, layer_target_decomps, att_probs_lst = prop_classifier_model_hh(encoding, model, 
                                                                           source_node_list[b_st: b_end],
                                                                           target_nodes, patched_values,
                                                                           att_list=att_list,
                                                                           output_att_prob=output_att_prob,
                                                                           mean_ablated=mean_ablated)
        out_decomps = out_decomps + layer_out_decomps
        target_decomps = [target_decomps[i] + layer_target_decomps[i] for i in range(n_layers)]
    
    return out_decomps, target_decomps

In [136]:
# codes for second pass: ablate the target nodes

In [137]:
def prop_self_attention_patched(rel, irrel, attention_mask, 
                                head_mask, patched_entries, layer_patched_values,
                                sa_module, att_probs = None):
    
    if att_probs is not None:
        att_probs = att_probs
    else:
        att_probs = get_attention_probs(rel + irrel, attention_mask, head_mask, sa_module)
    
    rel_value, irrel_value = prop_linear(rel, irrel, sa_module.value)
    
    rel_context = mul_att(att_probs, rel_value, sa_module)
    irrel_context = mul_att(att_probs, irrel_value, sa_module)
    
    #rel_context, irrel_context = patch_context(rel_context, irrel_context, patched_entries, layer_patched_values, sa_module)
    
    return rel_context, irrel_context
    
def patch_context_baseline(rel, irrel, patched_entries, layer_patched_values, sa_module):
    rel = reshape_separate_attention_heads(rel, sa_module)
    irrel = reshape_separate_attention_heads(irrel, sa_module)

    for i, entry in enumerate(patched_entries):
        pos = entry[1]
        att_head = entry[2]
        
        saved_rel = torch.Tensor(layer_patched_values[0][i])
        saved_irrel = torch.Tensor(layer_patched_values[1][i])
        
        rel[:, pos, att_head, :] = saved_rel
        irrel[:, pos, att_head, :] = saved_irrel
        
    rel = reshape_concatenate_attention_heads(rel, sa_module)
    irrel = reshape_concatenate_attention_heads(irrel, sa_module)
    return rel, irrel

def prop_attention_patched_baseline(rel, irrel, attention_mask, 
                           head_mask, patched_entries, layer_patched_values, a_module, 
                           att_probs = None):

    
    rel_context, irrel_context = prop_self_attention_patched(rel, irrel, 
                                                             attention_mask, 
                                                             head_mask, 
                                                             patched_entries,
                                                             layer_patched_values,
                                                             a_module.self, att_probs)

    output_module = a_module.output
    
    rel_dense, irrel_dense = prop_linear(rel_context, irrel_context, output_module.dense)
    rel_tot = rel_dense + rel
    irrel_tot = irrel_dense + irrel
        
    rel_tot, irrel_tot = patch_context_baseline(rel_tot, irrel_tot, patched_entries, layer_patched_values, a_module.self)
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)
    
    return rel_out, irrel_out

def prop_layer_patched(rel, irrel, attention_mask, head_mask, patched_entries, layer_patched_values, layer_module, att_probs = None):
    
    rel_a, irrel_a = prop_attention_patched_baseline(rel, irrel, attention_mask, head_mask,
                                                     patched_entries, layer_patched_values,
                                                     layer_module.attention, att_probs)
    
    i_module = layer_module.intermediate
    rel_id, irrel_id = prop_linear(rel_a, irrel_a, i_module.dense)
    rel_iact, irrel_iact = prop_act(rel_id, irrel_id, i_module.intermediate_act_fn)
    
    o_module = layer_module.output
    rel_od, irrel_od = prop_linear(rel_iact, irrel_iact, o_module.dense)
    
    rel_tot = rel_od + rel_a
    irrel_tot = irrel_od + irrel_a
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, o_module.LayerNorm)
    
    # import pdb; pdb.set_trace()
    
    return rel_out, irrel_out

In [138]:
def ablate_target_nodes(encoding, model, patched_entries, patched_values=None, att_list = None):
    
    embedding_output = get_embeddings_bert(encoding, model.bert)
    input_shape = encoding['input_ids'].size()
    extended_attention_mask = get_extended_attention_mask(attention_mask = encoding['attention_mask'], 
                                                          input_shape = input_shape, 
                                                          bert_model = model.bert,
                                                          device = device)
    
    head_mask = [None] * model.bert.config.num_hidden_layers
    encoder_module = model.bert.encoder
    
    sh = list(embedding_output.shape)
    
    rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    
    irrel[:] = embedding_output[:]

    for i, layer_module in enumerate(encoder_module.layer):
        layer_patched_entries = [p_entry for p_entry in patched_entries if p_entry[0] == i]
        layer_head_mask = head_mask[i]
        att_probs = None
        
        if patched_values is not None:
            layer_patched_values = patched_values[i]
        else:
            layer_patched_values = None
        
        rel_n, irrel_n = prop_layer_patched(rel, irrel, extended_attention_mask,
                                            layer_head_mask, layer_patched_entries,
                                            layer_patched_values,
                                            layer_module, att_probs)
        normalize_rel_irrel(rel_n, irrel_n)
        rel, irrel = rel_n, irrel_n
    
    rel_pool, irrel_pool = prop_pooler(rel, irrel, model.bert.pooler)
    rel_out, irrel_out = prop_linear(rel_pool, irrel_pool, model.classifier)
    
    return rel_out, irrel_out

In [139]:
import pickle

path = f"{base_dir}/output/{args['task']}/{args['model_type']}_{args['field']}/h_to_logits"
os.makedirs(path, exist_ok=True)

with open(os.path.join(path, f"mean_head_out_res_500.pkl"), 'rb') as handle:
    back = pickle.load(handle)

In [140]:
import itertools

def patch_hh_at_pos_baseline(encoding, label_idx, model, target_nodes, pos=0, mean_act=None, mean_ablated=False):
    
    raw_logit = model(**encoding)[0][0][label_idx]
    
    pos_specific_hs = [
        [i for i in range(12)],
        [pos],
        [i for i in range(12)]
    ]
    all_heads = list(itertools.product(*pos_specific_hs))

    # patch one node at a time
    h_ctbn_list = []
    
    source_list = [[node] for node in all_heads if node not in target_nodes]
    out_decomps, target_decomps = prop_classifier_model_hh_batched(encoding, model, source_list, target_nodes,
                                                                  patched_values=mean_act, mean_ablated=True)
    
    for i, _ in enumerate(source_list):
        tmp = []
        for l in range(12):
            if target_decomps[l][i][0].shape[0] != 0:
                tmp.append(target_decomps[l][i])
            else:
                tmp.append([])
        
        rel_out, irrel_out = ablate_target_nodes(encoding, model, target_nodes, tmp, att_list = None)
        logit_diff = (rel_out[0][label_idx] + irrel_out[0][label_idx]) - raw_logit
        h_ctbn_list.append(logit_diff / abs(raw_logit) * 100)
        
    return source_list, h_ctbn_list

In [141]:
text = documents_full[0]
label = labels_full[0]
encoding = get_encoding(text, tokenizer, device)
label_idx = le_dict[label]

In [146]:
#target_nodes = [(11, 0, 1), (11, 0, 7), (11, 0, 5), (11, 0, 3), (11, 0, 0), (11, 0, 8)]
#target_nodes = [(8, 132, 1), (8, 275, 0), (6, 397, 1), (8, 66, 6), (8, 380, 8), (1, 195, 0)]
#target_nodes = [(1, 169, 2), (2, 169, 2), (2, 169, 3), (4, 169, 8), (1, 411, 3), (2, 169, 1)]
#### backup^^^

target_nodes = [(11, 506, 6), (11, 506, 7), (11, 506, 8), (11, 506, 9), (11, 506, 10), (11, 506, 11)]

all_source_hs = []
all_htbn = []
for pos in tqdm.tqdm(range(512)):
    with torch.no_grad():
        source_list, h_ctbn_list = patch_hh_at_pos_baseline(encoding, label_idx, model, target_nodes,
                                                            pos=pos, mean_act=back, mean_ablated=True)
    torch.cuda.empty_cache()
    all_source_hs.append(source_list)
    all_htbn.append(h_ctbn_list)

100%|██████████| 512/512 [3:37:26<00:00, 25.48s/it]  


In [None]:
h_ctbn_list

In [147]:
flat_ctbn = [c for sublist in all_htbn for c in sublist]
flat_source_h = [c for sublist in all_source_hs for c in sublist]


In [148]:
top_idx = sorted(range(len(flat_ctbn)), key=lambda i: flat_ctbn[i])[-6:]

In [149]:
for i in top_idx:
    print(flat_source_h[i], flat_ctbn[i])

[(11, 511, 6)] tensor(-1.5471e-05, device='cuda:0')
[(11, 511, 7)] tensor(-1.5471e-05, device='cuda:0')
[(11, 511, 8)] tensor(-1.5471e-05, device='cuda:0')
[(11, 511, 9)] tensor(-1.5471e-05, device='cuda:0')
[(11, 511, 10)] tensor(-1.5471e-05, device='cuda:0')
[(11, 511, 11)] tensor(-1.5471e-05, device='cuda:0')


In [109]:
# mean-ablated
import pickle

path = f"{base_dir}/output/{args['task']}/{args['model_type']}_{args['field']}/h2"
os.makedirs(path, exist_ok=True)

with open(os.path.join(path, f"flat_source_h_baseline.pkl"), 'wb') as handle:
    pickle.dump(flat_source_h, handle)
    
with open(os.path.join(path, f"flat_source_h_baseline.pkl"), 'rb') as handle:
    back = pickle.load(handle)

In [110]:
back[i]

[(2, 169, 1)]

In [117]:
def collect_attended_tokens(positives_heads, N=100, Z_thres=2):
    index_lst = random.sample(range(0, len(documents_full)), N)
    docs = [documents_full[i] for i in index_lst]
    
    collect = collections.defaultdict(int)
    for doc in docs:
        encoding = get_encoding(doc, tokenizer, device)
        
        _, _, raw_att_probs_lst = prop_classifier_model_hh(encoding, model, [[]], [], output_att_prob=True)
        raw_att_probs = torch.stack(raw_att_probs_lst).cpu().numpy()

        avg_att_m = np.zeros((512))
        for level, pos, h in positives_heads:
            att_m = raw_att_probs[level, h, pos, :]
            avg_att_m += att_m

        avg_att_m /= len(positives)
        
        # convert to word level
        interval_dict, word_lst = compute_word_intervals(encoding)
        word_att_m = combine_token_attn(interval_dict, avg_att_m)
        
        Z = (word_att_m - np.mean(word_att_m)) / np.std(word_att_m)

        positive_words = np.where(Z > Z_thres)
        
        for w_idx in positive_words[0]:
            w = word_lst[w_idx]
            #collect[w] += 1
            collect[w] += word_att_m[w_idx]
            
    return collect


def combine_token_attn(interval_dict, avg_att_m):
    word_cnt = len(interval_dict)
    new_att_m = np.zeros(word_cnt)
    for i in range(word_cnt):
        t_idx_lst = interval_dict[i+1]
        if len(t_idx_lst) == 1:
            new_att_m[i] = avg_att_m[t_idx_lst[0]]
        else:
            new_att_m[i] = np.sum(avg_att_m[t_idx_lst[0]:t_idx_lst[-1]+1])
    return new_att_m


def compute_word_intervals(encoding):
    word_cnt = 0
    interval_dict = collections.defaultdict(list)
    
    pretok_sent = ""
    for i in range(512):
        tok = tokenizer.decode(encoding['input_ids'][:, i])
        if tok.startswith("##"):
            interval_dict[word_cnt].append(i)
            pretok_sent += tok[2:]
        else:
            word_cnt += 1
            interval_dict[word_cnt].append(i)
            pretok_sent += " " + tok
    pretok_sent = pretok_sent[1:]
    word_lst = pretok_sent.split(" ")
    
    assert(len(interval_dict) == len(word_lst))
    
    return interval_dict, word_lst

In [122]:
#h1 = [(11, 0, 1), (11, 0, 7), (11, 0, 5), (11, 0, 3), (11, 0, 0), (11, 0, 8)]
#h2 = [(8, 132, 1), (8, 275, 0), (6, 397, 1), (8, 66, 6), (8, 380, 8), (1, 195, 0)]
h3 = [(1, 169, 2), (2, 169, 2), (2, 169, 3), (4, 169, 8), (1, 411, 3), (2, 169, 1)]
positives = h3
positive_attended_token_freq = collect_attended_tokens(positives, N=500, Z_thres=3)
positive_attended_token_freq = sorted(positive_attended_token_freq.items(), key=lambda k_v: k_v[1], reverse=True)

In [123]:
import json
with open('result_h3_baseline.json', 'w') as fp:
    json.dump(positive_attended_token_freq, fp)

### Function description:

prop_classifier_model_hh_batched(encoding, model, source_list, target_nodes):

- encoding - Encoding given by tokenizer
- model - BERT model
- source_list - List of lists where each list consists of tuples (layer, position, head) indexing a particular attention head whose influence is to be calculated
- target_nodes - A single list of tuples (layer, position, head) containing attention heads on whom the influence is to be measured
- num_at_time (optional) - Number of source_lists to be processed in a batch
- n_layers - Number of layers
- att_list - Attention probabilities if precomputed

Output consists of two lists - out_decomps and target_decomps:
- out_decomps - Consists of a list of tuples (rel, irrel) reflecting the decomposition of the _output_
- target_decomps - A list containining 12 (one for each layer) where each list is of length len(source_list). For any layer l, each entry of target_decomps[l] is a tuple (rel, irrel) decomposition of the target nodes at that layer for the corresponding set of source nodes. rel, irrel are of dimension #number of target nodes in layer l x head_size and the ordering of the target nodes in this layer is the same as provided 

In [324]:
import time

In [325]:
st = time.time()
out_decomps, target_decomps = prop_classifier_model_hh_batched(encoding, ft_model, source_list_30, target_nodes)
end = time.time()
print(end - st)

st = time.time()
out_decomps, target_decomps = prop_classifier_model_hh_batched(encoding, ft_model, source_list_60, target_nodes)
end = time.time()
print(end - st)

4.329473972320557
4.332589626312256
8.348850011825562
8.351091861724854


In [None]:
target_decomps[11]

# Appendix

In [None]:
def patch_context_dot_w_embed(embed, rel, irrel, patched_entries, sa_module):
    rel = reshape_separate_attention_heads(rel, sa_module)
    irrel = reshape_separate_attention_heads(irrel, sa_module)
    
    for entry in patched_entries:
        pos = entry[1]
        att_head = entry[2]

        rel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]
        irrel[:, pos, att_head, :] = 0
        #rel[:, pos, att_head, :] = 0
        #irrel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]

        
    
    rel = reshape_concatenate_attention_heads(rel, sa_module)
    irrel = reshape_concatenate_attention_heads(irrel, sa_module)
    
    return rel, irrel

def prop_self_attention_patched_dot_w_embed(embed, rel, irrel, attention_mask, 
                                head_mask, patched_entries, 
                                sa_module, att_probs = None, output_att_prob=False):
    if att_probs is not None:
        att_probs = att_probs
    else:
        att_probs = get_attention_probs(rel + irrel, attention_mask, head_mask, sa_module)
    
    rel_value, irrel_value = prop_linear(rel, irrel, sa_module.value)
    
    rel_context = mul_att(att_probs, rel_value, sa_module)
    irrel_context = mul_att(att_probs, irrel_value, sa_module)
    
    rel_context, irrel_context = patch_context(embed, rel_context, irrel_context, patched_entries, sa_module)
    
    if output_att_prob:
        return rel_context, irrel_context, att_probs
    else:
        return rel_context, irrel_context, None
    
def prop_attention_patched_dot_w_embed(embed, rel, irrel, attention_mask, 
                           head_mask, patched_entries, a_module, 
                           att_probs = None,
                           output_att_prob=False):
    
    rel_context, irrel_context, returned_att_probs = prop_self_attention_patched(rel, irrel, 
                                                             attention_mask, 
                                                             head_mask, 
                                                             patched_entries,
                                                             a_module.self, att_probs, output_att_prob)
    
    # if len(patched_entries):
    #     print(rel_context[0, 0, :])
    #     print(irrel_context[0, 0, :])
    
    output_module = a_module.output
    
    rel_dense, irrel_dense = prop_linear(rel_context, irrel_context, output_module.dense)
    rel_tot = rel_dense + rel
    irrel_tot = irrel_dense + irrel
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)
    
    return rel_out, irrel_out, returned_att_probs

def prop_layer_patched_dot_w_embed(embed, rel, irrel, attention_mask, head_mask, patched_entries, layer_module, att_probs = None, output_att_prob=False):
    rel_a, irrel_a, returned_att_probs = prop_attention_patched_dot_w_embed(embed, rel, irrel, attention_mask, head_mask, patched_entries, layer_module.attention, att_probs, output_att_prob)
    
    i_module = layer_module.intermediate
    rel_id, irrel_id = prop_linear(rel_a, irrel_a, i_module.dense)
    rel_iact, irrel_iact = prop_act(rel_id, irrel_id, i_module.intermediate_act_fn)
    
    o_module = layer_module.output
    rel_od, irrel_od = prop_linear(rel_iact, irrel_iact, o_module.dense)
    
    rel_tot = rel_od + rel_a
    irrel_tot = irrel_od + irrel_a
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, o_module.LayerNorm)
    
    # import pdb; pdb.set_trace()
    
    return rel_out, irrel_out, returned_att_probs

def prop_classifier_model_patched_dot_w_embed(encoding, model, patched_entries, att_list = None, output_att_prob=False):
    # patched_entries: attention heads to patch. format: [(level, pos, head)]
    # level: 0-11, pos: 0-511, head: 0-11
    # rel_out: the contribution of the patched_entries
    # irrel_out: the contribution of everything else
    
    embedding_output = get_embeddings_bert(encoding, model.bert)
    input_shape = encoding['input_ids'].size()
    extended_attention_mask = get_extended_attention_mask(attention_mask = encoding['attention_mask'], 
                                                          input_shape = input_shape, 
                                                          bert_model = model.bert)
    
    head_mask = [None] * model.bert.config.num_hidden_layers
    encoder_module = model.bert.encoder
    
    sh = list(embedding_output.shape)
    
    rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    
    #rel[:] = embedding_output[:]
    irrel[:] = embedding_output[:]

    att_probs_lst = []
    for i, layer_module in enumerate(encoder_module.layer):
        layer_patched_entries = [p_entry for p_entry in patched_entries if p_entry[0] == i]
        layer_head_mask = head_mask[i]
        att_probs = None
        rel_n, irrel_n, returned_att_probs = prop_layer_patched_dot_w_embed(embedding_output, rel, irrel, extended_attention_mask,
                                                                layer_head_mask, layer_patched_entries,
                                                                layer_module, att_probs, output_att_prob)
        normalize_rel_irrel(rel_n, irrel_n)
        rel, irrel = rel_n, irrel_n
        
        if output_att_prob:
            att_probs_lst.append(returned_att_probs.squeeze(0))
    
    rel_pool, irrel_pool = prop_pooler(rel, irrel, model.bert.pooler)
    rel_out, irrel_out = prop_linear(rel_pool, irrel_pool, model.classifier)
    
    return rel_out, irrel_out, att_probs_lst

In [None]:
text = documents_full[0]
label = labels_full[0]
encoding = get_encoding(text, tokenizer).to(device)

In [None]:
rel, irrel, _ = prop_classifier_model_patched_dot_w_embed(encoding, model, [(11, 0, 0)])