In [1]:
import numpy as np
import torch

from lime.lime_text import LimeTextExplainer
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
roberta_pipe = pipeline(
    "sentiment-analysis",
    model="siebert/sentiment-roberta-large-english",
    tokenizer="siebert/sentiment-roberta-large-english",
    top_k=1,
    device=device
)

In [3]:
sample_text = "I really liked the Oppenheimer movie and found it truly entertaining and full of substance."
np.set_printoptions(suppress = True,
   formatter = {'float_kind':'{:f}'.format}, precision = 2)

In [4]:
def predict_prob(texts):
    preds = roberta_pipe(texts)
    preds = np.array([[label[0]['score'], 1 - label[0]['score']] if label[0]['label'] == 'NEGATIVE'
                      else [1 - label[0]['score'], label[0]['score']] for label in preds])
    return preds

In [5]:
explainer = LimeTextExplainer(class_names=['NEGATIVE', 'POSITIVE'])
exp = explainer.explain_instance(text_instance=sample_text,
                                 classifier_fn=predict_prob)

In [6]:
original_prediction = predict_prob(sample_text)
print(original_prediction)

[[0.001083 0.998917]]


In [7]:
print(np.array(exp.as_list()))

[['entertaining' '0.02324313634539364']
 ['liked' '0.021720990278725042']
 ['and' '0.018245235697928647']
 ['truly' '0.017765791161630226']
 ['substance' '0.014841506222376703']
 ['Oppenheimer' '-0.010981793436361903']
 ['of' '-0.009783568460385427']
 ['found' '0.006399394452398886']
 ['movie' '0.006374361410402638']
 ['it' '0.004353281434411829']]


In [8]:
modified_text = "I found the Oppenheimer movie very slow, boring and veering on being too scientific."

In [9]:
new_prediction = predict_prob(modified_text)
print(new_prediction)

[[0.999501 0.000499]]


In [10]:
exp = explainer.explain_instance(text_instance=modified_text,
                                 classifier_fn=predict_prob)
print(np.array(exp.as_list()))

[['boring' '-0.14367020927227456']
 ['slow' '-0.13916208713739314']
 ['too' '-0.08752756409615216']
 ['veering' '-0.05178532169530207']
 ['and' '0.020763893842961647']
 ['Oppenheimer' '-0.015812617921760013']
 ['scientific' '-0.009930259179443747']
 ['being' '-0.009781979229833508']
 ['very' '0.008015557326137967']
 ['on' '0.007369245549495149']]
