# Interpreting predictions on MNIST

In [8]:
# imports
import tempfile
import numpy as np
from matplotlib import pyplot as plt
from depiction.core import Task, DataType

In [18]:
# plotting
plt.gray()

In [27]:
# general variables
# NOTE: get a valid cache (but only once!)
CACHE_DIR = None
if CACHE_DIR is None:
    CACHE_DIR = tempfile.mkdtemp()

In [63]:
# a bunch of useful functions
def transform(x):
    """
    Move to -0.5, 0.5 range and add channel dimension.
    
    Args:
        x (np.ndarray): a 2D-array representing an image.

    Returns:
        np.ndarray: a 3D-array representing the transformed image.
    """
    return np.expand_dims(x.astype('float32') / 255 - 0.5, axis=-1)


def transform_sample(x):
    """
    Add dimension representing the batch size.

    Args:
        x (np.ndarray): a 3D-array represnting an image.
    
    Returns:
        np.ndarray: a 4D-array representing a batch with a
            single image.
    """
    return np.expand_dims(transform(x), axis=0)


def inverse_transform(x):
    """
    Apply an inverse transform on a batch with a single image.

    Args:
        x (np.ndarray): a 4D-array representing a batch with a
            single image.
    
    Returns:
        np.ndarray: a 3D-array represnting an image.
    """
    return (x.squeeze() + 0.5) * 255


def show_image(x, title=None):
    """
    Show an image.

    Args:
        x (np.ndarray): a 4D-array representing a batch with a
            single image.
        title (str): optional title.
    """
    axes_image = plt.imshow(x.squeeze())
    axes_image.axes.set_xticks([], [])
    axes_image.axes.set_yticks([], [])
    if title is not None:
        axes_image.axes.set_title(title)
    return axes_image

## Instantiate a model to intepret

In [38]:
from depiction.models.keras import KerasModel
from tensorflow.keras.models import load_model
from depiction.models.base.utils import get_model_file

depiction_model = KerasModel(
    load_model(
        get_model_file(
            filename='mninst_cnn.h5',
            origin='https://ibm.box.com/shared/static/v3070m2y62qw4mpwl04pee75n0zg681g.h5',
            cache_dir=CACHE_DIR
        )
    ),
    task=Task.CLASSIFICATION, data_type=DataType.IMAGE
)
depiction_model._model.summary()

## Get data

In [49]:
from tensorflow.keras.datasets.mnist import load_data
(x_train, y_train), (x_test, y_test) = load_data()
print('x_train shape:', x_train.shape, 'y_train shape:', y_train.shape)
print('x_test shape:', x_test.shape, 'y_test shape:', y_test.shape)

In [67]:
index = 42
example = transform_sample(x_train[index])
label = y_train[index]
show_image(
    example,
    title=f'True={label} Predicted={np.argmax(depiction_model.predict(example))}'
)

## LIME

In [68]:
from depiction.interpreters.u_wash import UWasher

interpreter = UWasher('lime', depiction_model)

In [76]:
# handle example shapes

In [72]:
interpreter.interpret(example)

## Anchors

In [0]:
from depiction.interpreters.u_wash import UWasher

interpreter = UWasher('anchors', depiction_model)

In [0]:
interpreter.interpret(example)

## CEM

## Counterfactual explanations