In [22]:
%load_ext autoreload
%autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
import argparse
import numpy as np
import os
import pandas as pd
import scipy as sp
import sys
import torch
import torch.nn.functional as F
import warnings
import random
import collections
import functools

# CD-T Imports
import math
import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import itertools

from torch import nn

warnings.filterwarnings("ignore")

base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)

from argparse import Namespace
from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars
from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals
from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels
from pyfunctions.cdt_from_source_nodes import *
from pyfunctions.cdt_source_to_target import *
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AutoTokenizer, AutoModel
from transformers import BertTokenizer, BertForSequenceClassification

In [3]:
torch.autograd.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x717706b71820>

## Model Arguments

In [4]:
args = {
    'model_type': 'bert', # bert, medical_bert, pubmed_bert, biobert, clinical_biobert
    'task': 'path',
    'field': 'PrimaryGleason'
}

device = 'cpu'

In [5]:
if args['model_type'] == 'bert':
    bert_path = 'bert-base-uncased'
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
elif args['model_type'] == 'medical_bert':
    bert_path = f"{base_dir}/models/pretrained/bert_pretrain_output_all_notes_150000/"
    tokenizer = BertTokenizer.from_pretrained(bert_path, local_files_only=True)
elif args['model_type'] == 'pubmed_bert':
    bert_path = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
elif args['model_type'] == 'pubmed_bert_full':
    bert_path = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
elif args['model_type'] == 'biobert':
    bert_path = "dmis-lab/biobert-v1.1"
    tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
elif args['model_type'] == 'clinical_biobert':
    bert_path = "emilyalsentzer/Bio_ClinicalBERT"
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

## Load Data

you can cutomize the code here to read in your own data.

In [6]:
# Read in data
#field = 'PrimaryGleason' # out of PrimaryGleason, SecondaryGleason', 'MarginStatusNone', 'SeminalVesicleNone'
path = f"../data/prostate.json"
data = readJson(path)

# Clean reports
data = cleanSplit(data, stripChars)
data['dev_test'] = cleanReports(data['dev_test'], stripChars)
data = fixLabel(data)

train_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['train']]
val_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['val']]
test_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['test']]
print(len(train_documents), len(val_documents),len(test_documents))

Token indices sequence length is longer than the specified maximum sequence length for this model (1345 > 512). Running this sequence through the model will result in indexing errors


2066 517 324


In [7]:
#data['train'][0]
train_documents[0]

'synoptic comment for prostate tumors null - type of tumor : small acinar adenocarcinoma. - location of tumor : left posterior mid gland ( slides d17 and d18 ). - estimated volume of tumor : 0. 3 cm3. - gleason score : 4 + 3. - estimated volume > gleason pattern 3 : 70 %. - involvement of capsule : tumor invades capsule in left posterior mid section slide d18. - extraprostatic extension : none. - margin status for tumor : - negative. - margin status for benign prostate glands : - no benign glands present at inked excision margins. - high - grade prostatic intraepithelial neoplasia ( hgpin ) : present extensively. - tumor involvement of seminal vesicle : none. - perineural infiltration : present ( slides d17 and d18 ). - lymph node status : - negative ; total number of nodes examined : 15 ( parts b and c ). - ajcc / uicc stage : pt2an0. null null null null specimen ( s ) received a : anterior prostatic fat b : lymph node right pelvic c : lymph node left pelvic d : prostate and bilateral

In [7]:
# Create datasets
train_labels = [patient['labels'][args['field']] for patient in data['train']]
val_labels = [patient['labels'][args['field']] for patient in data['val']]
test_labels = [patient['labels'][args['field']] for patient in data['test']]

train_documents, train_labels = exclude_labels(train_documents, train_labels)
val_documents, val_labels = exclude_labels(val_documents, val_labels)
test_documents, test_labels = exclude_labels(test_documents, test_labels)

le = preprocessing.LabelEncoder()
le.fit(train_labels)

# Map raw label to processed label
le_dict = dict(zip(le.classes_, le.transform(le.classes_)))
le_dict = {str(key):le_dict[key] for key in le_dict}

for label in val_labels + test_labels:
    if str(label) not in le_dict:
        le_dict[str(label)] = len(le_dict)

# Map processed label back to raw label
inv_le_dict = {v: k for k, v in le_dict.items()}

In [8]:
documents_full = train_documents + val_documents + test_documents
labels_full = train_labels + val_labels + test_labels

## Load Trained Models

In [9]:
#load finetuned model
model_path = f"{base_dir}/PG_best_ckpts/{args['model_type']}" #{args['task']}/{args['model_type']}_{args['field']}"
checkpoint_file = f"{model_path}/save_output"
config_file = f"{model_path}/save_output/config.json"

model = BertForSequenceClassification.from_pretrained(checkpoint_file, num_labels=len(le_dict), output_hidden_states=True)

model = model.eval()
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [26]:
print(next(model.parameters()).dtype)

torch.float32


In [27]:
#model
print(model.bert.encoder.layer[0])
print(type(model))

BertLayer(
  (attention): BertAttention(
    (self): BertSdpaSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
<class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'

In [28]:
!pip install torchsummary

python(11648) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [None]:
from torchsummary import summary
summary(model, )

## Head to head direct influence

In [10]:
# read in the pre-calculated mean head response
path = f"{base_dir}/output/"

with open(os.path.join(path, f"{args['model_type']}_mean_acts_random_500.pkl"), 'rb') as handle:
    back = pickle.load(handle)


FileNotFoundError: [Errno 2] No such file or directory: '/home/shawnghu/ml/CD_Circuit/output/bert_mean_acts_random_500.pkl'

### Function description:

prop_classifier_model_hh_batched(encoding, model, source_list, target_nodes):

- encoding - Encoding given by tokenizer
- model - BERT model
- source_list - List of lists where each list consists of tuples (layer, position, head) indexing a particular attention head whose influence is to be calculated
- target_nodes - A single list of tuples (layer, position, head) containing attention heads on whom the influence is to be measured
- num_at_time (optional) - Number of source_lists to be processed in a batch
- n_layers - Number of layers
- att_list - Attention probabilities if precomputed

Output consists of two lists - out_decomps and target_decomps:
- out_decomps - Consists of a list of tuples (rel, irrel) reflecting the decomposition of the _output_
- target_decomps - A list containining 12 (one for each layer) where each list is of length len(source_list). For any layer l, each entry of target_decomps[l] is a tuple (rel, irrel) decomposition of the target nodes at that layer for the corresponding set of source nodes. rel, irrel are of dimension #number of target nodes in layer l x head_size and the ordering of the target nodes in this layer is the same as provided 

In [17]:
def patch_hh_at_pos(encoding, model, target_nodes, pos=0, mean_acts=None, set_irrel_to_mean=False):
    pos_specific_hs = [
        [i for i in range(12)],
        [pos],
        [i for i in range(12)]
    ]
    all_heads = list(itertools.product(*pos_specific_hs))

    # patch one node at a time
    h_ctbn_list = []
    
    source_node_list = [node for node in all_heads if node not in target_nodes]
    print(source_node_list[0])
    prop_fn = functools.partial(prop_BERT_hh, encoding, model, target_nodes=target_nodes, device=device, mean_acts=mean_acts, set_irrel_to_mean=set_irrel_to_mean)
    out_decomps, target_decomps = batch_run(prop_fn, source_node_list)
    for i, _ in enumerate(source_node_list):
        ctbn = 0
        for l in range(12):
            if target_decomps[l][i][0].shape[0] != 0:
                rel_part = np.mean(abs(target_decomps[l][i][0]))
                irrel_part = np.mean(abs(target_decomps[l][i][1]))
                ctbn += rel_part / abs(rel_part + irrel_part) * 100
        h_ctbn_list.append(ctbn)
        
    return source_node_list, h_ctbn_list

In [12]:
# perform on one doc as an example
text = documents_full[0]
label = labels_full[0]
encoding = get_encoding(text, tokenizer, device)

In [24]:
# perform one iteration measuring effect of the source nodes to target nodes as an example
# note that target nodes get updated in each iteration

target_nodes = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]

all_source_hs = []
all_htbn = []
for pos in tqdm.tqdm(range(2)):
    with torch.no_grad():
        source_list, h_ctbn_list = patch_hh_at_pos(encoding, model, target_nodes, pos=pos, mean_acts=None, set_irrel_to_mean=False)
    torch.cuda.empty_cache()
    all_source_hs.append(source_list)
    all_htbn.append(h_ctbn_list)

  0%|          | 0/2 [00:00<?, ?it/s]

(0, 0, 0)


In [15]:
flat_ctbn = [c for sublist in all_htbn for c in sublist]
flat_source_h = [c for sublist in all_source_hs for c in sublist]


In [108]:
top_idx = sorted(range(len(flat_ctbn)), key=lambda i: flat_ctbn[i])[-6:]

In [109]:
for i in top_idx:
    print(flat_source_h[i], flat_ctbn[i])

[(6, 82, 4)] 36.659424751996994
[(5, 82, 4)] 42.90880411863327
[(3, 82, 0)] 51.443591713905334
[(4, 82, 0)] 87.46089041233063
[(5, 82, 0)] 103.08798849582672
[(6, 82, 0)] 132.23715126514435


In [112]:
# save the identified heads
path = f"{base_dir}/output/{args['task']}/{args['model_type']}_{args['field']}/h3"
os.makedirs(path, exist_ok=True)

with open(os.path.join(path, f"flat_source_h.pkl"), 'wb') as handle:
    pickle.dump(flat_source_h, handle)
    
with open(os.path.join(path, f"flat_source_h.pkl"), 'rb') as handle:
    back = pickle.load(handle)

## Examine the attended words by the identified heads

In [24]:
def collect_attended_tokens_hh(positives_heads, device, tokenizer, N=100, Z_thres=2, percentile=75, use_perc=False):
    index_lst = random.sample(range(0, len(documents_full)), N)
    docs = [documents_full[i] for i in index_lst]
    
    collect = collections.defaultdict(int)
    for doc in docs:
        encoding = get_encoding(doc, tokenizer, device)
        
        _, _, raw_att_probs_lst = prop_BERT_hh(encoding, model, [[]], [], device=device, output_att_prob=True)
        raw_att_probs = torch.stack(raw_att_probs_lst).cpu().numpy()

        avg_att_m = np.zeros((512))
        for level, pos, h in positives_heads:
            att_m = raw_att_probs[level, h, pos, :]
            avg_att_m += att_m

        avg_att_m /= len(positives)
        
        # convert to word level
        interval_dict, word_lst = compute_word_intervals(encoding, tokenizer)
        word_att_m = combine_token_attn(interval_dict, avg_att_m)
        
        if use_perc:
            perc_cutoff = np.percentile(word_att_m, percentile)
            positive_words = np.where(word_att_m > perc_cutoff)
        else:
            Z = (word_att_m - np.mean(word_att_m)) / np.std(word_att_m)
            positive_words = np.where(Z > Z_thres)
        
        for w_idx in positive_words[0]:
            w = word_lst[w_idx]
            #collect[w] += 1
            collect[w] += word_att_m[w_idx]
            
    return collect

In [65]:
def collect_attended_tokens_hh_rm_pos(positives_heads, device, tokenizer, N=100, Z_thres=2, percentile=75, use_perc=False):
    index_lst = random.sample(range(0, len(documents_full)), N)
    docs = [documents_full[i] for i in index_lst]
    
    collect = collections.defaultdict(int)
    for doc in docs:
        encoding = get_encoding(doc, tokenizer, device)
        
        _, _, raw_att_probs_lst = prop_BERT_hh(encoding, model, [[]], [], device=device, output_att_prob=True)
        raw_att_probs = torch.stack(raw_att_probs_lst).cpu().numpy()

        avg_att_m = np.zeros((512))
        for level, _, h in positives_heads:
            att_m = raw_att_probs[level, h, :, :]
            #att_m = np.mean(att_m, axis=0)
            max_row = np.unravel_index(np.argmax(att_m, axis=None), att_m.shape)[0]
            avg_att_m += att_m[max_row, :]

        avg_att_m /= len(positives)
        
        # convert to word level
        interval_dict, word_lst = compute_word_intervals(encoding, tokenizer)
        word_att_m = combine_token_attn(interval_dict, avg_att_m)
        
        if use_perc:
            perc_cutoff = np.percentile(word_att_m, percentile)
            positive_words = np.where(word_att_m > perc_cutoff)
        else:
            Z = (word_att_m - np.mean(word_att_m)) / np.std(word_att_m)
            positive_words = np.where(Z > Z_thres)
        
        for w_idx in positive_words[0]:
            w = word_lst[w_idx]
            #collect[w] += 1
            collect[w] += word_att_m[w_idx]
            
    return collect

In [62]:
negatives = [(11, 2), (11, 5), (11, 6), (11, 9), (11, 10), (11, 11),
             (10, 3), (10, 4), (10, 5), (10, 6), (10, 9), (10, 10), (10, 11),
             (9, 1), (9, 2), (9, 3), (9, 4), (9, 5), (9, 6), (9, 8), (9, 9), (9, 10), (9, 11),
             (8, 1), (8, 2), (8, 3), (8, 4), (8, 5), (8, 6), (8, 7), (8, 8), (8, 9), (8, 10), (8, 11),
             (7, 1), (7, 2), (7, 3), (7, 4), (7, 5), (7, 7), (7, 8), (7, 9), (7, 10),
             (6, 1), (6, 2), (6, 3), (6, 5), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11),
             (5, 1), (5, 2), (5, 3), (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), (5, 10), (5, 11),
             (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (4, 11),
             (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9), (3, 10), (3, 11),
             (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11),
             (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 0), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11),
             (0, 0), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 8), (0, 10), (0, 11),
            ]

In [88]:
# the identified attn heads using CD-T
pos_specific_hs = [
            [i for i in range(12)],
            [i for i in range(512)],
            [i for i in range(12)]
        ]
all_heads = list(itertools.product(*pos_specific_hs))
random_heads = random.sample(all_heads, 6)
positives = random_heads

In [107]:
positives = [(1, 169, 2), (2, 169, 2), (2, 169, 3), (4, 169, 8), (1, 411, 3)]

In [108]:
positive_attended_token_freq = collect_attended_tokens_hh_rm_pos(positives, device, tokenizer, N=200, use_perc=True)
positive_attended_token_freq = sorted(positive_attended_token_freq.items(), key=lambda k_v: k_v[1], reverse=True)

In [61]:
#h1 = [(10, 82, 0), (10, 61, 8), (10, 82, 7), (10, 176, 2), (10, 467, 1), (10, 91, 7)]
#h2 = [(7, 82, 11), (7, 82, 0), (7, 82, 6), (9, 82, 0), (9, 91, 7), (8, 82, 0)]
#h3 = [(6, 82, 4), (5, 82, 4), (3, 82, 0), (4, 82, 0), (5, 82, 0), (6, 82, 0)]
#positives = [(0, 82, 9), (0, 82, 1), (0, 82, 7), (1, 82, 6), (0, 82, 6), (2, 82, 0)]
positive_attended_token_freq = collect_attended_tokens_hh(positives, device, tokenizer, N=500, use_perc=True)
positive_attended_token_freq = sorted(positive_attended_token_freq.items(), key=lambda k_v: k_v[1], reverse=True)

In [109]:
import json
with open('pp15_h2.json', 'w') as fp:
    json.dump(positive_attended_token_freq, fp)

## Tests

In [42]:
text = documents_full[0]
label = labels_full[0]
encoding = get_encoding(text, tokenizer, device)

In [96]:
"""
source_list_30 = [#list(itertools.product(range(12), range(512), range(12))), 
                  # list(itertools.product(range(12), range(70, 85), range(12))), 
                  # [(11, 0, i) for i in range(12)]
                  [(0, 0, 0)]] * 30
source_list_60 = [#list(itertools.product(range(12), range(512), range(12))), 
                  # list(itertools.product(range(12), range(70, 85), range(12))), 
                  # [(11, 0, i) for i in range(12)]
                  [(0, 0, 0)], []] * 30
"""
target_nodes = [(11, 8), (11, 0), (11, 1), (11, 4), (11, 3), (11, 7)]
source_list = [[(5, 7, 0)], [(5, 5, 0)]]

In [None]:
out_decomps, target_decomps, _ = prop_BERT_hh(encoding, model, source_list, target_nodes)

In [103]:
out_decomps

[(tensor([-0.0171,  0.0302, -0.0149], device='cuda:0'),
  tensor([-3.2685,  6.1339, -3.2953], device='cuda:0')),
 (tensor([-0.0179,  0.0306, -0.0147], device='cuda:0'),
  tensor([-3.2677,  6.1335, -3.2955], device='cuda:0'))]

In [104]:
target_decomps[11][0][0].shape

torch.Size([2, 64])

In [110]:
target_decomps[0][0][0].shape

torch.Size([0, 64])