In [146]:
%matplotlib inline

import imp
import keras.backend
import keras.models
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import pickle
import time
import keras
import re
import numpy as np

from string import punctuation
from keras.models import Model, load_model
from keras import optimizers
from gensim.models import Word2Vec
from gensim.models import KeyedVectors
from nltk.tokenize import word_tokenize
from nltk.tokenize import PunktSentenceTokenizer
from matplotlib import cm, transforms
from imblearn.metrics import classification_report_imbalanced
from sklearn.metrics import roc_curve, auc, confusion_matrix

import innvestigate
import innvestigate.applications
from innvestigate.utils.tests.networks import base as network_base

# Introduction

In this example, we are going to build a text classifer, inspired by experiments in [Arras et al. (2017a)][arras] and [Arras et al. (2017b)][arras2]. In particular, we are going to classify the relevance of epidemiological texts, and apply explanation methods provided by iNNvestigate to analyze how words in each article influence the articles's relevance prediction.

We apply various explanation methods implemented in iNNvestigate to explain decisions from a trained model. The figure below is explanations of a review that we expect to see: red indicates a high relevance score in favour of the prediction, while blue is the opposite.

![][sample]

[arras]: http://www.aclweb.org/anthology/W16-1601
[arras2]: https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0181142
[sample]: https://i.imgur.com/IRQL5oh.png

## Data Preprocessing

In [149]:
np.random.seed(13353)

In [148]:
wv = KeyedVectors.load("self_trained_200", mmap="r")
vocabs = [i for i in wv.wv.vocab.keys()] 
total_vocabs = len(vocabs) 

  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL


In [150]:
# Unknown vocabs are set to <UNK>.
encoder = dict(zip(['<UNK>'] + vocabs, range(len(vocabs) +1)))
decoder = dict(zip(encoder.values(), encoder.keys()))

print('We have %d vocabs.' % len(encoder))

We have 1534827 vocabs.


In [151]:
pretrained_embedding = wv.wv.vectors

# Unknown vocabs will have embedding weights of zero.
embedding = np.zeros((pretrained_embedding.shape[0]+1, pretrained_embedding.shape[1]))
embedding[1:, :] = pretrained_embedding

## Spiting Training, Testing, and Validation Set.

In [152]:
SPLIT_LABEL_MAPPING = {
    'training' : 1,
    'testing': 2,
    'validation': 3
}

MAX_SEQ_LENGTH = 200 # Calculated from where 63 is the shortest found text: min(epi_df.tokenized.apply(len))
EMBEDDING_DIM = embedding.shape[1]

In [153]:
epi_texts = pd.read_csv("with_label.csv")

In [154]:
epi_texts["extracted_text"] = (epi_texts["extracted_text"]
                               .apply(lambda x: 
                                          re.sub(r"([0-9a-zA-Z]+)\.([A-Za-z]+\s)",
                                                 r"\g<1>. \g<2>",
                                                 x)
                                      )
                              )

In [155]:
# sent_tokenizer = PunktSentenceTokenizer(" ".join(epi_texts["extracted_text"]))

In [156]:
with open("sent_tokenizer.p", "rb") as f:
    sent_tokenizer = pickle.load(f)

In [157]:
epi_texts['tokenized'] = (epi_texts["extracted_text"]
                          .apply(lambda x: sent_tokenizer.tokenize(x))
                          .apply(lambda x: [word_tokenize(sentence) for sentence in x])
                          .apply(lambda article: [token for sentence in article for token in sentence])
                         )

In [158]:
epi_texts["tokenized"] = (epi_texts["tokenized"]
                               .apply(lambda x: [i.lower() for i in x if i not in punctuation])
                              )

In [159]:
train, validate, test = np.split(epi_texts.sample(frac=1), [int(.6*len(epi_texts)), int(.8*len(epi_texts))], )

In [160]:
train["splitset_label"] = "training"

In [161]:
validate["splitset_label"] = "validation"

In [162]:
test["splitset_label"] = "testing"

In [163]:
validate["splitset_label"] = "validation"

In [164]:
epi_df = pd.concat([train, test, validate])

In [165]:
epi_df.head()

Unnamed: 0,extracted_text,label,tokenized,splitset_label
2857,ProMED-mail is a program of the International ...,False,"[promed-mail, is, a, program, of, the, interna...",training
3090,[Ref: S Rasool et al (2017): First Report of _...,False,"[ref, s, rasool, et, al, 2017, first, report, ...",training
433,ProMED-mail is a program of the International ...,False,"[promed-mail, is, a, program, of, the, interna...",training
1156,ProMED-mail is a program of the International ...,False,"[promed-mail, is, a, program, of, the, interna...",training
1780,"- Epidemiological situation 22 Jun 2018, DRC M...",False,"[epidemiological, situation, 22, jun, 2018, dr...",training


In [166]:
LABEL_IDX_TO_NAME = {
    0: 'irrelevant',
    1: 'relevant'
}

In [167]:
def prepare_own_dataset(data_set):
    filtered_indices = epi_df["splitset_label"] == data_set
    data_set_of_epi_df = epi_df[filtered_indices]
    
    xd = np.zeros((len(data_set_of_epi_df), MAX_SEQ_LENGTH, EMBEDDING_DIM))
    y = data_set_of_epi_df["label"].values.astype(int)
    
    articles = []
    for i, tokenized in enumerate(data_set_of_epi_df["tokenized"].values):
        tokenized = [token.lower() for token in tokenized]
        article = []
        for j, v in enumerate(tokenized[:MAX_SEQ_LENGTH]):
            if v in encoder:
                e_idx = encoder[v]
            else:
                e_idx = 0
            
            xd[i, j, :] = embedding[e_idx]
            article.append(e_idx)
        articles.append(article)
    
    return dict(x4d=np.expand_dims(xd, axis=1), 
                y=y,
                encoded_articles=articles)

DATASETS = dict()

for data_set in ['training', 'testing', 'validation']:
    DATASETS[data_set] = prepare_own_dataset(data_set)

In [168]:
# Find text with lowest amount of <UNK> tokens
ideces_of_lowest = []
length = 4
while len(ideces_of_lowest) < 1:
    for sample_idx in range(len(DATASETS['testing']['encoded_articles'])):
        text = ' '.join(map(lambda x: decoder[x], DATASETS['training']['encoded_articles'][sample_idx]))
        amount_of_unk = len(re.findall(r"<UNK>", text))
        if amount_of_unk == length:
            ideces_of_lowest.append(sample_idx)
    length += 1

In [169]:
print('Review(ID=%d): %s' %
      (sample_idx, ' '.join(map(lambda x: decoder[x], DATASETS['training']['encoded_articles'][590]))))

Review(ID=646): date fri 5 oct 2018 source outbreak news today edited http <UNK> the los angeles county department of public health lac dph is reporting an endemic flea-borne typhus outbreak in downtown los angeles between july and september 2018 health officials identified 9 cases of flea-borne typhus the cases have a history of living or working in downtown los angeles and 6 of them have reported experiencing homelessness or living in interim housing facilities in the area all cases were hospitalized and no deaths have occurred flea-borne typhus is endemic in lac with cases detected each year in recent years the average number of cases reported to lac dph has doubled to nearly 60 cases per year however geographic clusters of the size occurring in downtown los angeles are unusual most cases occur in the summer and fall months in lac the primary animals known to carry infected fleas include rats feral cats and opossums people with significant exposure to these animals are at risk of ac

# Model Construction

Our classifier is a convolutional neural network, which was inspired by the network used in https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0181142

In [170]:
NUM_CLASSES = 2

In [171]:
def build_network(input_shape, output_n, activation=None, dense_unit=256, dropout_rate=0.25):
    if activation:
        activation = "relu"

    net = {}
    net["in"] = network_base.input_layer(shape=input_shape)
    
    net["conv"] = keras.layers.Conv2D(filters=100, kernel_size=(1,2), 
                                      strides=(1, 1),
                                      activation='relu', 
                                      padding='valid')(net["in"])
    net["dropout"] = keras.layers.Dropout(dropout_rate)(net["conv"])
    net["pool"] = keras.layers.MaxPooling2D(pool_size=(1, input_shape[2]-1), strides=(1,1))(net["dropout"])
    
    net["out"] = network_base.dense_layer(keras.layers.Flatten()(net["pool"]), units=output_n, activation=activation)
    net["sm_out"] = network_base.softmax(net["out"])


    net.update({
        "input_shape": input_shape,

        "output_n": output_n,
    })
    return net

net = build_network((None, 1, MAX_SEQ_LENGTH, EMBEDDING_DIM), NUM_CLASSES)
model_without_softmax, model_with_softmax = Model(inputs=net['in'], outputs=net['out']), Model(inputs=net['in'], outputs=net['sm_out'])

In [189]:
def to_one_hot(y):
    return keras.utils.to_categorical(y, NUM_CLASSES)

def train_model(model,  batch_size=200, epochs=40):
    
    x_train = DATASETS['training']['x4d']
    y_train = to_one_hot(DATASETS['training']['y'])
    
    x_test = DATASETS['testing']['x4d']
    y_test = to_one_hot(DATASETS['testing']['y'])
    
    x_val = DATASETS['validation']['x4d']
    y_val = to_one_hot(DATASETS['validation']['y'])
    
    model.compile(loss='categorical_crossentropy',
                  optimizer=optimizers.Adam(),
                  metrics=['accuracy'])

    history = model.fit(x_train, y_train,
                        batch_size=batch_size,
                        epochs=epochs,
                        verbose=1,
                        validation_data=(x_val, y_val),
                        shuffle=True,
                        class_weight="auto"
                       )
    score = model.evaluate(x_test, y_test, verbose=0)
    print('Test loss:', score[0])
    print('Test accuracy:', score[1])

In [190]:
train_model(model_with_softmax, batch_size=256, epochs=20)

Train on 1939 samples, validate on 646 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Test loss: 0.18741113787633548
Test accuracy: 0.955177743431221


In [191]:
model_without_softmax.set_weights(model_with_softmax.get_weights())

In [175]:
model_with_softmax.save("my_model_softmax")
model_without_softmax.save("my_model_without")

In [25]:
model_with_softmax = load_model('my_model')

In [192]:
x_test = DATASETS['testing']['x4d']
y_test = DATASETS['testing']['y']

soft_max_predicted = model_with_softmax.predict(x_test)
y_pred = [soft_max_predicted[i][1] for i in range(len(x_test))]

In [200]:
fpr, tpr, _ = roc_curve(y_test, arg_max_predicted)

plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.plot([0, 1], [0, 1], linestyle='--')
plt.plot(fpr, tpr)

In [194]:
y_argmax = [soft_max_predicted[i].argmax() for i in range(len(x_test))]

In [197]:
print(classification_report_imbalanced(y_test, y_argmax))
print(confusion_matrix(y_test, y_argmax))
print('AUC score: {:3f}'.format(auc(fpr, tpr)))

# Model Analysis and Visualization

At this stage, we have a trained model and are ready to explain it via **iNNvestigate**'s analyzers.

In [141]:
# Specify methods that you would like to use to explain the model. 
# Please refer to iNNvestigate's documents for available methods.
methods = ['gradient', 'lrp.z', 'lrp.alpha_2_beta_1', 'pattern.attribution']
kwargs = [{}, {}, {}, {'pattern_type': 'relu'}]

In [142]:
# build an analyzer for each method
analyzers = []

for method, kws in zip(methods, kwargs):
    analyzer = innvestigate.create_analyzer(method, model_without_softmax, **kws)
    analyzer.fit(DATASETS['training']['x4d'], batch_size=256, verbose=1)
    analyzers.append(analyzer)



Epoch 1/1


In [143]:
# specify indices of reviews that we want to investigate
# test_sample_indices = [170, 321, 414]
test_sample_indices = np.argwhere(DATASETS['testing']['y']==1).squeeze()
test_sample_preds = [None]*len(test_sample_indices)

# a variable to store analysis results.
analysis = np.zeros([len(test_sample_indices), len(analyzers), 1, MAX_SEQ_LENGTH])

for i, ridx in enumerate(test_sample_indices):

    x, y = DATASETS['testing']['x4d'][ridx], DATASETS['testing']['y'][ridx]

    t_start = time.time()
    x = x.reshape((1, 1, MAX_SEQ_LENGTH, EMBEDDING_DIM))    

    presm = model_without_softmax.predict_on_batch(x)[0] #forward pass without softmax
    prob = model_with_softmax.predict_on_batch(x)[0] #forward pass with softmax
    y_hat = prob.argmax()
    test_sample_preds[i] = y_hat
    
    for aidx, analyzer in enumerate(analyzers):

        a = np.squeeze(analyzer.analyze(x))
        a = np.sum(a, axis=1)

        analysis[i, aidx] = a
    t_elapsed = time.time() - t_start
    print('Review %d (%.4fs)'% (ridx, t_elapsed))

Review 6 (1.8100s)
Review 74 (0.0150s)
Review 79 (0.0110s)
Review 106 (0.0100s)
Review 113 (0.0120s)
Review 142 (0.0140s)
Review 185 (0.0090s)
Review 219 (0.0120s)
Review 252 (0.0100s)
Review 270 (0.0090s)
Review 276 (0.0100s)
Review 296 (0.0100s)
Review 299 (0.0100s)
Review 304 (0.0120s)
Review 386 (0.0100s)
Review 389 (0.0100s)
Review 391 (0.0100s)
Review 398 (0.0080s)
Review 422 (0.0100s)
Review 423 (0.0110s)
Review 424 (0.0090s)
Review 479 (0.0110s)
Review 482 (0.0090s)
Review 547 (0.0090s)
Review 606 (0.0100s)
Review 617 (0.0100s)
Review 626 (0.0090s)
Review 638 (0.0090s)


## Visualization

To this point, we have all analysis results from iNNvestigate's analyzers, and we are now ready to visualize them in a insightful way. We will use relevance scores from explanation methods to highlight the words in each review. 

We will use  the *blue-white-red (bwr)* color map for this purpose. Hence, words that have a positive score to the prediction are be shaded in *red*, while  negative-contribution or zero-contribution words are then highlighted in *blue*, and *white*, respectively.


In [144]:
# This is a utility method visualizing the relevance scores of each word to the network's prediction. 
# one might skip understanding the function, and see its output first.
def plot_text_heatmap(words, scores, title="", width=10, height=0.2, verbose=0, max_word_per_line=20):
    fig = plt.figure(figsize=(width, height))
    
    ax = plt.gca()

    ax.set_title(title, loc='left')
    tokens = words
    if verbose > 0:
        print('len words : %d | len scores : %d' % (len(words), len(scores)))

    cmap = plt.cm.ScalarMappable(cmap=cm.bwr)
    cmap.set_clim(0, 1)
    
    canvas = ax.figure.canvas
    t = ax.transData

    normalized_scores = 0.5 * scores / np.max(np.abs(scores)) + 0.5
    
    if verbose > 1:
        print('Raw score')
        print(scores)
        print('Normalized score')
        print(normalized_scores)

    # make sure the heatmap doesn't overlap with the title
    loc_y = -0.2

    for i, token in enumerate(tokens):
        *rgb, _ = cmap.to_rgba(normalized_scores[i], bytes=True)
        color = '#%02x%02x%02x' % tuple(rgb)
        
        text = ax.text(0.0, loc_y, token,
                       bbox={
                           'facecolor': color,
                           'pad': 5.0,
                           'linewidth': 0.5,
                           'boxstyle': 'round,pad=0.5'
                       }, transform=t)

        text.draw(canvas.get_renderer())
        ex = text.get_window_extent()
        
        # create a new line if the line exceeds the length
        if (i+1) % max_word_per_line == 0:
            loc_y = loc_y -  2.5
            t = ax.transData
        else:
            t = transforms.offset_copy(text._transform, x=ex.width+15, units='dots')

    if verbose == 0:
        ax.axis('off')

In [196]:
# Traverse over the analysis results and visualize them.
for i, idx in enumerate(test_sample_indices):

    words = [decoder[t] for t in list(DATASETS['testing']['encoded_articles'][idx])]
    
    print('Review(id=%d): %s' % (idx, ' '.join(words)))
    y_true = DATASETS['testing']['y'][idx]
    y_pred = test_sample_preds[i]

    print("Pred class : %s %s" %
          (LABEL_IDX_TO_NAME[y_pred], '✓' if y_pred == y_true else '✗ (%s)' % LABEL_IDX_TO_NAME[y_true])
         )
                                
    for j, method in enumerate(methods):
        plot_text_heatmap(words, analysis[i, j].reshape(-1), title='Method: %s' % method, verbose=0)
        plt.show()