In [2]:
%load_ext autoreload
%autoreload 2

In [31]:
# no need to use these
#import sys
#sys.path.append("/accounts/campus/aliyahhsu/.local/bin") #/accounts/campus/aliyahhsu/.local/bin

#import sys
#print(sys.path)

In [None]:
# prepare the env
#!python3.12 -m pip install transformer_lens
#!python3.12 -m pip install --upgrade setuptools
#!python3.12 -m pip install --upgrade packaging
#!python3.12 -m pip install wandb==0.17.2

In [3]:
import argparse
import numpy as np
import os
import sys
import pandas as pd
import scipy as sp
import torch
import torch.nn.functional as F
import warnings
import random
import collections

# CD-T Imports
import math
import tqdm
import itertools

from torch import nn

warnings.filterwarnings("ignore")

base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)

from argparse import Namespace
from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars
from pyfunctions.general import extractListFromDic, readJson, combine_token_attn, compute_word_intervals
from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels
from pyfunctions.cdt_basic import *
from pyfunctions.cdt_source_to_target import *
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AutoTokenizer, AutoModel
from transformers import BertTokenizer, BertForSequenceClassification

In [4]:
torch.autograd.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f857af2b6e0>

# Load data

## Model Args Setup

In [6]:
args = {
    'model_type': 'bert', # bert, medical_bert, pubmed_bert, biobert, clinical_biobert
    'field': 'PrimaryGleason'
}

device = 'cuda:0'

In [7]:
if args['model_type'] == 'bert':
    bert_path = 'bert-base-uncased'
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
elif args['model_type'] == 'medical_bert':
    bert_path = f"{base_dir}/models/pretrained/bert_pretrain_output_all_notes_150000/"
    tokenizer = BertTokenizer.from_pretrained(bert_path, local_files_only=True)
elif args['model_type'] == 'pubmed_bert':
    bert_path = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
elif args['model_type'] == 'pubmed_bert_full':
    bert_path = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
elif args['model_type'] == 'biobert':
    bert_path = "dmis-lab/biobert-v1.1"
    tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
elif args['model_type'] == 'clinical_biobert':
    bert_path = "emilyalsentzer/Bio_ClinicalBERT"
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

# Load data

In [None]:
# Read in data
#field = 'PrimaryGleason' # out of PrimaryGleason, SecondaryGleason', 'MarginStatusNone', 'SeminalVesicleNone'
path = os.path.join(base_dir, "data/prostate.json")
data = readJson(path)

# Clean reports
data = cleanSplit(data, stripChars)
data['dev_test'] = cleanReports(data['dev_test'], stripChars)
data = fixLabel(data)

train_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['train']]
val_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['val']]
test_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['test']]
print(len(train_documents), len(val_documents),len(test_documents))

In [13]:
# Create datasets
train_labels = [patient['labels'][args['field']] for patient in data['train']]
val_labels = [patient['labels'][args['field']] for patient in data['val']]
test_labels = [patient['labels'][args['field']] for patient in data['test']]

train_documents, train_labels = exclude_labels(train_documents, train_labels)
val_documents, val_labels = exclude_labels(val_documents, val_labels)
test_documents, test_labels = exclude_labels(test_documents, test_labels)

le = preprocessing.LabelEncoder()
le.fit(train_labels)

# Map raw label to processed label
le_dict = dict(zip(le.classes_, le.transform(le.classes_)))
le_dict = {str(key):le_dict[key] for key in le_dict}

for label in val_labels + test_labels:
    if str(label) not in le_dict:
        le_dict[str(label)] = len(le_dict)

# Map processed label back to raw label
inv_le_dict = {v: k for k, v in le_dict.items()}

In [14]:
documents_full = train_documents + val_documents + test_documents
labels_full = train_labels + val_labels + test_labels

## Load model

In [18]:
#load finetuned model
model_path = os.path.join(base_dir, "models/path/bert_PrimaryGleason") # /{args['model_type']}_{args['field']}"
checkpoint_file = os.path.join(model_path, "save_output")
# config_file = f"{model_path}/save_output/config.json"

model = BertForSequenceClassification.from_pretrained(checkpoint_file, num_labels=len(le_dict), output_hidden_states=True)

model = model.eval()
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

## Sanity checks

In [145]:
source_list = [[(0, 0, x) for x in range(12)], [(5, 78, 6)]] #(layer, pos, attn_head)
target_nodes = [(i+7, 6, 11) for i in range(3)]

text = documents_full[120]
encoding = get_encoding(text, tokenizer, device)
out_decomps, target_decomps, _ = prop_BERT_hh(encoding, model, source_list, target_nodes, device, mean_acts=None, output_att_prob=False, set_irrel_to_mean=False)

In [142]:
out_decomps

[(array([[-0.2216931 ,  0.37654954, -0.23705788]], dtype=float32),
  array([[-2.815079 ,  5.86899  , -3.3938763]], dtype=float32)),
 (array([[-0.02794807,  0.03191791, -0.00936984]], dtype=float32),
  array([[-3.0092185,  6.213666 , -3.6212041]], dtype=float32))]

In [146]:
out_decomps

[(array([[-0.36410302,  0.5577252 , -0.29722145]], dtype=float32),
  array([[-2.672668 ,  5.687814 , -3.3337135]], dtype=float32)),
 (array([[-0.03016923,  0.04365933, -0.01757346]], dtype=float32),
  array([[-3.0069966,  6.201925 , -3.6130009]], dtype=float32))]

In [147]:
logits = model(encoding['input_ids']).logits
logits

tensor([[-3.8459,  5.7531, -2.3518]], device='cuda:0')

In [151]:
# make sure rel + irrel = total logits

for i in random.sample(range(len(documents_full)), 20):
    text = documents_full[i]
    encoding = get_encoding(text, tokenizer, device)
    logits = model(encoding['input_ids']).logits
    
    out_decomps, target_decomps, _ = prop_BERT_hh(encoding, model, source_list, target_nodes, device, mean_acts=None, output_att_prob=False, set_irrel_to_mean=False)
    
    for x in out_decomps:
        #print(np.abs((x[0] + x[1]) - logits.detach().cpu().numpy())[0])
        try:
            assert(np.mean(np.abs((x[0] + x[1]) - logits.detach().cpu().numpy())[0]) <= 1e-04)
        except:
            print(np.abs((x[0] + x[1]) - logits.detach().cpu().numpy())[0])
            print(i)

[0.0165     0.17245507 0.24120665]
736
[0.0164957  0.17246962 0.24122047]
736
[0.0141449  0.14537406 0.16366768]
1286
[0.01414299 0.14536524 0.163661  ]
1286
[0.02013254 0.54891944 0.5833664 ]
2384
[0.0201335 0.5489185 0.583364 ]
2384
[0.02800894 0.05061698 0.11377645]
1562
[0.02800798 0.0506196  0.11377835]
1562
[0.46180725 0.5767827  1.0922997 ]
2472
[0.4618168  0.57678556 1.0923133 ]
2472
[0.30555463 0.33365822 0.6787486 ]
368
[0.30554295 0.33365917 0.67874384]
368
[0.03361273 0.18153238 0.20721531]
2754
[0.03361177 0.18152499 0.20720553]
2754


In [121]:
target_decomps

[[(array([], shape=(0, 64), dtype=float32),
   array([], shape=(0, 64), dtype=float32)),
  (array([], shape=(0, 64), dtype=float32),
   array([], shape=(0, 64), dtype=float32)),
  (array([], shape=(0, 64), dtype=float32),
   array([], shape=(0, 64), dtype=float32)),
  (array([], shape=(0, 64), dtype=float32),
   array([], shape=(0, 64), dtype=float32)),
  (array([], shape=(0, 64), dtype=float32),
   array([], shape=(0, 64), dtype=float32)),
  (array([], shape=(0, 64), dtype=float32),
   array([], shape=(0, 64), dtype=float32)),
  (array([], shape=(0, 64), dtype=float32),
   array([], shape=(0, 64), dtype=float32)),
  (array([[ 0.02713919, -0.0622553 ,  0.00735879, -0.02987053, -0.08705581,
            0.11297631, -0.03065166,  0.02380989, -0.08406242, -0.01910041,
            0.06665847, -0.01237809,  0.11193677, -0.08449186, -0.0306458 ,
           -0.09281306, -0.0190668 ,  0.00445568, -0.0405965 , -0.1555212 ,
           -0.01881754,  0.06485477, -0.03708712,  0.05026134, -0.0062646

tensor([[ 6.1244, -1.8527, -3.9428]], device='cuda:0')