# Explainable AI

In [1]:
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import classification_report
from transformers import BertTokenizer
import shap
from load_model import *

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

## Load Dataset

In [3]:
df = pd.read_csv('./data/test.csv')
X = df['text']
Y = df['suicide']

## Data pre-processing

In [4]:
sentences = [sen for sen in X]
sen_labels = torch.tensor(Y.values)

In [5]:
tokenizer = BertTokenizer.from_pretrained('./pretrained/bert-base-uncased')

max_length = 512

tokenized = tokenizer(sentences, padding=True, truncation=True, max_length=max_length, return_tensors='pt')

sen_ids = tokenized['input_ids']
attention_mask = tokenized['attention_mask']

print(sen_ids.size())

torch.Size([46413, 512])


## Load Saved Model

### LSTM

In [6]:
lstm_path = './model/LSTM_classifier.pt'
lstm_model = load_LSTMClassifier(lstm_path, device)

### BERT

In [7]:
bert_model_path = './model/Bert_classifier.pth'
pretrained_path = './pretrained/bert-base-uncased'

bert_model = load_BERTClassifier(bert_model_path, pretrained_path, device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./pretrained/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Predict

In [8]:
def predict(lstm_model, bert_model, ids, masks):
    lstm_pred_list = []
    bert_pred_list = []

    lstm_model.eval()
    bert_model.eval()

    with torch.no_grad():
        for i in range(len(ids)):
            hidden = lstm_model.init_hidden(1)

            X = ids[i].unsqueeze(dim=0).cuda()
            attention_mask = masks[i].unsqueeze(dim=0).cuda()
            
            lstm_pred, hidden = lstm_model(X, hidden)
            bert_pred = bert_model(X, attention_mask=attention_mask, token_type_ids=None)

            lstm_pred = lstm_pred.cpu().numpy()
            lstm_pred_list.append(lstm_pred.tolist())

            bert_pred = bert_pred['logits'].cpu()
            bert_pred = torch.nn.functional.softmax(bert_pred, dim=1)
            bert_pred_list.append(bert_pred.numpy().tolist())

    return lstm_pred_list, bert_pred_list

In [9]:
lstm_preds, bert_preds = predict(lstm_model, bert_model, sen_ids, attention_mask)

### Get Predicted Labels

In [10]:
lstm_preds = np.array(lstm_preds).reshape(-1, 2).argmax(1)
bert_preds = np.array(bert_preds).reshape(-1, 2).argmax(1)

In [11]:
print('LSTM report:')
print(classification_report(sen_labels, lstm_preds))

print('---------------')

print('BERT report:')
print(classification_report(sen_labels, bert_preds))

LSTM report:
              precision    recall  f1-score   support

           0       0.96      0.93      0.95     23187
           1       0.94      0.96      0.95     23226

    accuracy                           0.95     46413
   macro avg       0.95      0.95      0.95     46413
weighted avg       0.95      0.95      0.95     46413

---------------
BERT report:
              precision    recall  f1-score   support

           0       0.97      0.98      0.97     23187
           1       0.98      0.97      0.97     23226

    accuracy                           0.97     46413
   macro avg       0.97      0.97      0.97     46413
weighted avg       0.97      0.97      0.97     46413



## Get Wrong Samples

In [12]:
lstm_fp = []
lstm_fn = []
bert_fp = []
bert_fn = []

for i in range(len(sen_labels)):
    if lstm_preds[i] != sen_labels[i]:
        if lstm_preds[i] == 1:
            lstm_fp.append(sentences[i])
        else:
            lstm_fn.append(sentences[i])

    if bert_preds[i] != sen_labels[i]:
        if bert_preds[i] == 1:
            bert_fp.append(sentences[i])
        else:
            bert_fn.append(sentences[i])
    

## Explain Models

In [13]:
def lstm_f(x):
    tv = torch.tensor(
        [
            tokenizer.encode(v, padding="max_length", max_length=512, truncation=True)
            for v in x
        ]
    ).cuda()
    hidden = lstm_model.init_hidden(len(tv))
    outputs, _ = lstm_model(tv, hidden)
    outputs = outputs.detach().cpu().numpy()
    return outputs

def bert_f(x):
    tv = torch.tensor(
        [
            tokenizer.encode(v, padding="max_length", max_length=512, truncation=True)
            for v in x
        ]
    ).cuda()
    attention_mask = (tv != 0).type(torch.int64).cuda()
    outputs = bert_model(tv, attention_mask=attention_mask)[0].cpu()
    outputs = torch.nn.functional.softmax(outputs, dim=1).detach().numpy()
    return outputs

### Create Explainer

In [14]:
lstm_explainer = shap.Explainer(lstm_f, tokenizer, output_names=['non-suicide', 'suicide'])
bert_explainer = shap.Explainer(bert_f, tokenizer, output_names=['non-suicide', 'suicide'])

## Explain LSTM False-Positive Samples

In [15]:
sv = lstm_explainer(lstm_fp[: 2])
shap.plots.text(sv)

## Explain LSTM False-Negative Samples

In [16]:
sv = lstm_explainer(lstm_fn[: 2])
shap.plots.text(sv)

## Explain BERT False-Positive Samples

In [17]:
sv = bert_explainer(bert_fp[: 2])
shap.plots.text(sv)

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 3it [00:11, 11.53s/it]               


## Explain BERT False-Negative Samples

In [18]:
sv = bert_explainer(bert_fn[: 2])
shap.plots.text(sv)

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  50%|█████     | 1/2 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 3it [00:20, 10.13s/it]               
