<h2> Interpreting Bert </h2>

The aim of this notebook is to try understand how BERT models make decisions. To do this we take the airline tweets dataset and perform sentiment analysis on it using BERT. We then make use of interpret-text , an opensource library to help us understand our trained model.

This process would take a lot of time in a CPU environment hence GPU is strongly recommended.
To have the dashboard load up , please make sure you have jswidgets enabled in your jupyter environment.
<br>

Use the following jupyterhub image for this notebook:
<br>
tensorflow-gpu-3.6-CUDA10.1


Run the following command and restart kernel to make sure it is enabled : 
<br>
<b>jupyter nbextension enable --py --sys-prefix widgetsnbextension</b>

In [None]:
import pandas as pd
import numpy as np
import scrapbook as sb
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from interpret_text.experimental.common.utils_bert import Language, Tokenizer, BERTSequenceClassifier
from interpret_text.experimental.common.timer import Timer
import re
from sklearn.linear_model.base import MultiOutputMixin
from sklearn.linear_model import LinearRegression
from interpret_text.experimental.unified_information import UnifiedInformationExplainer
import json

Once we are done importing the dependencies , we import our data into the dataframe , perform some pre processing and split the data into train and test sets.

In [None]:
tweets = pd.read_csv('Tweets.csv')

#Shuffling the data
tweets.sample(frac=1)

features = tweets.iloc[:, 10].values
labels = tweets.iloc[:, 1].values
processed_features = []

for sentence in range(0, len(features)):
    #Getting rid of special characters
    processed_feature = re.sub(r'\W', ' ', str(features[sentence]))
    # remove all single characters
    processed_feature= re.sub(r'\s+[a-zA-Z]\s+', ' ', processed_feature)
    # Remove single characters from the start
    processed_feature = re.sub(r'\^[a-zA-Z]\s+', ' ', processed_feature) 
    # Substituting multiple spaces with single space
    processed_feature = re.sub(r'\s+', ' ', processed_feature, flags=re.I)
    # Removing prefixed 'b'
    processed_feature = re.sub(r'^b\s+', '', processed_feature)
    # Converting to Lowercase
    processed_feature = processed_feature.lower()
    processed_features.append(processed_feature)
    
X_train, X_test, y_train, y_test = train_test_split(processed_features, labels, test_size=0.2, random_state=0)

In [None]:
label_encoder = LabelEncoder()
labels_train = label_encoder.fit_transform(y_train)
labels_test = label_encoder.transform(y_test)

Here we define some hyper parameters for our mode before we set it up for training 

In [None]:
TEST_DATA_FRACTION = 1
NUM_EPOCHS = 1

torch.cuda.set_device(0) 

if torch.cuda.is_available():
    BATCH_SIZE = 1
else:
    BATCH_SIZE = 8

DATA_FOLDER = "./temp"
BERT_CACHE_DIR = "./temp"
LANGUAGE = Language.ENGLISH
TO_LOWER = True
MAX_LEN = 50
BATCH_SIZE_PRED = 512
TRAIN_SIZE = 0.6
LABEL_COL = "genre"
TEXT_COL = "sentence1"

We tokenize our text using a tokenizer provided by the interpret-text library itself.

In [None]:
tokenizer = Tokenizer(Language.ENGLISH, to_lower=TO_LOWER, cache_dir=BERT_CACHE_DIR)
tokens_train = tokenizer.tokenize(X_train)
tokens_test = tokenizer.tokenize(X_test)

tokens_train, mask_train, _ = tokenizer.preprocess_classification_tokens(tokens_train, MAX_LEN)
tokens_test, mask_test, _ = tokenizer.preprocess_classification_tokens(tokens_test, MAX_LEN)

Initializing the classifier.

In [None]:
classifier = BERTSequenceClassifier(language=LANGUAGE, num_labels=3, cache_dir=BERT_CACHE_DIR)

We now train the model, keeping track of the time elapsed.

In [None]:
# with Timer() as t:
#     classifier.fit(token_ids=tokens_train,
#                     input_mask=mask_train,
#                     labels=labels_train,    
#                     num_epochs=NUM_EPOCHS,
#                     batch_size=BATCH_SIZE,    
#                     verbose=True)    
# print("[Training time: {:.3f} hrs]".format(t.interval / 3600))

Now that we have completed the training. In the next steps we perform predictions on test data and train our explainer.

If you have already run the model before just uncomment and load the pretrained model itself.

In [None]:
#torch.save(classifier,'saved_model.pth')
classifier = torch.load('saved_model.pth')
preds = classifier.predict(token_ids=tokens_test, 
                           input_mask=mask_test, 
                           batch_size=BATCH_SIZE_PRED)

report = classification_report(labels_test, preds, target_names=label_encoder.classes_, output_dict=True) 
accuracy = accuracy_score(labels_test, preds)
print("accuracy: {}".format(accuracy))


Since we have decided to use Unified Information Explainer for this model we need to initialize and train with our train data.

In [None]:
print(json.dumps(report, indent=4, sort_keys=True))

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")

classifier.model.to(device)
for param in classifier.model.parameters():
    param.requires_grad = False
classifier.model.eval()

#Target Layer is the parameter which decides   
interpreter_unified = UnifiedInformationExplainer(model=classifier.model, 
                                 train_dataset=list(X_train), 
                                 device=device, 
                                 target_layer=14, 
                                 classes=label_encoder.classes_)

We can now use an of the test samples, make a predictions and use the dashboard for the said prediction.
We can use indeces to select test samples or we can just use a sentences and label of our own.

In [None]:
idx = 1010
text = X_test[idx]
true_label = y_test[idx]
predicted_label = label_encoder.inverse_transform([preds[idx]])
print(text, true_label, predicted_label)

explanation_unified = interpreter_unified.explain_local(text, true_label)
from interpret_text.experimental.widget import ExplanationDashboard

In [None]:
ExplanationDashboard(explanation_unified)