In [None]:
#set the variable below
multiclass_classification = True

if not multiclass_classification:
  num_classes = 2
  fold_index = 0
  experiment_name = "adrenal_NLP_biobert_binary_reps_UMAP"
  k_folds = 5
  experiment_name = experiment_name + str(fold_index)
  print(experiment_name)
  saved_weights_path = 'source_folder/' + 'saved_weights_' + experiment_name + '.pt'
  
else:
  num_classes = 6
  fold_index = 0
  k_folds = 5

  experiment_name = "adrenal_NLP_biobert_multiclass_new_"
  experiment_name = experiment_name + str(fold_index)
  saved_weights_path = 'source_folder/' + 'saved_weights_'  + experiment_name + '.pt'
  

In [None]:
!git clone https://github.com/hila-chefer/Transformer-Explainability.git

import os
os.chdir(f'./Transformer-Explainability')

!pip install -r requirements.txt
!pip install captum
!pip install transformers

In [None]:
import torch
if torch.cuda.is_available():    
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
import pandas as pd
import numpy as np

from sklearn import preprocessing
from keras.preprocessing.sequence import pad_sequences

from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold

In [None]:
#for binary classification
if not(multiclass_classification):
  # Load the dataset into a pandas dataframe.
  df = pd.read_csv("binary_data.csv")
  df['sentence'] = df['text']
  df = df[['sentence', 'label']]

  #check for and drop empty sentences
  print(df.isna().any())
  df = df.dropna(how='any', subset = ['sentence'])
  print(df['label'].unique())
  # Report the number of sentences.
  print('Number of training sentences: {:,}\n'.format(df.shape[0]))

  # Display 10 random rows from the data.
  print(df.sample(10))


In [None]:
#for multiclass classification
if multiclass_classification:
  df = pd.read_csv("multiclass_data.csv")
  print(df.head())
  
  train_df = df[['sample_text', 'sample_label']]
  print(len(train_df))
  
  print(train_df['sample_label'].unique())
  print(train_df['sample_label'].value_counts())

  text_list = list(train_df['sample_text'])
  label_list = list(train_df['sample_label'])

  #convert labels to one-hot encoding
  le = preprocessing.LabelEncoder()
  le.fit(label_list)
  labels = le.transform(train_df['sample_label'])#.astype('category'))
  
  new_df = pd.DataFrame()
  new_df['sentence'] = text_list
  new_df['label'] = labels
  
  df = new_df

In [None]:
# Get the lists of sentences and their labels.
sentences = df.sentence.values
labels = df.label.values

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")

In [None]:
# Tokenize all of the sentences and map the tokens to thier word IDs.
input_ids = []

# For every sentence...
for sent in sentences:
    encoded_sent = tokenizer.encode(
                        sent,                    
                        add_special_tokens = True
                   )
    input_ids.append(encoded_sent)

In [None]:
#pad tokens
MAX_LEN = 128
input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", 
                          value=0, truncating="post", padding="post")


In [None]:
attention_masks = []

for sent in input_ids:
    att_mask = [int(token_id > 0) for token_id in sent]
    attention_masks.append(att_mask)

##Training & Validation Split


In [None]:
#5-fold cross-validation
test_size_fraction = 0.10

train_inputs, test_inputs, train_labels, test_labels = train_test_split(input_ids, labels, 
                                                            random_state=2018, test_size=test_size_fraction, stratify=labels)


train_masks, test_masks, _, _ = train_test_split(attention_masks, labels,
                                             random_state=2018, test_size=test_size_fraction)


skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=2018)
skf.get_n_splits(input_ids, labels)

print(skf)

folds = {}
X = input_ids
y = labels
for i, (train_index, valid_index) in enumerate(skf.split(train_inputs, train_labels)):
    folds[i] = (train_index, valid_index)


train_index = folds[fold_index][0]
valid_index = folds[fold_index][1]
print(fold_index)


def get_attn_masks(attention_masks, index_list):
  return [attention_masks[idx] for idx in index_list]

train_inputs, train_masks, train_labels = input_ids[train_index], get_attn_masks(attention_masks, train_index), labels[train_index]
validation_inputs, validation_masks, validation_labels = input_ids[valid_index], get_attn_masks(attention_masks, valid_index), labels[valid_index]


In [None]:
# Convert to tensors
import torch 
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)
test_inputs = torch.tensor(test_inputs)

train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)
test_labels = torch.tensor(test_labels)

train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)
test_masks = torch.tensor(test_masks)

In [None]:
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler


train_batch_size = 16

# Create the DataLoader for our training set.
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)

# Create the DataLoader for our validation set.
validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=train_batch_size)

test_batch_size = 16
#Create a loader for the test set
test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=test_batch_size)

In [None]:
from transformers import BertForSequenceClassification, AdamW, BertConfig

model = BertForSequenceClassification.from_pretrained(
    "dmis-lab/biobert-base-cased-v1.1", 
    num_labels = num_classes, 
    output_attentions = True, 
    output_hidden_states = True 
)

model.cuda()

In [None]:
optimizer = AdamW(model.parameters(),
                  lr = 2e-5, 
                  eps = 1e-8 
                )


In [None]:
from transformers import get_linear_schedule_with_warmup

if not multiclass_classification:
  epochs = 4
else:
  epochs = 4
total_steps = len(train_dataloader) * epochs
print(total_steps)

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

In [None]:
import numpy as np

def flat_accuracy(preds, labels):
  
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [None]:
import time
import datetime

def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))


In [None]:
import random

# training code based on https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128


seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

train_loss_values = []
valid_loss_values = []

valid_accuracy_values = []
max_valid_accuracy = -1

# For each epoch...
for epoch_i in range(0, epochs):
    
    

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    t0 = time.time()

    total_loss = 0
    total_valid_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):

        if step % 40 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

        
        batch = [r.cuda() for r in batch]
        b_input_ids, b_input_mask, b_labels = batch
        model.zero_grad()        
        
        outputs = model(b_input_ids, 
                    token_type_ids=None, 
                    attention_mask=b_input_mask, 
                    labels=b_labels)
        loss = outputs[0]
        total_loss += loss.item()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        scheduler.step()

    avg_train_loss = total_loss / len(train_dataloader)            
    train_loss_values.append(avg_train_loss)

    print("")
    print("  Average training loss: {0:.5f}".format(avg_train_loss))
    print("  Training epoch took: {:}".format(format_time(time.time() - t0)))
  
    print("")
    print("Running Validation...")

    t0 = time.time()
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    
    for batch in validation_dataloader:
    
        
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        with torch.no_grad():        
          outputs = model(b_input_ids, 
                            token_type_ids=None, 
                            attention_mask=b_input_mask)
        logits = outputs[0]
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        
        tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        eval_accuracy += tmp_eval_accuracy
        nb_eval_steps += 1

    valid_accuracy = eval_accuracy/nb_eval_steps
    valid_accuracy_values.append(valid_accuracy)

    # Report the final accuracy for this validation run.
    print("  Accuracy: {0:.5f}".format(valid_accuracy))
    print("  Validation took: {:}".format(format_time(time.time() - t0)))

    if valid_accuracy > max_valid_accuracy:
      torch.save(model.state_dict(), saved_weights_path)
print("")
print("Training complete!")

In [None]:
import matplotlib.pyplot as plt
% matplotlib inline

import seaborn as sns

sns.set(style='darkgrid')
sns.set(font_scale=1.5)
plt.rcParams["figure.figsize"] = (12,6)
plt.plot(train_loss_values, 'b-o')
plt.title(experiment_name + " - Training loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")

out_dir = "source_folder/plots/"
out_confus = out_dir + experiment_name + '_train_loss' + '.png'
plt.savefig(out_confus, dpi=300,facecolor='w')

plt.show()

In [None]:
import matplotlib.pyplot as plt
% matplotlib inline
import seaborn as sns

sns.set(style='darkgrid')
sns.set(font_scale=1.5)
plt.rcParams["figure.figsize"] = (12,6)
plt.plot(valid_accuracy_values, 'b-o')
plt.title(experiment_name + " - Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")

out_confus = out_dir + experiment_name + '_valid_accuracy' + '.png'
plt.savefig(out_confus, dpi=300,facecolor='w')

plt.show()

In [None]:
from transformers import BertForSequenceClassification, AdamW, BertConfig

#load best model to predict
test_model = BertForSequenceClassification.from_pretrained("dmis-lab/biobert-base-cased-v1.1", num_labels = num_classes, output_hidden_states=True)
test_model.load_state_dict(torch.load(saved_weights_path))#saved_weights_path))
test_model.eval()
test_model.cuda()

In [None]:
#dump biobert embeddings to pickle file for visualisation

text_list = [tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(i, skip_special_tokens=True)) for i in test_inputs]
print(len(text_list))
print(len(new_test_y))

hover_data = pd.DataFrame({'index':np.arange(len(new_test_y)),'text': text_list,'label': le.inverse_transform(new_preds)})
#saving the data in a pickle file to plot a umap
import pickle

s = {}

s['representations']=reps #np_encoded_text
s['labels']=new_preds
s['hover_data']=hover_data

outfile=open('adrenal_NLP_biobert_multiclass_trained_predictions.pkl', 'wb')

pickle.dump(s, outfile)

outfile.close()

In [None]:
#test_accuracy
t0 = time.time()
eval_accuracy = 0
nb_eval_steps = 0
logits_stack = []
labels_stack = []
for step, batch in enumerate(test_dataloader):
        
  batch = tuple(t.to(device) for t in batch)
  b_input_ids, b_input_mask, b_labels = batch
  with torch.no_grad():        
    outputs = test_model(b_input_ids, 
                      token_type_ids=None, 
                      attention_mask=b_input_mask)
  logits = outputs[0]
  logits = logits.detach().cpu().numpy()
  label_ids = b_labels.to('cpu').numpy()
  logits_stack.append(logits)
  labels_stack.append(label_ids)
  tmp_eval_accuracy = flat_accuracy(logits, label_ids)
  eval_accuracy += tmp_eval_accuracy
  nb_eval_steps += 1

#calculate test accuracy
test_accuracy = eval_accuracy/nb_eval_steps
print("  Accuracy: {0:.5f}".format(test_accuracy))
print("  Test took: {:}".format(format_time(time.time() - t0)))

In [None]:
from sklearn.metrics import classification_report

test_preds = np.vstack(tuple(logits_stack))
new_test_y = test_labels.detach().cpu()
new_preds = np.argmax(test_preds, axis=1).flatten()

print(classification_report(new_test_y, new_preds))
# confusion matrix

pd.crosstab(new_test_y, new_preds)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score



out_dir = "source_folder"
master_sheet = pd.DataFrame()
master_sheet['label'] = new_test_y
master_sheet['Prediction'] = new_preds



In [None]:
if not multiclass_classification:
  y_true = master_sheet['label']
  y_pred = master_sheet['Prediction']


  conf = confusion_matrix(y_true, y_pred) 
  tick_labels = ["Normal", "Abnormal"]
  df_cm = pd.DataFrame(conf, index=tick_labels, columns=tick_labels)
  fig, ax = plt.subplots()
  sns.heatmap(df_cm, cmap='Blues', annot=True, fmt='g', annot_kws={"size": 16})
  ax.xaxis.tick_top()
  ax.xaxis.set_label_position('top')
  plt.xlabel('Automated diagnosis')
  plt.ylabel('Reference standard diagnosis')
  plt.title(experiment_name + ": Report level confusion matrix - (n = {})".format(master_sheet.shape[0]))#master_sheet.shape[0]))
  out_confus = out_dir + experiment_name + '_confusion_report_level' + '.png'
  plt.savefig(out_confus, dpi=300,facecolor='w')
  plt.show() 

else:
  y_true = master_sheet['label']
  y_pred = master_sheet['Prediction']


  conf = confusion_matrix(y_true, y_pred) 
  tick_labels = le.inverse_transform([0, 1, 2, 3, 4, 5])
  df_cm = pd.DataFrame(conf, index=tick_labels, columns=tick_labels)
  fig, ax = plt.subplots()
  sns.heatmap(df_cm, cmap='Blues', annot=True, fmt='g', annot_kws={"size": 18})
  ax.xaxis.tick_top()
  ax.xaxis.set_label_position('top')
  plt.xlabel('Automated diagnosis')
  plt.ylabel('Reference standard diagnosis')
  plt.title(experiment_name + ": Report level confusion matrix - (n = {})".format(master_sheet.shape[0]))#master_sheet.shape[0]))
  out_confus = out_dir + experiment_name + '_confusion_report_level' + '.png'
  plt.savefig(out_confus, dpi=300,facecolor='w')
  plt.show() 


In [None]:
#plot AUROC
y_true = master_sheet['label']
y_pred = master_sheet['Prediction']


fpr_adrenal, tpr_adrenal, _ = roc_curve(y_true, y_pred)
auc_adrenal = auc(fpr_adrenal, tpr_adrenal)
plt.plot(fpr_adrenal, tpr_adrenal, lw=2, label='Abnormal vs. Normal (AUC = %0.2f)' % auc_adrenal)
plt.legend()
plt.ylabel('Sensitivity')
plt.xlabel('1 - Specificity')
plt.title(experiment_name + ": Analysis of Reports, per report (n = {})".format(master_sheet.shape[0]))
out_aucpi = out_dir + experiment_name + '_AUC_per_report' + '.png'
plt.savefig(out_aucpi, dpi=300)
plt.show()


In [None]:
#plot AUPRC
y_true = master_sheet['label']
y_pred = master_sheet['Prediction']


precision_adrenal, recall_adrenal, _ = precision_recall_curve(y_true, y_pred)
ap_adrenal = average_precision_score(y_true, y_pred)

plt.plot(precision_adrenal, recall_adrenal, lw=2, label='Abnormal vs. Normal(Avg. Precision = %0.2f)' % ap_adrenal)
plt.legend()
plt.ylabel('Precision')
plt.xlabel('Recall')
plt.title(experiment_name + ": Analysis of Report data, per image (n = {})".format(master_sheet.shape[0]))
out_auprc = out_dir + experiment_name + '_PR_per_report' + '.png'
plt.savefig(out_auprc, dpi=300)
plt.show()


In [None]:
#calculate CI by using the folds as the samples
#for AUROC, AUPRC, Kappa, etc.

import numpy as np
import scipy
from scipy import stats

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h


f1_scores = [0.56, 0.80, 0.52, 0.70, 0.54, 0.05]
supports = [51, 796, 93, 444, 304, 17]

sum_supports = sum(supports)
kappa_value_array1 = [(i*j)/sum_supports for i, j in zip(f1_scores, supports)]
print("weighted avg F1 = ", sum(kappa_value_array1))


In [None]:
from sklearn.metrics import roc_auc_score
from math import sqrt

def roc_auc_ci(y_true, y_score, positive=1):
    AUC = roc_auc_score(y_true, y_score)
    N1 = sum(y_true == positive)
    N2 = sum(y_true != positive)
    Q1 = AUC / (2 - AUC)
    Q2 = 2*AUC**2 / (1 + AUC)
    SE_AUC = sqrt((AUC*(1 - AUC) + (N1 - 1)*(Q1 - AUC**2) + (N2 - 1)*(Q2 - AUC**2)) / (N1*N2))
    lower = AUC - 1.96*SE_AUC
    upper = AUC + 1.96*SE_AUC
    if lower < 0:
        lower = 0
    if upper > 1:
        upper = 1
    return (lower, upper)

In [None]:
import numpy as np
import pandas as pd

np.random.seed(2018)

from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split
from sklearn.metrics import auc
import matplotlib
import matplotlib.pyplot as plt

def roc_curve_and_score(y_test, pred_proba):
    fpr, tpr, _ = roc_curve(y_test.ravel(), pred_proba.ravel())
    roc_auc = roc_auc_score(y_test.ravel(), pred_proba.ravel())
    return fpr, tpr, roc_auc


plt.figure(figsize=(8, 6))
matplotlib.rcParams.update({'font.size': 14})
plt.grid()

tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)

colors = ['pink', 'green', 'red', 'blue', 'orange', 'cyan', 'purple', 'yellow']
for i in range(5):
  df = pd.read_csv('adrenal_binary_biobert_fold_' + str(i) + '.csv')
  y_true = df['label']
  y_pred = df['Prediction']
  fpr, tpr, roc_auc = roc_curve_and_score(y_true, y_pred)
  conf_inter = roc_auc_ci(y_true, y_pred)
  plt.plot(fpr, tpr, color=colors[i], lw=2, label='ROC AUC={0:.3f} {z} (fold {x})'.format(roc_auc, z = conf_inter, x = i+1), alpha=0.7)
  interp_tpr = np.interp(mean_fpr, fpr, tpr)
  interp_tpr[0] = 0.0
  tprs.append(interp_tpr)
  aucs.append(roc_auc)
plt.legend(loc="lower right")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1 - Specificity')
plt.ylabel('Sensitivity')

plt.title('Mean AUROC')
out_auroc_cumul = out_dir + experiment_name + 'multiAUROC' + '.png'
plt.savefig(out_auroc_cumul)
plt.show()


In [None]:
#calculate kappa values  and mean average precision
if multiclass_classification:
  from sklearn.metrics import cohen_kappa_score # as cohen_kappa_score
  kappa_unweighted = cohen_kappa_score(y_true, y_pred)#, weights="unweighted")
  kappa_linear = cohen_kappa_score(y_true, y_pred, weights="linear")
  kappa_quadratic = cohen_kappa_score(y_true, y_pred, weights="quadratic")
  print(kappa_unweighted, kappa_linear, kappa_quadratic)

  

In [None]:
#regex for multiclass

#mapping to encoder
label_dict = {
    'absent': 0, \
    'mass': 1, \
    'metastasis': 2, \
    'normal': 3, \
    'thickening': 4, \
    'unknown': 5
}

phrases_dict = {
    'absent': ['absent'], \
    'mass': ['mass', 'nodule'], \
    'metastasis': ['metastasis'], \
    'normal': ['normal'], \
    'thickening': ['thickening'], \
    'unknown': ['unknown']
}

#decreasing order of criticality
priority = ['mass', 'metastasis', 'thickening', 'absent', 'normal', 'unknown']

#to store predictions
predictions = []

#loop to detect mention and assign label

for s in sentences:
  l = 5
  for p in priority:
    if any([1 if w in s else 0 for w in phrases_dict[p]]):
      l = label_dict[p]
      break;
  
  #if l == '-111':
  #  print(s)
  predictions.append(l)

print(predictions[:10])
predictions.count(5)

from sklearn.metrics import classification_report
print(classification_report(labels, predictions))

#Feature Extraction for UMAP

In [None]:
model_feature_extraction = BertForSequenceClassification.from_pretrained(
    "dmis-lab/biobert-base-cased-v1.1", # Use the 12-layer BERT model, with an uncased vocab.
    num_labels = num_classes, # The number of output labels--2 for binary classification.
                    # You can increase this for multi-class tasks.   
    output_attentions = True, # Whether the model returns attentions weights.
    output_hidden_states = True # Whether the model returns all hidden-states.
)

# Tell pytorch to run this model on the GPU.
model.cuda()

In [None]:
model_feature_extraction.load_state_dict(torch.load(saved_weights_path))

#Transformer Explainability

###Code based on https://github.com/hila-chefer/Transformer-Explainability.git

In [None]:
os.chdir(f'/content/Transformer-Explainability')

In [None]:
import numpy as np
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from transformers import AutoTokenizer

import matplotlib
print(matplotlib.__version__)
from captum.attr import (
    visualization
)
import torch 

In [None]:
#load biobert model
vis_model = BertForSequenceClassification.from_pretrained("dmis-lab/biobert-base-cased-v1.1", num_labels = num_classes)
saved_weights_path = "saved_weights_adrenal_multiclass_biobert_fold_2.pt"
vis_model.load_state_dict(torch.load("/content/drive/My Drive/Fracture/" + saved_weights_path))
vis_model.eval()
vis_model.cuda()

In [None]:
# initialize the explanations generator
explanations = Generator(vis_model)

if multiclass_classification:
  classifications = le.inverse_transform([i for i in range(num_classes)])
else:
  classifications = ["NEGATIVE", "POSITIVE"]
print(classifications)

In [None]:
visualisations = []

In [None]:
def simpler_format(tokens, expl):
  #trying to merge weights for split tokens
  tokens2 = tokenizer.convert_tokens_to_string(tokens).split(' ')
  
  expl2 = expl

  for i in range(len(tokens)-1, -1, -1):
    if tokens[i].startswith('##'):
      expl2[i-1] = max(expl2[i], expl2[i-1])

  expl2 = [expl2[i] for i in range(len(expl2)) if not tokens[i].startswith('##')]
  expl2 = expl2[1:-1]
  tokens2 = tokens2[1:-1]
  assert len(tokens2) == len(expl2)
  return tokens2, expl2
  

In [None]:
test_sents = [tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(token_ids.numpy())) for token_ids in test_inputs]
print(len(test_sents))

In [None]:
     
visualisations = []
preds_vis_model = []
l = 0
sents = test_sents[l:l+10]

for i in range(len(test_sents)):
  text_batch = test_sents[i]             
  encoding = tokenizer(text_batch, return_tensors='pt')
  input_ids = encoding['input_ids'].to("cuda")
  attention_mask = encoding['attention_mask'].to("cuda")

  # true class is positive - 1
  true_class = test_labels[i]

  # generate an explanation for the input
  expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]
  # normalize scores
  expl = (expl - expl.min()) / (expl.max() - expl.min())

  # get the model classification
  output = torch.nn.functional.softmax(vis_model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
  classification = output.argmax(dim=-1).item()
  # get class name
  class_name = classifications[classification]

  preds_vis_model.append(classification)

   
  # if the classification is negative, higher exp lanation scores are more negative
  # flip for visualization
  #if class_name == "NEGATIVE":
  if class_name == "POSITIVE":
    expl *= (-1)

  tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
  print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
  tokens2, expl2 = simpler_format(tokens, expl)

  vis_data_records = [visualization.VisualizationDataRecord(
                                  expl2,
                                  output[0][classification],
                                  classification,
                                  true_class,
                                  1,
                                  1,       
                                  tokens2,
                                  1)]
  visualization.visualize_text(vis_data_records)
  visualisations.append(visualization.visualize_text(vis_data_records))

In [None]:
from IPython.display import display, HTML, Image
i = 73
filename = "0_2___" + str(i)
v = visualisations[i]
display(HTML(v.data))
with open(filename + '.html', 'w') as f:
  f.write(v.data)