In [30]:
import math
from lime.lime_text import LimeTextExplainer
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re


class LimeExplainer():
    def __init__(self, model_name, nli=False):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to('cuda')
        self.baseline_token = '[MASK]'
        self.device = 'cuda'
        self.nli = nli

    def predictor(self, texts):
        texts = [t.replace('[[MASK]]', '[SEP]') for t in texts]
        if self.nli:
            sent1 = []
            sent2 = []
            for t in texts:
                s1, s2 = t.split('[SEP]')
                sent1.append(s1)
                sent2.append(s2)

        all_probs = []
        for i in range(math.ceil(len(texts) / self.bs)):
            batch_idx = slice(i * self.bs, (i + 1) * self.bs)
            if self.nli:
                encoded = self.tokenizer(sent1[batch_idx],
                                        text_pair=sent2[batch_idx],
                                        return_tensors="pt",
                                        padding=True).to(self.device)
            else:
                encoded = self.tokenizer(texts[batch_idx],
                                        return_tensors="pt",
                                        padding=True).to(self.device)
            outputs = self.model(**encoded)
            probs = torch.softmax(outputs.logits, -1).detach().cpu()
            all_probs.append(probs)
        probs = torch.cat(all_probs, 0).numpy()
        return probs

    def explain(self, str_to_predict, batch_size=32, mask_n=5000, output_indices=None):
        self.bs = batch_size

        explainer = LimeTextExplainer(class_names=['contradiction', 'entailment', 'neutral'],
                                      bow=False,
                                      mask_string=self.baseline_token)
        exp = explainer.explain_instance(str_to_predict,
                                         self.predictor,
                                         top_labels=2,
                                         num_features=20,
                                         num_samples=mask_n)
        pred = int(exp.predict_proba.argmax())
        tokens = re.split(r'\W+', str_to_predict)
        tokens = [tok if tok != 'SEP' else '[SEP]' for tok in tokens]
        if output_indices is None:
            output_indices = pred
        explanations = {(k,): v for k, v in exp.as_map()[output_indices]}
        return exp, explanations, tokens, pred


In [31]:
lime = LimeExplainer('textattack/bert-base-uncased-snli', nli=True)

In [32]:
exp, explanation, tokens, pred = lime.explain("A soccer game with multiple males playing. [SEP] Some men are playing a sport.")

In [34]:
explanation

{(5,): 0.21130668071150888,
 (13,): 0.16711675887591976,
 (1,): 0.14968719404536854,
 (8,): -0.06588197184775399,
 (2,): 0.06479727033633015,
 (9,): -0.05534492112108581,
 (6,): 0.05176177471145916,
 (11,): -0.04457048769407898,
 (12,): 0.028630411137390396,
 (0,): -0.02518382129868726,
 (4,): -0.019109104346571495,
 (3,): 0.018757777425323866,
 (10,): -0.010720534914372328,
 (7,): -0.0046395738476809915}

In [None]:
lime = LimeExplainer('textattack/bert-base-uncased-SST-2', nli=False)

In [18]:
exp, explanation, tokens, pred = lime.explain("This film doesn't care about cleverness, wit or any other kind of intelligent humor")

In [33]:
exp.save_to_file('exp_nli.html')