In [None]:
import time
import spacy
import string
import numpy as np
import itertools
import pickle as pkl
from tqdm import tqdm

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from alibi.explainers import AnchorText
from alibi.datasets import fetch_movie_sentiment
from alibi.utils.download import spacy_model
from alibi.utils.lang_model import DistilbertBaseUncased, BertBaseUncased, RobertaBase

%load_ext autoreload
%autoreload 2

### Load movie review dataset

The `fetch_movie_sentiment` function returns a `Bunch` object containing the features, the targets and the target names for the dataset.

In [None]:
movies = fetch_movie_sentiment()
movies.keys()

In [None]:
data = movies.data
labels = movies.target
target_names = movies.target_names

In [None]:
train, test, train_labels, test_labels = train_test_split(data, labels, test_size=.2, random_state=42)
train, val, train_labels, val_labels = train_test_split(train, train_labels, test_size=.1, random_state=42)
train_labels = np.array(train_labels)
test_labels = np.array(test_labels)
val_labels = np.array(val_labels)

### Apply CountVectorizer to training set

In [None]:
vectorizer = CountVectorizer(min_df=1)
vectorizer.fit(train)

### Fit model

In [None]:
np.random.seed(0)
clf = LogisticRegression(solver='liblinear')
clf.fit(vectorizer.transform(train), train_labels)

### Define prediction function

In [None]:
predict_fn = lambda x: clf.predict(vectorizer.transform(x))

### Make predictions on train and test sets

In [None]:
preds_train = predict_fn(train)
preds_val = predict_fn(val)
preds_test = predict_fn(test)
print('Train accuracy', accuracy_score(train_labels, preds_train))
print('Validation accuracy', accuracy_score(val_labels, preds_val))
print('Test accuracy', accuracy_score(test_labels, preds_test))

### Load spaCy model

English multi-task CNN trained on OntoNotes, with GloVe vectors trained on Common Crawl. Assigns word vectors, context-specific token vectors, POS tags, dependency parse and named entities.

In [None]:
model = 'en_core_web_md'
spacy_model(model=model)
nlp = spacy.load(model)

### Load transformers

In [None]:
lang_model = DistilbertBaseUncased()

### Random indices

In [None]:
indices = np.random.choice(len(test), size=10, replace=False)
print(indices)

### Define Anchor

In [None]:
explainer = AnchorText(nlp=nlp, language_model=lang_model, predictor=predict_fn)

In [None]:
def build_explanation(explanation) -> str:
    s = ''
    s += 'Anchor: %s\n' % (' AND '.join(explanation.anchor))
    s += 'Precision: %.2f\n' % explanation.precision
    
    # print examples covered as True
    s += '\n\nExamples where anchor applies and model predicts %s:\n' % pred
    if len(explanation.raw['examples']):
        s += '\n'.join([x for x in explanation.raw['examples'][-1]['covered_true']])
    
    # print examples covered as False
    s += '\n\nExamples where anchor applies and model predicts %s:\n' % alternative
    if len(explanation.raw['examples']):
        s += '\n'.join([x for x in explanation.raw['examples'][-1]['covered_false']])
    
    s += '\n\n\n'
    return s

In [None]:
config_similarity = {
    "threshold": [0.95],
    "sample_proba": [0.5],
    "punctuation": [string.punctuation],
    "top_n": [10, 100, 500],
    "sampling_method": ["similarity"]
}

config_language_model = {
    "threshold": [0.95],
    "sample_proba": [0.5],
    "filling_method": ['parallel'],
    "punctuation": [string.punctuation],
    "top_n": [10, 100, 500],
    "prec_mask_templates": [0.1, 0.5, 1.0],
    "sampling_method": ["language_model"]
}


In [None]:
class_names = movies.target_names

### Benchmarking

In [None]:
def run_configuration(tests: dict, file_name: str) -> dict:
    # get all combinations
    values = tests.values()
    combinations = itertools.product(*values)
    
    # open checkpoint file
    f = open(file_name, 'wb')
    configs = []

    for comb in tqdm(combinations):
        # build configuration
        config = dict(zip(tests.keys(), comb))
        config['elapsed_time'] = 0

        for index in tqdm(indices):
            text = test[index]

            # compute text prediction
            pred = class_names[predict_fn([text])[0]]
            alternative = class_names[1 - predict_fn([text])[0]]
         
            # define explainer
            np.random.seed(0)
            explainer = AnchorText(nlp=nlp, language_model=lang_model, predictor=predict_fn)

            # compute explanation
            start = time.time()
            explanation = explainer.explain(text, **config)
            config['elapsed_time'] += (time.time() - start)
        
        # compute mean
        config['elapsed_time'] /= len(indices)
        configs.apppend(config)
        
    # append it to output file
    pkl.dump(configs, f)
    f.close()
    
    return configs


In [None]:
configs_sim = run_configuration(config_similarity, "stats_similarity.pkl")
configs_lm = run_configuration(config_language_model, "stats_lm.pkl")