In [5]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from alibi.prototypes import ProtoSelect
from alibi.prototypes.protoselect import cv_protoselect_euclidean
from alibi.utils.kernel import EuclideanDistance
from sklearn.model_selection import train_test_split

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load MNIST dataset

In [2]:
(X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()
X_train = X_train.reshape(len(X_train), -1) / 255.
X_test = X_test.reshape(len(X_test), -1) / 255.

# get validation
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=1000, random_state=0)

# select random samples from train an test
size = 1000
train_idx = np.random.choice(len(X_train), size=size, replace=False)
test_idx = np.random.choice(len(X_test), size=size, replace=False)
X_train, Y_train = X_train[train_idx], Y_train[train_idx]
X_test, Y_test = X_test[test_idx], Y_test[test_idx]

In [102]:
num_prototypes = 20
grid_size = 50
quantiles = (0.1, 0.9)

# get best eps by cv
eps = cv_protoselect_euclidean(refset=(X_train, Y_train),
                               protoset=(X_test,),
#                                valset=(X_val, Y_val),
                               valset = None,
                               num_splits=5,
                               num_prototypes=num_prototypes,
                               quantiles=quantiles,
                               grid_size=grid_size)

In [103]:
explainer = ProtoSelect(kernel_distance=EuclideanDistance(), eps=eps)
explainer = explainer.fit(X=X_train, X_labels=Y_train, Y=X_test)

In [104]:
explanation = explainer.explain(num_prototypes=num_prototypes)

In [105]:
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=1)
knn = knn.fit(X=explanation.data['prototypes'], y=explanation.data['prototypes_labels'])
knn.score(X_test, Y_test)

0.726

In [106]:
explanation.data['prototypes_labels']

array([0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9],
      dtype=int32)