In this dashboard we will look the performance of an MNIST classifier. For convenience, we will load a pre-trained model.

An interactive heatmap is used to display the confusion matrix. Click on any cell of the heatmap to see the corresponding mis-labeled images

In [1]:
# disable logging messages when rendered as a voila dashboard
# comment out the code below if running as an interactive notebook
import logging, sys
logging.disable(sys.maxsize)

In [2]:
import numpy as np
import pandas as pd
from tensorflow import keras
import ipywidgets as w
import bqplot as bq
import bqplot.pyplot as plt

from PIL import Image

In [3]:
def create_image_from_array(imarray):
    """
    creates an image widget from 2d numpy array
    """
    image = Image.fromarray(imarray)
    image_widget = w.Output()
    with image_widget:
        display(image)
    return image_widget

In [4]:
num_classes = 10
# the data, split between train and test sets
(_, _), (x_test, y_test) = keras.datasets.mnist.load_data()

x_test = np.expand_dims(x_test, -1)
y_test = keras.utils.to_categorical(y_test, num_classes)

In [5]:
model = keras.models.load_model("mnist_model")

2022-01-24 15:35:06.594658: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [6]:
from sklearn.metrics import confusion_matrix
y_test = np.argmax(y_test, axis=-1)
y_test_pred = np.argmax(model.predict(x_test), axis=-1)

results_df = pd.DataFrame({'actual': y_test, 'predicted': y_test_pred})

In [7]:
conf_mat = confusion_matrix(y_test, y_test_pred).astype('float')
np.fill_diagonal(conf_mat, np.nan)

In [8]:
import bqplot.pyplot as plt
import bqplot as bq

heatmap_fig = plt.figure(layout=dict(width='600px', height='540px'), 
                         title='Confusion Matrix')
plt.scales(scales={'color': bq.ColorScale(scheme='Oranges')})
axes_options = ({'color': {'visible': False}, 
                 'column': {'label': 'Predicted'},
                 'row': {'label': 'Actual'}})
conf_mat_heatmap = plt.gridheatmap(conf_mat, axes_options=axes_options, 
                                   font_style={'font-size': '16px', 
                                               'font-weight': 'bold'},
                                   interactions={'click':'select'},
                                   selected_style={'stroke': 'limegreen', 
                                                   'stroke-width': 3}, 
                                   display_format='.0f')
plt.grids(heatmap_fig, 'none')

label_tmpl = '<div style="font-size: 16px">Mis-labeled Images<br>{}</div>'
images_label = w.HTML(label_tmpl.format(''))
images_placeholder = w.Box()

In [9]:
# def on_heatmap_cell_select(*args):
#     """
#     display mis-labeled images when heatmap cell is clicked
#     """
#     actual, predicted = conf_mat_heatmap.selected[0]
#     if actual == predicted:
#         images_label.value = label_tmpl.format("")
#         images_placeholder.children = []
#     else:
#         images_label.value = label_tmpl.format(f"Actual: {actual} Predicted: {predicted}")
#         ixs = results_df[(results_df['actual'] == actual) & (results_df['predicted'] == predicted)].index
#         # arrange images in a square grid
#         n = len(ixs)
#         ncols = int(np.ceil(np.sqrt(n)))
#         img_rows = []
#         for i in range(ncols):
#             ix_slice = ixs[i*ncols:(i+1)*ncols]
#             if len(ix_slice) > 0:
#                 print(ix_slice)
#                 col_images = w.HBox([create_image_from_array(x_test[ix].reshape(28, 28)) for ix in ix_slice])
#                 img_rows.append(col_images)
#         images_placeholder.children = [w.VBox(img_rows)]

# conf_mat_heatmap.observe(on_heatmap_cell_select, 'selected')

In [10]:
w.HBox([heatmap_fig, 
        w.VBox([images_label, images_placeholder],
               layout={'margin': '60px 0px 0px 0px'})
       ])

HBox(children=(Figure(axes=[ColorAxis(grid_lines='none', scale=ColorScale(scheme='Oranges'), visible=False), A…

ValueError: Unsupported dtype object

ValueError: Unsupported dtype object