In [None]:
import argparse
import numpy as np
import os
import pandas as pd
import scipy as sp
import sys
import torch
import torch.nn.functional as F
import warnings
import random
import json
import pickle
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, paired_distances
from sklearn.metrics import accuracy_score, confusion_matrix
from scipy.stats import spearmanr
import seaborn as sns
import collections
from tqdm import tqdm

warnings.filterwarnings("ignore")

base_dir = os.path.split(os.getcwd())[0]
sys.path.append(base_dir)
sys.path.append(f"{base_dir}/turing/examples-raw/gluesst_finetune/")
sys.path.append(f"{base_dir}/turing/src/")

from argparse import Namespace
from methods.bag_of_ngrams.processing import cleanReports, cleanSplit, stripChars
from methods.interpretations.utils import compute_input_type_attention
from methods.interpretations.integrated_gradients.utils import forward_with_softmax, summarize_attributions
from pyfunctions.general import extractListFromDic, readJson
from pyfunctions.pathology import extract_synoptic, fixLabelProstateGleason, fixProstateLabels, fixLabel, exclude_labels
from pyfunctions.feature_anaysis_utils import center, calculate_geometry, compute_RSA, get_projection, low_rank_approx, rank_1_approx, get_acc
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import AutoTokenizer, AutoModel
from transformers import BertTokenizer, BertForSequenceClassification
from turing.pathology.path_utils import evaluate, extract_features, load_tnlr_base, load_tnlr_tokenizer, path_dataset

# Finetuned feature space is highly sparsified -- PC explained ratios

In [None]:
models = ['bert', 'tnlr', 'pubmed_bert', 'biobert', 'clinical_biobert']
features = []
field = 'PrimaryGleason' #'PrimaryGleason','SecondaryGleason', 'MarginStatusNone', 'SeminalVesicleNone'

for m in models:
    model_folder = f"{base_dir}/output/rsa/{m}/{field}"

    p = os.path.join(model_folder, f"1000_cls_logits_l12_ft_best.pkl")
    # open pkl file
    with open(p, 'rb') as handle:
        f = pickle.load(handle)
        features.append(f)
center_f = center(features)

In [None]:
pc_exp_ratios = {m:[] for m in models}

for i, cf in enumerate(center_f):
    my_model = PCA(n_components=5)
    my_model.fit_transform(cf)
    pc_exp_ratios[models[i]].extend(my_model.explained_variance_ratio_.cumsum())
    #print (my_model.explained_variance_)
    #print (my_model.explained_variance_ratio_)
    #print (my_model.explained_variance_ratio_.cumsum())

In [None]:
dt = {}
for m in models:
    results_dt = pd.DataFrame(columns=['PC Index', 'Explained Variance Ratio - Cumulative Sum'])
    i = 0
    for l in range(5):
        row = pd.Series({'PC Index': l, 'Explained Variance Ratio - Cumulative Sum': pc_exp_ratios[m][l]}, name=i)
        results_dt = results_dt.append(row)
        i+=1
    dt[m] = results_dt

In [None]:
dt['bert']['model'] = 'BERT'
dt['tnlr']['model'] = 'TNLR'
dt['pubmed_bert']['model'] = 'PubMed BERT'
dt['biobert']['model'] = 'BioBERT'
dt['clinical_biobert']['model'] = 'Clinical BioBERT'

data = pd.concat([dt['bert'], dt['tnlr'], dt['pubmed_bert'], dt['biobert'], dt['clinical_biobert']])

In [None]:
plt.style.use(os.path.join(f"{base_dir}/theme_bw.mplstyle"))

plt.figure(figsize=(10, 7.5))

# Remove the plot frame lines. They are unnecessary chartjunk.  
ax = plt.subplot(111)  
ax.spines["top"].set_visible(False)   
ax.spines["right"].set_visible(False)    

# Ensure that the axis ticks only show up on the bottom and left of the plot.  
# Ticks on the right and top of the plot are generally unnecessary chartjunk.  
ax.get_xaxis().tick_bottom()  
ax.get_yaxis().tick_left()

ax.set_xticks([0, 1, 2, 3, 4]) 

plt.xlabel('PC Index') 
plt.ylabel('Cumulative Sum') 

plt.ylim((0.6, 1))

sns.lineplot(x = 'PC Index', y = 'Explained Variance Ratio - Cumulative Sum', data=data, hue='model', marker='o', ci=None)#, palette=color_scheme)

plt.title("PC Explained Variance Ratio - Primary Gleason")
plt.legend(title='Model', labels=['BERT', 'TNLR','PubMEd BERT', 'BioBERT','Clinical BioBERT'])

plt.show()

#plt.savefig('micro.png', dpi=300)

# Accuracy probing: remove principal components

In [None]:
# layer-wise analysis
models = ['bert', 'tnlr', 'pubmed_bert', 'biobert', 'clinical_biobert']
features = []
field = 'SeminalVesicleNone' #'PrimaryGleason','SecondaryGleason', 'MarginStatusNone', 'SeminalVesicleNone'

for m in models:
    model_folder = f"{base_dir}/output/rsa/{m}/{field}"

    p = os.path.join(model_folder, f"1000_cls_logits_l12_ft_best.pkl")
    # open pkl file
    with open(p, 'rb') as handle:
        f = pickle.load(handle)
        features.append(f)
center_f = center(features)

In [None]:
# model params
args = {
    'model_type': 'clinical_biobert', # tnlr, bert, pubmed_bert, biobert, clinical_biobert
    'task_name': 'sst-2',
    'do_train': False,
    'do_eval': True,
    'evaluate_during_training': True,
    'max_seq_length': 512,
    'do_lower_case': True,
    'per_gpu_train_batch_size': 8,
    'per_gpu_eval_batch_size': 8,
    'gradient_accumulation_steps': 1,
    'learning_rate': 7e-6,
    'weight_decay': 0.0,
    'adam_epsilon': 1e-8,
    'max_grad_norm': 1,
    'num_train_epochs': 3.0,
    'max_steps': -1,
    'warmup_ratio': 0.2,
    'logging_steps': 50,
    'eval_all_checkpoints': True,
    'no_cuda': False,
    'seed': 42,
    'metric_for_choose_best_checkpoint': None,
    'fp16': False,
    'fp16_opt_level': 'O1',
    'local_rank': -1,
    'num_train_epochs': 25, 
    'n_gpu': 1,
    'device': 'cuda',
    'run': 0
}

kwargs = Namespace(**args)

if args['model_type'] == 'bert':
    bert_path = 'bert-base-uncased'
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
elif args['model_type'] == 'pubmed_bert':
    bert_path = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
    tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
elif args['model_type'] == 'biobert':
    bert_path = "dmis-lab/biobert-v1.1"
    tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
elif args['model_type'] == 'clinical_biobert':
    bert_path = "emilyalsentzer/Bio_ClinicalBERT"
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
elif args['model_type'] == 'tnlr':
    vocab_file = f'{base_dir}/turing/src/tnlr/tokenizer/tnlr-uncased-vocab.txt'
    tokenizer = load_tnlr_tokenizer(vocab_file)

In [None]:
# Read in data
path = f"../data/prostate.json"
data = readJson(path)

# Clean reports
data = cleanSplit(data, stripChars)
data['dev_test'] = cleanReports(data['dev_test'], stripChars)
data = fixLabel(data)


# Create datasets
train_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['train']]
train_labels = [patient['labels'][field] for patient in data['train']]

val_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['val']]
val_labels = [patient['labels'][field] for patient in data['val']]

test_documents = [extract_synoptic(patient['document'].lower(), tokenizer) for patient in data['test']]
test_labels = [patient['labels'][field] for patient in data['test']]

# Exclude '2' and 'null'
if field in ['PrimaryGleason', 'SecondaryGleason']:
    train_documents, train_labels = exclude_labels(train_documents, train_labels)
    val_documents, val_labels = exclude_labels(val_documents, val_labels)
    test_documents, test_labels = exclude_labels(test_documents, test_labels)

le = preprocessing.LabelEncoder()
le.fit(train_labels)

# Handle new class
le_dict = dict(zip(le.classes_, le.transform(le.classes_)))
le_dict = {str(key):le_dict[key] for key in le_dict}

for label in val_labels + test_labels:
    if str(label) not in le_dict:
        le_dict[str(label)] = len(le_dict)

# Map processed label back to raw label
inv_le_dict = {v: k for k, v in le_dict.items()}

documents_full = train_documents + val_documents + test_documents
labels_full = train_labels + val_labels + test_labels

p_test = len(test_labels)/len(labels_full)
p_val = len(val_labels)/(len(train_labels) + len(val_labels))

train_docs, test_docs, train_labels, test_labels = train_test_split(documents_full, 
                                                                    labels_full, 
                                                                    test_size= p_test,
                                                                    random_state=args['run'])

train_docs, val_docs, train_labels, val_labels = train_test_split(train_docs, 
                                                                  train_labels, 
                                                                  test_size= p_val,
                                                                  random_state=args['run'])

In [None]:
# load in model
model_path = f"{base_dir}/output/fine_tuning/{args['model_type']}_{0}/{field}"
checkpoint_file = f"{model_path}/save_output"
config_file = f"{model_path}/save_output/config.json"

if args['model_type'] != 'tnlr':
    model = BertForSequenceClassification.from_pretrained(checkpoint_file, num_labels=len(le_dict), output_hidden_states=True)
else:
    model = load_tnlr_base(checkpoint_file, config_file, model_type='tnlrv3_classification', num_labels=len(le_dict))
    model.config.update({'output_hidden_states': True})

In [None]:
pc_rm_result = {m:{'f1':[]} for m in models} # comment this line out to gather results for all models

In [None]:
true_labels = train_labels[:1000]

# f1 vs pc removal
device = args['device']
inv_model_dict = {m:i for i, m in enumerate(models)}

cf = center_f[inv_model_dict[args['model_type']]]
d = cf.shape[1]

for k in tqdm(range(d-1, -1, -1)):
    W_k = low_rank_approx(k, cf)
    
    with torch.cuda.device(1):
        model = model.eval()
        model.to(device)
        W_k = torch.from_numpy(W_k[:, np.newaxis, :]).to(device)

        with torch.no_grad():
            a1 = model.bert.pooler(W_k)
            a2 = model.dropout(a1)
            logits = model.classifier(a2)

        f1, _, _ = get_acc(logits, true_labels)
    
    pc_rm_result[args['model_type']]['f1'].append(f1)

In [None]:
# after gathering results for all models:
dt = {}
for m in models:
    results_dt = pd.DataFrame(columns=['PC Index', 'F1'])
    i = 0
    for l in range(d):
        row = pd.Series({'PC Index': l, 'F1': pc_rm_result[m]['f1'][l]}, name=i)
        results_dt = results_dt.append(row)
        i+=1
    dt[m] = results_dt

In [None]:
dt['bert']['model'] = 'BERT'
dt['tnlr']['model'] = 'TNLR'
dt['pubmed_bert']['model'] = 'PubMed BERT'
dt['biobert']['model'] = 'BioBERT'
dt['clinical_biobert']['model'] = 'Clinical BioBERT'

data = pd.concat([dt['bert'], dt['tnlr'], dt['pubmed_bert'], dt['biobert'], dt['clinical_biobert']])

In [None]:
plt.style.use(os.path.join(f"{base_dir}/theme_bw.mplstyle"))

plt.figure(figsize=(10, 7.5))

# Remove the plot frame lines. They are unnecessary chartjunk.  
ax = plt.subplot(111)  
ax.spines["top"].set_visible(False)   
ax.spines["right"].set_visible(False)    

# Ensure that the axis ticks only show up on the bottom and left of the plot.  
# Ticks on the right and top of the plot are generally unnecessary chartjunk.  
ax.get_xaxis().tick_bottom()  
ax.get_yaxis().tick_left()

#ax.set_xticks(range(769)) 

plt.xlabel('First Bottom k PC(s) Used') 
plt.ylabel('F1') 

#plt.xlim((0, 769))

sns.lineplot(x = 'PC Index', y = 'F1', data=data, hue='model', ci=None)#, palette=color_scheme)

plt.title("PC Probing - Primary Gleason")
plt.legend(title='Model', labels=['BERT', 'TNLR','PubMEd BERT', 'BioBERT','Clinical BioBERT'])

plt.show()

#plt.savefig('micro.png', dpi=300)

## close-up

In [None]:
dt = {}
for m in models:
    results_dt = pd.DataFrame(columns=['PC Index', 'F1'])
    i = 0
    for l in range(720, d):
        row = pd.Series({'PC Index': l, 'F1': pc_rm_result[m]['f1'][l]}, name=i)
        results_dt = results_dt.append(row)
        i+=1
    dt[m] = results_dt

In [None]:
dt['bert']['model'] = 'BERT'
dt['tnlr']['model'] = 'TNLR'
dt['pubmed_bert']['model'] = 'PubMed BERT'
dt['biobert']['model'] = 'BioBERT'
dt['clinical_biobert']['model'] = 'Clinical BioBERT'

data = pd.concat([dt['bert'], dt['tnlr'], dt['pubmed_bert'], dt['biobert'], dt['clinical_biobert']])

In [None]:
plt.style.use(os.path.join(f"{base_dir}/theme_bw.mplstyle"))

plt.figure(figsize=(10, 7.5))

# Remove the plot frame lines. They are unnecessary chartjunk.  
ax = plt.subplot(111)  
ax.spines["top"].set_visible(False)   
ax.spines["right"].set_visible(False)    

# Ensure that the axis ticks only show up on the bottom and left of the plot.  
# Ticks on the right and top of the plot are generally unnecessary chartjunk.  
ax.get_xaxis().tick_bottom()  
ax.get_yaxis().tick_left()

#ax.set_xticks(range(769)) 

plt.xlabel('First Bottom k PC(s) Used') 
plt.ylabel('F1') 

plt.xlim((719, 770))

sns.lineplot(x = 'PC Index', y = 'F1', data=data, hue='model', marker='o', ci=None)#, palette=color_scheme)

plt.title("PC Probing (Close-up) - Primary Gleason")
plt.legend(title='Model', labels=['BERT', 'TNLR','PubMEd BERT', 'BioBERT','Clinical BioBERT'])

plt.show()

#plt.savefig('micro.png', dpi=300)