# Multi-Input Text Explanation: Textual Entailment with Facebook BART


This notebook demonstrates how to get explanations for the output of the Facebook BART model trained on the mnli dataset and used for textual entailment. We use an example from the snli dataset due to mnli not being supported in the required environment for shap. 

BART: https://huggingface.co/facebook/bart-large-mnli


In [None]:
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import shap
from datasets import load_dataset

### Load model and tokenizer

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")

In [None]:
# load dataset
dataset = load_dataset("snli")
snli_label_map = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}
example_ind = 6
premise, hypothesis, label = ( dataset['train']['premise'][example_ind], 
                              dataset['train']['hypothesis'][example_ind], 
                              dataset['train']['label'][example_ind] )
print('Premise: ' + premise)
print('Hypothesis: ' + hypothesis)
true_label = snli_label_map[label]
print('The true label is: {true_label}'.format(true_label=true_label))

In [None]:
# test model
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]
probs = logits.softmax(dim=1)

bart_label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
for i, lab in bart_label_map.items():
    print('{lab} probability: {prob:0.2f}%'.format(lab=lab, prob=probs[0][i] * 100))


## Run shap values

In [None]:
import scipy as sp
import torch

# wrapper function for model
# takes in masked string which is in the form: premise <separator token(s)> hypothesis 
def f(x): 
    outputs = []
    for _x in x:
        encoding = torch.tensor([tokenizer.encode(_x)])
        output = model(encoding)[0].detach().cpu().numpy() 
        outputs.append(output[0])
    outputs = np.array(outputs)
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores)
    return val

In [None]:
# Construct explainer
bart_labels = ['contradiction', 'neutral', 'entailment']    
explainer = shap.Explainer(f, tokenizer, output_names=bart_labels)

In [None]:
# encode then decode premise, hypothesis to get concatenated sentences
encoded = tokenizer(premise, hypothesis)['input_ids'][1:-1] # ignore the start and end tokens, since tokenizer will naturally add them
decoded = tokenizer.decode(encoded)
print(decoded)

In [None]:
shap_values = explainer([decoded]) # wrap input in list
print(shap_values)

## Explanation Visualization

In [None]:
shap.plots.text(shap_values)

## Input Partition Tree - Dendrogram

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage
from matplotlib import pyplot as plt

In [None]:
Z = shap_values[0].abs.clustering
Z[-1][2] = Z[-2][2] + 10 # last row's distance is extremely large, so make it a more reasonable value
print(Z)

In [None]:
labels_arr = shap_values[0].data

# # clean labels of unusal characters (only for slow tokenizer, if use_fast=False)
# labels_arr = []
# for token in shap_values[0].data:
#     if token[0] == 'Ġ':
#         labels_arr.append(token[1:])
#     else:
#         labels_arr.append(token)
print(labels_arr)


In [None]:
fig = plt.figure(figsize=(len(Z) + 20, 15))
dn = dendrogram(Z, labels=labels_arr)
plt.show()

## Benchmarking

In [None]:
sort_order = 'positive'
perturbation = 'keep'

In [None]:
from shap import benchmark

In [None]:
sper = benchmark.perturbation.SequentialPerturbation(explainer.model, explainer.masker, sort_order, perturbation)
xs, ys, auc = sper.model_score(shap_values, [decoded])
sper.plot(xs, ys, auc)