# eli5 text highlighting with a custom vectorizer

The [eli5 library](https://eli5.readthedocs.io/en/latest/) visualizes the weights and predictions of text classifiers (and other machine learning models). It supports many scikit-learn models out-of-box, but not pipelines with a custom vectorizer. Custom vectorizers are necessary, for example, when one wants to lemmatize words (using an external library such as libvoikko for the Finnish language).

This notebook shows how to get eli5 to show exaplanations as highlighted text even when using custom vectorizers.

In [1]:
import random
import re
import eli5
import numpy as np
import pandas as pd
from eli5.base import DocWeightedSpans
from eli5.lime import TextExplainer
from models.sif import SIFTransformer
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.linear_model import LogisticRegressionCV
from sklearn.metrics import classification_report
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MaxAbsScaler, LabelEncoder, FunctionTransformer
from sklearn.svm import LinearSVC
from voikko import libvoikko

## A custom vectorizer for lemmatizing Finnish

The following cell implements a custom scikit-learn Vectorizer that a) uses libvoikko for lemmatization and b) implements the eli5 interface required for text highlighting.

In [2]:
class VoikkoVectorizer(TfidfVectorizer):
    """Convert a collection of raw documents to a matrix of TF-IDF features.
    
    Based on the scikit-learn's TfidfVectorizer.
    
    Parameters
    ----------
    voikko : default = None
        An instance of libvoikko.Voikko object. If not None, the words
        will be lemmatized using libvoikko.

    Other parameters are the same as in TfidfVectorizer (except for
    tokenizer and analyzer which this class overrides).
    """
    def __init__(self, *, input='content', encoding='utf-8',
                 decode_error='strict', strip_accents=None, lowercase=True,
                 preprocessor=None, stop_words=None,
                 ngram_range=(1, 1), max_df=1.0, min_df=1,
                 max_features=None, vocabulary=None, binary=False,
                 dtype=np.float64, norm='l2', use_idf=True, smooth_idf=True,
                 sublinear_tf=False, voikko=None):
        self.voikko = voikko

        if stop_words:
            stop_words = set(self._simple_tokenizer(' '.join(stop_words)))

        super().__init__(
            input=input, encoding=encoding, decode_error=decode_error,
            strip_accents=strip_accents, lowercase=lowercase,
            preprocessor=preprocessor, tokenizer=self._simple_tokenizer,
            analyzer='word', stop_words=stop_words, token_pattern=None, 
            ngram_range=ngram_range, max_df=max_df, min_df=min_df,
            max_features=max_features, vocabulary=vocabulary, binary=binary,
            dtype=dtype, norm=norm, use_idf=use_idf, smooth_idf=smooth_idf,
            sublinear_tf=sublinear_tf)

    def get_doc_weighted_spans(self, doc, feature_weights, feature_fn):
        """This function implements eli5's interface required for highlighting text on custom vectorizer.
        
        Adapted from eli5.sklearn.text."""
        preprocessed_doc = self.build_preprocessor()(self.decode(doc))
        feature_weights_dict = _get_feature_weights_dict(feature_weights, feature_fn)
        
        spans = []
        found_features = {}
        for f_spans, feature in self._span_analyzer(preprocessed_doc):
            if feature not in feature_weights_dict:
                continue

            weight, key = feature_weights_dict[feature]
            spans.append((feature, f_spans, weight))
            found_features[key] = weight

        return found_features, DocWeightedSpans(
            document=preprocessed_doc,
            spans=spans,
            preserve_density=self.analyzer.startswith('char'),
        )

    def lemmatize(self, text):
        """Lemmatize a string of text."""
        return ' '.join(self.lemmatize_token(t) for t in text.split())

    def lemmatize_token(self, token):
        """Lemmatize one token using libvoikko."""
        if self.voikko is not None:
            analyzed = self.voikko.analyze(token)
            if analyzed:
                token = analyzed[0].get('BASEFORM', token)

        return token
    
    def _span_tokenizer(self, doc):
        """Split a string in to tokens and lemmatize them.
        
        Returns also the start and end indexes of the tokens in the original document.
        """
        tokens = []
        for m in re.finditer(r'\b\w\w+\b', doc):
            token = self.lemmatize_token(m.group())
            tokens.append((m.span(), token.lower()))

        return tokens
    
    def _simple_tokenizer(self, doc):
        """Tokenize a document like _span_tokenizer but don't return the spans."""
        return [token for _, token in self._span_tokenizer(doc)]

    def _span_analyzer(self, doc):
        assert self.analyzer == 'word'
        
        tokens = self._span_tokenizer(doc)        
        return self._span_word_ngrams(tokens)
            
    def _span_word_ngrams(self, tokens):
        if self.stop_words is not None:
            tokens = [(s, w) for s, w in tokens if w not in self.stop_words]

        min_n, max_n = self.ngram_range
        if max_n == 1:
            tokens = [([s], w) for s, w in tokens]
        else:
            original_tokens = tokens
            tokens = []
            n_original_tokens = len(original_tokens)
            tokens_append = tokens.append
            space_join = ' '.join
            for n in range(min_n,
                            min(max_n + 1, n_original_tokens + 1)):
                for i in range(n_original_tokens - n + 1):
                    ngram_tokens = original_tokens[i: i + n]
                    tokens_append((
                        [s for s, _ in ngram_tokens],
                        space_join(t for _, t in ngram_tokens)))

        return tokens

def _get_feature_weights_dict(feature_weights,  # type: FeatureWeights
                              feature_fn        # type: Optional[Callable[[str], str]]
                              ):
    # type: (...) -> Dict[str, Tuple[float, Tuple[str, int]]]
    """ Return {feat_name: (weight, (group, idx))} mapping.
    
    Copied from eli5.sklearn.text.
    """
    return {
        # (group, idx) is an unique feature identifier, e.g. ('pos', 2)
        feat_name: (fw.weight, (group, idx))
        for group in ['pos', 'neg']
        for idx, fw in enumerate(getattr(feature_weights, group))
        for feat_name in _get_features(fw.feature, feature_fn)
    }

def _get_features(feature, feature_fn=None):
    """Copied from eli5.sklearn.text."""
    if isinstance(feature, list):
        features = [f['name'] for f in feature]
    else:
        features = [feature]
    if feature_fn:
        features = list(filter(None, map(feature_fn, features)))
    return features

def replace_num_tokens(text):
    """Replace all words with numbers with "<num>"."""
    return re.sub(r'\b\w*\d\w*\b', '<num>', text, flags=re.IGNORECASE)

## Loading data

Let's work on the [eduskunta-vkk](https://github.com/aajanki/eduskunta-vkk) dataset.

In [3]:
small_classes = [
    'ulkomaankauppa- ja kehitysministeri',
    'puolustusministeri',
    'pääministeri',
    'eurooppa-, kulttuuri- ja urheiluministeri'
]

short_names = {
    'perhe- ja peruspalveluministeri': 'per',
    'maatalous- ja ympäristöministeri': 'maa',
    'sisäministeri': 'sis',
    'oikeus- ja työministeri': 'oik',
    'opetus- ja kulttuuriministeri': 'ope',
    'valtiovarainministeri': 'val',
    'liikenne- ja viestintäministeri': 'lii',
    'sosiaali- ja terveysministeri': 'sos',
    'elinkeinoministeri': 'eli',
    'ulkoministeri': 'ulk',
    'kunta- ja uudistusministeri': 'kun',
    'eurooppa-, kulttuuri- ja urheiluministeri': 'eur',
    'pääministeri': 'pää',
    'puolustusministeri': 'puo',
    'ulkomaankauppa- ja kehitysministeri': 'uke',
}

def load_documents(filename):
    df = pd.read_csv(filename, header=0).rename(columns={'ministry': 'class'})
    df = df[~df['class'].isin(small_classes)].reset_index()
    return df

def load_data():
    train = load_documents('data/vkk/train.csv.bz2')
    dev = load_documents('data/vkk/dev.csv.bz2')
    test = load_documents('data/vkk/test.csv.bz2')
    
    return train, dev, test

In [4]:
train, dev, test = load_data()
print(f'Number of classes: {len(train["class"].unique())}')
print(f'Number of train samples: {len(train)}')
print(f'Number of dev samples: {len(dev)}')
print(f'Number of test samples: {len(test)}')

Number of classes: 11
Number of train samples: 46130
Number of dev samples: 2826
Number of test samples: 2819


## Explaining a bag-of-words classifier

As the first example, let's train a basic scikit-learn bag of word features classifier (SVM to be more exact) and examine its weights and predictions.

In [5]:
voikko = libvoikko.Voikko('fi')

In [6]:
stop_words_fi = [
    'ei', 'että', 'he', 'hän', 'ja', 'joissa', 'joka', 'jos', 'koska', 'kuin',
    'kuka', 'kun', 'me', 'mikä', 'minä', 'myös', 'ne', 'nuo', 'nämä', 'olla',
    'se', 'sinä', 'tai', 'te', 'tuo', 'tämä', 'vai',
]

enc = LabelEncoder()
y_encoded = enc.fit_transform(train['class'])

vec = VoikkoVectorizer(voikko=voikko, 
                       preprocessor=replace_num_tokens,
                       ngram_range=(1, 2),
                       min_df=2, max_df=0.1,
                       stop_words=stop_words_fi)

clf = LinearSVC(C=0.1, loss='hinge', intercept_scaling=5.0,
                max_iter=100000, multi_class='ovr')
scaler = MaxAbsScaler()
pipe = make_pipeline(vec, scaler, clf)
pipe.fit(train['sentence'], y_encoded);

Checking that the vectorizer really lemmatizes Finnish words:

In [7]:
text = 'Ajoimme punaisella autolla aamulla.'

vec.lemmatize(text)

'ajaa punainen auto aamulla.'

Performance on the development set:

In [8]:
y_dev_true = dev['class']
y_dev_pred = enc.inverse_transform(pipe.predict(dev['sentence']))

print(classification_report(y_dev_true, y_dev_pred))

                                  precision    recall  f1-score   support

              elinkeinoministeri       0.61      0.44      0.51        99
     kunta- ja uudistusministeri       0.80      0.58      0.67        48
 liikenne- ja viestintäministeri       0.81      0.79      0.80       216
maatalous- ja ympäristöministeri       0.78      0.83      0.80       456
         oikeus- ja työministeri       0.71      0.71      0.71       395
   opetus- ja kulttuuriministeri       0.83      0.82      0.83       349
 perhe- ja peruspalveluministeri       0.69      0.75      0.72       451
                   sisäministeri       0.71      0.80      0.76       346
   sosiaali- ja terveysministeri       0.74      0.64      0.69       232
                   ulkoministeri       0.85      0.75      0.80        63
           valtiovarainministeri       0.78      0.70      0.73       171

                        accuracy                           0.75      2826
                       macro avg    

### Examining the classifier and predictions

Eli5 can show the top features for each class. (Note how the feature names are lemmatized words.)

In [9]:
target_names = enc.inverse_transform(clf.classes_)

eli5.show_weights(clf, vec=vec, top=10, target_names=target_names)

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5,Unnamed: 8_level_5,Unnamed: 9_level_5,Unnamed: 10_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6,Unnamed: 8_level_6,Unnamed: 9_level_6,Unnamed: 10_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7,Unnamed: 8_level_7,Unnamed: 9_level_7,Unnamed: 10_level_7
Weight?,Feature,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,Unnamed: 7_level_8,Unnamed: 8_level_8,Unnamed: 9_level_8,Unnamed: 10_level_8
Weight?,Feature,Unnamed: 2_level_9,Unnamed: 3_level_9,Unnamed: 4_level_9,Unnamed: 5_level_9,Unnamed: 6_level_9,Unnamed: 7_level_9,Unnamed: 8_level_9,Unnamed: 9_level_9,Unnamed: 10_level_9
Weight?,Feature,Unnamed: 2_level_10,Unnamed: 3_level_10,Unnamed: 4_level_10,Unnamed: 5_level_10,Unnamed: 6_level_10,Unnamed: 7_level_10,Unnamed: 8_level_10,Unnamed: 9_level_10,Unnamed: 10_level_10
+1.467,kaivostoiminta,,,,,,,,,
+1.407,esir,,,,,,,,,
+1.361,finpron,,,,,,,,,
+1.282,finnair,,,,,,,,,
+1.238,posti,,,,,,,,,
+1.194,kaivos,,,,,,,,,
+1.194,raskoneen,,,,,,,,,
+1.107,kaivoslaki,,,,,,,,,
+1.071,tehoreservi,,,,,,,,,
… 16884 more positive …,… 16884 more positive …,,,,,,,,,

Weight?,Feature
+1.467,kaivostoiminta
+1.407,esir
+1.361,finpron
+1.282,finnair
+1.238,posti
+1.194,kaivos
+1.194,raskoneen
+1.107,kaivoslaki
+1.071,tehoreservi
… 16884 more positive …,… 16884 more positive …

Weight?,Feature
+1.025,maistraatti
+0.916,väestötietojärjestelmä
+0.747,kuntalaki
+0.685,senaatti kiinteistö
+0.657,koulunkäyntiavustaja
+0.592,muuttopäivä
+0.566,senaatti
+0.543,vimpeli
+0.538,muuttoilmoitus
… 13501 more positive …,… 13501 more positive …

Weight?,Feature
+2.090,liikennevirasto
+2.012,liikenne
+1.607,rata
+1.436,maantie
+1.405,viestintävirasto
+1.386,liikenneturvallisuus
+1.377,ajokortti
+1.283,postilaki
+1.273,tasoristeys
+1.267,nopeusrajoitus

Weight?,Feature
+2.357,metsähallitus
+2.222,eläin
+1.975,ympäristöministeriö
+1.855,susi
+1.842,metsä
+1.695,maatalous
+1.664,luonnonvarakeskus
+1.559,metsästys
+1.533,laji
+1.499,viljelijä

Weight?,Feature
+2.398,oikeusministeriö
+1.973,velallinen
+1.883,palkkatuki
+1.813,työnhakija
+1.696,käräjäoikeus
+1.512,rikosseuraamuslaitos
+1.357,avioliitto
+1.273,toimisto
+1.177,työtön
+1.167,perheryhmäkoti

Weight?,Feature
+2.908,varhaiskasvatus
+2.422,opetus
+2.258,korkeakoulu
+2.229,ammattikorkeakoulu
+2.217,yliopisto
+1.813,perusopetus
+1.791,opiskelija
+1.786,oppilas
+1.534,varhaiskasvatuslaki
+1.406,koulutus

Weight?,Feature
+1.771,potilas
+1.766,toimeentulotuki
+1.743,terveydenhuolto
+1.736,thl
+1.718,kela
+1.665,iäkäs
+1.630,hoito
+1.589,lastensuojelu
+1.582,sosiaalihuolto
+1.531,tulkki

Weight?,Feature
+2.941,poliisi
+2.656,sisäministeriö
+2.638,maahanmuuttovirasto
+1.922,poliisihallitus
+1.675,vastaanottokeskus
+1.615,poliisilaitos
+1.582,poliisimies
+1.581,ulkomaalaislaki
+1.537,turvapaikanhakija
+1.274,turvapaikka

Weight?,Feature
+2.349,lääke
+1.861,kansaneläkelaitos
+1.443,työkyvyttömyyseläke
+1.384,apteekki
+1.351,aktiivimalli
+1.310,väkivalta uhka
+1.267,alkuomavastuu
+1.209,lääkekatto
+1.124,sairausvakuutuslaki
+1.113,työttömyyspäiväraha

Weight?,Feature
+1.641,turkki
+1.345,libya
+1.211,yk
+1.036,ihmisoikeus
+0.866,syyria
+0.849,päätöslauselma
+0.824,pakistan
+0.777,humanitaarinen apu
+0.773,ankara suurlähetystö
… 12083 more positive …,… 12083 more positive …

Weight?,Feature
+2.385,vero
+2.100,verohallinto
+1.913,verotus
+1.894,verovelvollinen
+1.170,verojärjestelmä
+1.162,veropohja
+1.160,verokanta
+1.120,kotitalousvähennys
+1.107,verovapaus
+1.067,tulorekisteri


Next, explain a few **incorrectly** predicted training samples.

Eli5 shows the relative contributions of words by highlighting an input document. Note how highlighting show the original inflected word forms even though the classifier actually uses the lemmatized words behind the scene. This is because the VoikkoVectorizer.\_span_tokenizer\(\) function returns the original word spans.

In [10]:
k = 5

y_train_pred = enc.inverse_transform(pipe.predict(train['sentence']))
correctly_predicted = y_train_pred == train['class']
inds = random.sample(correctly_predicted[~correctly_predicted].index.values.tolist(), k=k)
for i in inds:
    print(f'i = {i}')
    display(eli5.show_prediction(clf,
                                 train['sentence'][i],
                                 vec=vec,
                                 target_names=target_names,
                                 targets=[train['class'][i]]))

i = 29959


Contribution?,Feature
0.1,Highlighted in text (sum)
-1.021,<BIAS>


i = 244


Contribution?,Feature
0.09,Highlighted in text (sum)
-1.08,<BIAS>


i = 15997


Contribution?,Feature
0.453,Highlighted in text (sum)
-0.981,<BIAS>


i = 26458


Contribution?,Feature
-0.024,Highlighted in text (sum)
-0.978,<BIAS>


i = 9758


Contribution?,Feature
-0.011,Highlighted in text (sum)
-0.981,<BIAS>


## Explaining word embedding based predictions

Now, let's try a bit more complicated case: a black-box text classifier where the features don't directly correspond to individual tokens. The Smoothed Inverse Frequency weighting model averages weighted word2vec embeddings. See the [SIF model implementation](models/sif.py) for more details.

In [11]:
sif = SIFTransformer(word_freq_filename='data/finnish_vocab/finnish_vocab.txt.gz',
                     word2vec_filename='data/fin-word2vec/fin-word2vec.bin')

In [12]:
stop_words_pat = re.compile('|'.join(r'\b' + re.escape(x) + r'\b' for x in stop_words_fi), re.IGNORECASE)

def preprocess_document(doc):
    return replace_num_tokens(stop_words_pat.sub('', doc))

def preprocess_tr(text):
    if isinstance(text, str):
        return preprocess_document(text)
    elif isinstance(text, pd.Series):
        return text.map(preprocess_document)
    else:
        return [preprocess_document(x) for x in text]

preprocess_transformer = FunctionTransformer(preprocess_tr, validate=False)
clf2 = LogisticRegressionCV(Cs=np.logspace(-2, 2, 10), cv=5, max_iter=1000, multi_class='multinomial')
pipe2 = make_pipeline(preprocess_transformer, sif, clf2)
pipe2.fit(train['sentence'], y_encoded);

Performance on the development set:

In [13]:
y_dev_true = dev['class']
y_dev_pred = enc.inverse_transform(pipe2.predict(dev['sentence']))

print(classification_report(y_dev_true, y_dev_pred))

                                  precision    recall  f1-score   support

              elinkeinoministeri       0.50      0.32      0.39        99
     kunta- ja uudistusministeri       0.50      0.15      0.23        48
 liikenne- ja viestintäministeri       0.61      0.61      0.61       216
maatalous- ja ympäristöministeri       0.67      0.75      0.70       456
         oikeus- ja työministeri       0.47      0.52      0.49       395
   opetus- ja kulttuuriministeri       0.78      0.71      0.74       349
 perhe- ja peruspalveluministeri       0.53      0.65      0.58       451
                   sisäministeri       0.58      0.65      0.61       346
   sosiaali- ja terveysministeri       0.54      0.27      0.36       232
                   ulkoministeri       0.63      0.51      0.56        63
           valtiovarainministeri       0.61      0.52      0.56       171

                        accuracy                           0.59      2826
                       macro avg    

Explanations for a few incorrectly predicted training samples.

TextExplainer from eli5 uses LIME to estimate the relative importance of words. Because we know that the SIF model uses only unigram tokens and removes stop words, we enforce the same constraints on the TextExplainer by constructing a CountVectorizer with these limitations.

In [14]:
k = 5

y_train_pred = enc.inverse_transform(pipe2.predict(train['sentence']))
correctly_predicted = y_train_pred == train['class']
inds = random.sample(correctly_predicted[~correctly_predicted].index.values.tolist(), k=k)
for i in inds:
    print(f'i = {i}')
    te = TextExplainer(vec=CountVectorizer(stop_words=stop_words_fi))
    te.fit(train['sentence'][i], pipe2.predict_proba)
    if te.metrics_['mean_KL_divergence'] > 0.1:
        print('WARNING: The explanation might not be reliable!')
        print(te.metrics_)
    display(te.show_prediction(target_names=target_names.tolist(), targets=[train['class'][i]]))

i = 4619
{'mean_KL_divergence': 0.1148570595680951, 'score': 0.8901819397343732}


Contribution?,Feature
-0.618,Highlighted in text (sum)
-1.005,<BIAS>


i = 28710


Contribution?,Feature
-0.992,Highlighted in text (sum)
-1.421,<BIAS>


i = 45496
{'mean_KL_divergence': 0.11741585302685531, 'score': 0.8924048681316638}


Contribution?,Feature
-0.329,Highlighted in text (sum)
-1.174,<BIAS>


i = 39740


Contribution?,Feature
-1.131,<BIAS>
-2.508,Highlighted in text (sum)


i = 5990
{'mean_KL_divergence': 0.11101152581906353, 'score': 0.9093756767520604}


Contribution?,Feature
-1.166,<BIAS>
-2.851,Highlighted in text (sum)
