In [1]:
import tensorflow as tf
import json

from quantus_nlp.xai import explain_lime
from quantus_nlp.xai.metric import relative_input_stability
from quantus_nlp.models.bert import pre_process_model
import pandas as pd
import nlpaug.augmenter.word as naw
from quantus_nlp.data.pre_processing import remove_punctuation

In [2]:
text = 'CHARLOTTE, N.C. (Sports Network) - Carolina Panthers  running back Stephen Davis will miss the remainder of the  season after being placed on injured reserve Saturday.'

text = remove_punctuation(text)
text

'CHARLOTTE N C Sports Network Carolina Panthers running back Stephen Davis will miss the remainder of the season after being placed on injured reserve Saturday '

In [3]:
pm = pre_process_model()
transformer = tf.saved_model.load("/Users/artemsereda/Documents/PycharmProjects/quantus-nlp/model/encoder")
metadata = tf.io.read_file("/Users/artemsereda/Documents/PycharmProjects/quantus-nlp/dataset/metadata.json").numpy()
metadata = json.loads(metadata)

## Lime

In [4]:
t, s = explain_lime(
    pre_process_model=pm,
    transformer=transformer,
    class_names=metadata["class_names"],
    example=text
)

In [5]:
df = pd.DataFrame()
df['tokens'] = t
df['salience'] = s
pd = df.sort_values(by='salience', ascending=False)
df

Unnamed: 0,tokens,salience
0,CHARLOTTE,-0.006918
1,N,0.020776
2,C,0.070118
3,Sports,0.086014
4,Network,0.024156
5,Carolina,0.066801
6,Panthers,0.194583
7,running,0.018657
8,back,0.018698
9,Stephen,0.027212


### Relative Input Stability as per [arxiv](https://arxiv.org/pdf/2203.06877.pdf)

In [6]:
aug = naw.SpellingAug()
augmented_text = aug.augment(text, n=1)[0]
augmented_text = remove_punctuation(augmented_text)
augmented_text

'CHARLOTTE N C Sports Netwok Carolina Panthers running back Sthepen Davis will miss tthe remainder OK then seaon after being placed on injured reserv Satday'

In [7]:
_, e = explain_lime(
    pm,
    transformer,
    metadata['class_names'],
    text
)

_, es = explain_lime(
    pm,
    transformer,
    metadata['class_names'],
    augmented_text
)

x = pm([text])['input_word_ids']
x = x.numpy()[0]
x = x[:len(e)]

xs = pm([augmented_text])['input_word_ids']
xs = xs.numpy()[0]
xs = xs[:len(es)]

ris = relative_input_stability(
    x, xs,
    e, es
)
ris

0.8010517222747249