# TF-Explan Callback API Example

tf-explain supports two APIs: the Core API which allows you to interpret a model after it was trained and a Callback API which lets you use callbacks to monitor the model whilst training. 

This notebook walks you through using the Callback API to analyze a model trained on the fashion mnist data-set.


In [1]:
import numpy as np
import tensorflow as tf
import tf_explain

Load data-set:

In [2]:
NUM_CLASSES = 10

dataset = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = dataset.load_data()


Add axis to data:

In [3]:
train_images = train_images[..., tf.newaxis].astype('float32')
test_images = test_images[..., tf.newaxis].astype('float32')

One hot encode labels:

In [4]:
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=NUM_CLASSES)
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=NUM_CLASSES)

Create a simple convolutional neural network:

In [5]:
def create_model(input_shape=(28, 28, 1), num_classes=10):
    img_input = tf.keras.Input(input_shape)

    x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(img_input)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu', name='target_layer')(x)
    x = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(x)

    x = tf.keras.layers.Dropout(0.25)(x)
    x = tf.keras.layers.Flatten()(x)

    x = tf.keras.layers.Dense(128, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.5)(x)

    x = tf.keras.layers.Dense(num_classes, activation='softmax')(x)

    model = tf.keras.Model(img_input, x)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    
    return model

In [6]:
model = create_model()

Select a subset of the validation data to examine with tf-explain. We will choose images from one class.

In [7]:
# in this case we use class zero and one

validation_class_zero = (np.array([
    el for el, label in zip(test_images, test_labels)
    if np.all(np.argmax(label) == 0)
][0:5]), None)

validation_class_one = (np.array([
    el for el, label in zip(test_images, test_labels)
    if np.all(np.argmax(label) == 1)
][0:5]), None)



In [8]:
callbacks = [
    tf_explain.callbacks.GradCAMCallback(validation_class_zero, 'target_layer', class_index=0),
    tf_explain.callbacks.GradCAMCallback(validation_class_one, 'target_layer', class_index=4),
    tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, layers_name=['target_layer']),
    tf_explain.callbacks.SmoothGradCallback(validation_class_zero, class_index=0, num_samples=15, noise=1.),
    tf_explain.callbacks.IntegratedGradientsCallback(validation_class_zero, class_index=0, n_steps=10),
    tf_explain.callbacks.VanillaGradientsCallback(validation_class_zero, class_index=0),
]

In [9]:
%load_ext tensorboard
%tensorboard --logdir logs

Reusing TensorBoard on port 6006 (pid 1372), started 0:47:55 ago. (Use '!kill 1372' to kill it.)

In [10]:
# Start training
model.fit(train_images, train_labels, epochs=5, callbacks=callbacks)

Train on 60000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x1b9c5936208>