# Interpreting predictions on MNIST

In [None]:
# general imports
import warnings; warnings.filterwarnings("ignore", category=FutureWarning)
import tensorflow as tf; tf.logging.set_verbosity(tf.logging.ERROR)  # suppress deprecation messages
import tempfile
import numpy as np
from matplotlib import pyplot as plt

from depiction.core import Task, DataType

In [None]:
# plotting
_ = plt.gray()
plt.rcParams['figure.figsize'] = [20, 10]

In [None]:
# general variables
# NOTE: get a valid cache (but only once!)
CACHE_DIR = None
if CACHE_DIR is None:
    CACHE_DIR = tempfile.mkdtemp()

In [None]:
# general utils
def transform(x):
    """
    Move to -0.5, 0.5 range and add channel dimension.
    
    Args:
        x (np.ndarray): a 2D-array representing an image.

    Returns:
        np.ndarray: a 3D-array representing the transformed image.
    """
    return np.expand_dims(x.astype('float32') / 255 - 0.5, axis=-1)


def transform_sample(x):
    """
    Add dimension representing the batch size.

    Args:
        x (np.ndarray): a 3D-array represnting an image.
    
    Returns:
        np.ndarray: a 4D-array representing a batch with a
            single image.
    """
    return np.expand_dims(transform(x), axis=0)


def inverse_transform(x):
    """
    Apply an inverse transform on a batch with a single image.

    Args:
        x (np.ndarray): a 4D-array representing a batch with a
            single image.
    
    Returns:
        np.ndarray: a 3D-array represnting an image.
    """
    return (x.squeeze() + 0.5) * 255


def show_image(x, title=None):
    """
    Show an image.

    Args:
        x (np.ndarray): a 4D-array representing a batch with a
            single image.
        title (str): optional title.
    """
    axes_image = plt.imshow(x.squeeze())
    axes_image.axes.set_xticks([], [])
    axes_image.axes.set_yticks([], [])
    if title is not None:
        axes_image.axes.set_title(title)
    return axes_image

## Instantiate a model to intepret

In [None]:
from depiction.models.keras import KerasModel
from tensorflow.keras.models import load_model
from depiction.models.base.utils import get_model_file

depiction_model = KerasModel(
    load_model(
        get_model_file(
            filename='mninst_cnn.h5',
            origin='https://ibm.box.com/shared/static/v3070m2y62qw4mpwl04pee75n0zg681g.h5',
            cache_dir=CACHE_DIR
        )
    ),
    task=Task.CLASSIFICATION, data_type=DataType.IMAGE
)
depiction_model._model.summary()

## Get data

In [None]:
from tensorflow.keras.datasets.mnist import load_data
(x_train, y_train), (x_test, y_test) = load_data()
print('x_train shape:', x_train.shape, 'y_train shape:', y_train.shape)
print('x_test shape:', x_test.shape, 'y_test shape:', y_test.shape)

In [None]:
index = 42
example = transform_sample(x_test[index])
label = y_test[index]
show_image(
    example,
    title=(
        f'True={label} '
        f'Predicted={np.argmax(depiction_model.predict(example))}'
    )
)

## How is our model doing?

In [None]:
from tensorflow.keras.utils import to_categorical
score = depiction_model._model.evaluate(
    transform(x_test), to_categorical(y_test), verbose=0
)
print(f'Test accuracy: {score[1]}')

## LIME

In [None]:
from depiction.interpreters.u_wash import UWasher

lime_interpreter = UWasher('lime', depiction_model)

In [None]:
lime_interpreter.interpret(example)

## CEM

In [None]:
# some utilities
def show_cem_explanation(explanation, mode):
    """
    Show a CEM explanation for images.

    Args:
        explanation (dict): CEM explanation.
        mode (str): CEM mode, PP or PN.
    """
    prediction_key = f'{mode}_pred'
    if prediction_key in explanation:
        print(f'{mode} prediction: {explanation[prediction_key]}')
        show_image(
            explanation[mode],
            title=(f'{mode} explanation for example provided.')
        )

In [None]:
# setting some parameters
shape = example.shape
kappa = 0.  # minimum difference needed between the prediction probability for the perturbed instance on the
# class predicted by the original instance and the max probability on the other classes
# in order for the first loss term to be minimized
beta = .1  # weight of the L1 loss term
gamma = 100  # weight of the optional auto-encoder loss term
c_init = 1.  # initial weight c of the loss term encouraging predictions for the perturbed instance compared to the original instance to be explained
c_steps = 10  # updates for c
max_iterations = 10  # iterations per value of c
feature_range = (x_train.min(), x_train.max())  # feature range for the perturbed instance
clip = (-1000., 1000.)  # gradient clipping
lr = 1e-2  # initial learning rate
no_info_val = -1.  # picking value close to background

In [None]:
# NOTE: CEM supports the usage of an autoencoder to impose a variation on a latent manifold
ae = load_model(
    get_model_file(
        filename='mninst_ae.h5',
        origin=
        'https://ibm.box.com/shared/static/psogbwnx1cz0s8w6z2fdswj25yd7icpi.h5',  # noqa
        cache_dir=CACHE_DIR
    )
)
ae.summary()

In [None]:
from depiction.interpreters.alibi import CEM

cem_pn_interpreter = CEM(
    depiction_model,
    'PN',  # pertinent negative
    shape,
    kappa=kappa,
    beta=beta,
    feature_range=feature_range,
    gamma=gamma,
    ae_model=ae,
    max_iterations=max_iterations,
    c_init=c_init,
    c_steps=c_steps,
    learning_rate_init=lr,
    clip=clip,
    no_info_val=no_info_val
)

In [None]:
cem_pn_explanation = cem_pn_interpreter.interpret(example)
show_cem_explanation(cem_pn_explanation, 'PN')

In [None]:
from depiction.interpreters.alibi import CEM

cem_pp_interpreter = CEM(
    depiction_model,
    'PP',  # pertinent positive
    shape,
    kappa=kappa,
    beta=beta,
    feature_range=feature_range,
    gamma=gamma,
    ae_model=ae,
    max_iterations=max_iterations,
    c_init=c_init,
    c_steps=c_steps,
    learning_rate_init=lr,
    clip=clip,
    no_info_val=no_info_val
)

In [None]:
cem_pp_explanation = cem_pp_interpreter.interpret(example)
show_cem_explanation(cem_pp_explanation, 'PP')

## Counterfactual explanations

In [None]:
def show_counterfactual_explanation(explanation):
    """
    Show a counterfactual explanation for images.

    Args:
        explanation (dict): counterfactual explanation.
    """
    predicted_class = explanation['cf']['class']
    probability = explanation['cf']['proba'][0][predicted_class]
    print(f'Counterfactual prediction: {predicted_class} with probability {probability}')
    show_image(explanation['cf']['X'])

In [None]:
# setting some parameters
shape = example.shape
target_proba = 1.0
tol = 0.1 # tolerance for counterfactuals
max_iter = 10
lam_init = 1e-1
max_lam_steps = 10
learning_rate_init = 0.1
feature_range = (x_train.min(),x_train.max())

In [None]:
from depiction.interpreters.alibi import Counterfactual

counterfactual_interpreter = Counterfactual(
    depiction_model,
    shape=shape, target_proba=target_proba, tol=tol,
    target_class='other',  # any other class
    max_iter=max_iter, lam_init=lam_init,
    max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init,
    feature_range=feature_range
)

In [None]:
counterfactual_explanation = counterfactual_interpreter.interpret(example)
show_counterfactual_explanation(counterfactual_explanation)

In [None]:
from depiction.interpreters.alibi import Counterfactual

counterfactual_interpreter = Counterfactual(
    depiction_model,
    shape=shape, target_proba=target_proba, tol=tol,
    target_class=1,  # focusing on a specific class 
    max_iter=max_iter, lam_init=lam_init,
    max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init,
    feature_range=feature_range
)

In [None]:
counterfactual_explanation = counterfactual_interpreter.interpret(example)
show_counterfactual_explanation(counterfactual_explanation)