# MNIST Example

This notebook illustrates the use of the proposed in(n)vestigation methods on the MNIST dataset.

# Imports

In [1]:
import warnings
warnings.simplefilter('ignore')

In [2]:
%matplotlib inline  

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import imp
import os

import keras
import keras.backend
import keras.models
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Dropout, Activation, Input
from keras.optimizers import RMSprop, Adam

import innvestigate
import innvestigate.utils as iutils
import innvestigate.utils.tests.networks.base
import innvestigate.utils.visualizations as ivis


eutils = imp.load_source("utils", "../utils.py")
mnistutils = imp.load_source("utils_mnist", "../utils_mnist.py")

Using TensorFlow backend.


# Data

Load MNIST data.

In [3]:
# Load data
channels_first = keras.backend.image_data_format == "channels_first"
data = mnistutils.fetch_data(channels_first)
num_classes = len(np.unique(data[1]))

# Test samples for illustrations
images = [(data[2][i].copy(), data[3][i]) for i in range(num_classes)]
label_to_class_name = [str(i) for i in range(num_classes)]

60000 train samples
10000 test samples


Preprocess data.

In [4]:
# Parameter
zero_mean = False

data_preprocessed = (mnistutils.preprocess(data[0],zero_mean), data[1],
                     mnistutils.preprocess(data[2],zero_mean), data[3])  #TODO: change this!!

# Model

Create & train a Multilayer Perceptron with two fully connected layers.

In [5]:
# Parameter
batch_size = 64
epochs = 5
activation_type = "relu"

# Create & train model
model, modelp = mnistutils.create_model(channels_first, activation_type, num_classes)
mnistutils.train_model(modelp, data_preprocessed, batch_size=batch_size, epochs=epochs)
model.set_weights(modelp.get_weights())

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Test loss: 0.08198731753372121
Test accuracy: 0.9764


# Analysis

Use below specified methods to in(n)vestigate and visualize the learned classifier on a few example images. 

Hereby the analyse method containing name, optional parameter, postprocessing information and a title for the final visualisation is passed into the framework by a n-tuple. The available methods are subdivided into three groups depending on their principal approach: gradient-based, pattern-based and relevance-based investigation methods. For a full list of methods please refer to the script in `innvestigate/innvestigate/analyzer/__init__.py` or in the list below (available upon first release).

In [None]:
# Methods we use and some properties.
methods = [
    # NAME                                             POSTPROCESSING       TITLE

    # Show input.
    ("input",                 {},                       mnistutils.image,   "Input"),

    # Function
    ("gradient",              {},                       mnistutils.graymap, "Gradient"),
    ("smoothgrad",            {"noise_scale": 50},      mnistutils.graymap, "SmoothGrad"),
    ("integrated_gradients",  {},                       mnistutils.graymap, ("Integrated", "Gradients")),

    # Signal
    ("deconvnet",             {},                       mnistutils.bk_proj, "Deconvnet"),
    ("guided_backprop",       {},                       mnistutils.bk_proj, ("Guided", "Backprop"),),
    ("pattern.net",           {},                       mnistutils.bk_proj, "PatterNet"),

    # Interaction
    ("pattern.attribution",   {},                       mnistutils.heatmap, "Pattern", "Attribution"),
    ("lrp.z",                 {},                       mnistutils.heatmap, "LRP"),
]

In [None]:
# Create analyzers.

pattern_type = activation_type
analyzers = []
for method in methods:
    analyzer = innvestigate.create_analyzer(method[0],
                                            model,
                                            **method[1])
    analyzer.fit(data_preprocessed[0], pattern_type=pattern_type,
                 batch_size=256, verbose=1)
    analyzers.append(analyzer)

# Create analysis.
analysis = np.zeros([len(images), len(analyzers), 28, 28, 3])
text = []
for i, (image, y) in enumerate(images):
    image = image[None, :, :, :]
    # Predict label.
    x = mnistutils.preprocess(image, zero_mean)
    presm = model.predict_on_batch(x)[0]
    prob = modelp.predict_on_batch(x)[0]
    y_hat = prob.argmax()

    text.append(("%s" %label_to_class_name[y], "%.2f" %presm.max(), 
                 "%.2f" % prob.max(), "%s" %label_to_class_name[y_hat]))

    for aidx, analyzer in enumerate(analyzers):
        is_input_analyzer = methods[aidx][0] == "input"
        # Analyze.
        a = analyzer.analyze(image if is_input_analyzer else x)
        # Postprocess.
        if not is_input_analyzer:
            a = mnistutils.postprocess(a)
        a = methods[aidx][2](a)
        analysis[i, aidx] = a[0]

Epoch 1/1
Epoch 1/1

In [None]:
# Plot the analysis.

grid = [[analysis[i, j] for j in range(analysis.shape[1])]
        for i in range(analysis.shape[0])]
row_labels = text
col_labels = [''.join(method[3]) for method in methods]

eutils.plot_image_grid(grid, row_labels, col_labels,
                       file_name=None,
                       row_label_offset=0,
                       col_label_offset=0,
                       is_fontsize_adaptive=True,
                       usetex=False,
                       dpi=224)