In [1]:
import spacy
import string
import numpy as np
import itertools
from tqdm import tqdm

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from alibi.explainers import AnchorText
from alibi.datasets import fetch_movie_sentiment
from alibi.utils.download import spacy_model
from alibi.utils.lang_model import DistilbertBaseUncased, BertBaseUncased, RobertaBase

%load_ext autoreload
%autoreload 2



### Load movie review dataset

The `fetch_movie_sentiment` function returns a `Bunch` object containing the features, the targets and the target names for the dataset.

In [2]:
movies = fetch_movie_sentiment()
movies.keys()

dict_keys(['data', 'target', 'target_names'])

In [3]:
data = movies.data
labels = movies.target
target_names = movies.target_names

In [4]:
train, test, train_labels, test_labels = train_test_split(data, labels, test_size=.2, random_state=42)
train, val, train_labels, val_labels = train_test_split(train, train_labels, test_size=.1, random_state=42)
train_labels = np.array(train_labels)
test_labels = np.array(test_labels)
val_labels = np.array(val_labels)

### Apply CountVectorizer to training set

In [5]:
vectorizer = CountVectorizer(min_df=1)
vectorizer.fit(train)

CountVectorizer()

### Fit model

In [6]:
np.random.seed(0)
clf = LogisticRegression(solver='liblinear')
clf.fit(vectorizer.transform(train), train_labels)

LogisticRegression(solver='liblinear')

### Define prediction function

In [7]:
predict_fn = lambda x: clf.predict(vectorizer.transform(x))

### Make predictions on train and test sets

In [8]:
preds_train = predict_fn(train)
preds_val = predict_fn(val)
preds_test = predict_fn(test)
print('Train accuracy', accuracy_score(train_labels, preds_train))
print('Validation accuracy', accuracy_score(val_labels, preds_val))
print('Test accuracy', accuracy_score(test_labels, preds_test))

Train accuracy 0.9801624284382905
Validation accuracy 0.7544910179640718
Test accuracy 0.7589841878294202


### Load spaCy model

English multi-task CNN trained on OntoNotes, with GloVe vectors trained on Common Crawl. Assigns word vectors, context-specific token vectors, POS tags, dependency parse and named entities.

In [9]:
model = 'en_core_web_md'
spacy_model(model=model)
nlp = spacy.load(model)

### Load transformers

In [10]:
models = {
    "DistilbertBaseUncased": DistilbertBaseUncased(),
    "BertBaseUncased":  BertBaseUncased(),
    "RobertaBase": RobertaBase(),
}

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Tests

In [11]:
texts = [
    "This is a good movie.",
    "The movie was excellent.",
    "The best story ever.",
    "Worst movie ever.",
    "Everyone played terrible.",
    "The book was horrible.",
]

tests = {
    "model_name": ["DistilbertBaseUncased", "BertBaseUncased", "RobertaBase"],
    "filling_method": ["parallel"],          # ['parallel', 'autoregressive'],
    "sample_proba": [1.0, 0.5],
    "mask_templates": [0.1, 0.2, 0.5],
    "temperature": [0.1, 1.0, 10., 100.],
    "punctuation": ["", string.punctuation],
    "stopwords": [[], ["in", "a", "the", "this", "those", "but"]]
}

class_names = movies.target_names

In [12]:
def build_explanation(explanation, config: dict = dict()) -> str:
    s = config['model_name'] + '\n=======================\n'
    s += 'Confing: ' + str(config) + '\n\n'
    s += 'Text: ' +  config['text'] + '\n\n'
    
    if len(explanation.anchor) == 0:
        s += 'WARNING !!! EMPTY ANCHOR!\n'
    
    s += 'Anchor: %s\n' % (' AND '.join(explanation.anchor))
    s += 'Precision: %.2f\n' % explanation.precision
    
    # print examples covered as True
    s += '\n\nExamples where anchor applies and model predicts %s:\n' % config['pred']
    if len(explanation.raw['examples']):
        s += '\n'.join([x for x in explanation.raw['examples'][-1]['covered_true']])
    
    # print examples covered as False
    s += '\n\nExamples where anchor applies and model predicts %s:\n' % config['alternative']
    if len(explanation.raw['examples']):
        s += '\n'.join([x for x in explanation.raw['examples'][-1]['covered_false']])
    
    s += '\n\n\n'
    return s

In [None]:
output_file = "qualitative_tests.txt"

for text in texts:
    # compute text prediction
    pred = class_names[predict_fn([text])[0]]
    alternative = class_names[1 - predict_fn([text])[0]]
    
    # get all combinations
    values = tests.values()
    combinations = itertools.product(*values)
    
    for comb in tqdm(combinations):
        config = dict(zip(tests.keys(), comb))
        config['pred'] = pred
        config['alternative'] = alternative
        config['text'] = text

        # define explainer
        np.random.seed(0)
        explainer = AnchorText(language_model=models[config['model_name']], predictor=predict_fn)

        # compute explanation
        explanation = explainer.explain(
            text,
            threshold=0.95,
            sampling_method="language_model",
            top_n=20,
            filling_method=config['filling_method'],           # this can vary
            sample_proba=config['sample_proba'],               # this can vary
            mask_templates=config['mask_templates'],           # this can vary
            stopwords=config['stopwords'],                     # this can vary
            punctuation=config['punctuation']                  # this can vary
        )

        # convert expalanation to string
        exp = build_explanation(explanation, config=config)
        
        # append it to output file
        with open(output_file, "a+") as fout:
            fout.write(exp)

288it [33:00,  6.88s/it]
272it [28:13, 17.37s/it]