# Interpreting predictions on ImageNet

In [None]:
# general imports
import warnings; warnings.filterwarnings("ignore", category=FutureWarning)
import tensorflow as tf; tf.logging.set_verbosity(tf.logging.ERROR)  # suppress deprecation messages
import os
import json
import numpy as np
import keras_applications
from tensorflow import keras
from ipywidgets import interact
from matplotlib import pyplot as plt

from depiction.core import DataType, Task

In [None]:
# plotting
plt.rcParams['figure.figsize'] = [20, 10]

In [None]:
# general utils
def image_preprocessing(image_path, preprocess_input, target_size):
    """
    Read and preprocess an image from disk.

    Args:
        image_path (str): path to the image.
        preprocess_input (funciton): a preprocessing function.
        target_size (tuple): image target size.

    Returns:
        np.ndarray: the preprocessed image.
    """
    image = keras.preprocessing.image.load_img(
        image_path, target_size=target_size
    )
    x = keras.preprocessing.image.img_to_array(image)
    x = np.expand_dims(x, axis=0)
    return preprocess_input(x)


def get_imagenet_labels():
    """
    Get ImamgeNet labels.

    Returns:
        list: list of labels.
    """
    labels_filepath = keras.utils.get_file(
        'imagenet_class_index.json',
        keras_applications.imagenet_utils.CLASS_INDEX_PATH
    )
    with open(labels_filepath) as fp:
        labels_json = json.load(fp)
    labels = [None] * len(labels_json)
    for index, (_, label) in labels_json.items():
        labels[int(index)] = label
    return labels


def show_image(x, title=None):
    """
    Show an image.

    Args:
        x (np.ndarray): a 4D-array representing a batch with a
            single image.
        title (str): optional title.
    """
    axes_image = plt.imshow(x.squeeze())
    axes_image.axes.set_xticks([], [])
    axes_image.axes.set_yticks([], [])
    if title is not None:
        axes_image.axes.set_title(title)
    return axes_image

## Instantiate a model to intepret

In [None]:
from depiction.models.keras import KerasApplicationModel
# instantiate the model
model = KerasApplicationModel(
    keras.applications.MobileNetV2(), Task.CLASSIFICATION, DataType.IMAGE
)

## Get data

In [None]:
# get labels
labels = get_imagenet_labels()
examples = {}
for filename, url in [
    ('elephant.jpg', 'https://upload.wikimedia.org/wikipedia/commons/thumb/f/f9/Zoorashia_elephant.jpg/120px-Zoorashia_elephant.jpg'),
    ('dog.jpg', 'https://upload.wikimedia.org/wikipedia/commons/thumb/1/15/Welsh_Springer_Spaniel.jpg/400px-Welsh_Springer_Spaniel.jpg'),
    ('cat.jpg', 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Six_weeks_old_cat_%28aka%29.jpg/400px-Six_weeks_old_cat_%28aka%29.jpg'),
    ('cat-and-dog.jpg.', 'https://upload.wikimedia.org/wikipedia/commons/9/97/Greyhound_and_cat.jpg'),
    ('plush.jpg', 'https://upload.wikimedia.org/wikipedia/commons/thumb/5/51/Plush_bunny_with_headphones.jpg/320px-Plush_bunny_with_headphones.jpg')
]:
    filepath = keras.utils.get_file(filename, url)
    examples[filename.split('.')[0]] = image_preprocessing(
        filepath,
        keras.applications.mobilenet_v2.preprocess_input,
        target_size=(224, 224)
    )
interact(lambda key: show_image(examples[key], title=f'{key}'), key=examples.keys());

In [None]:
# pick an example
image = examples['elephant']

## LIME

In [None]:
from depiction.interpreters.u_wash import UWasher

interpreter = UWasher('lime', model, class_names=labels)

In [None]:
explanation = interpreter.interpret(image)

## Anchors

In [None]:
from depiction.interpreters.u_wash import UWasher

interpreter = UWasher('anchors', model)

In [None]:
explanation = interpreter.interpret(image)