# Explaining Distribution Shifts in Time using Lime and Some Standard Classifiers

In this notebook, we will utilise MNIST to create our own simple dataset with images of geometric shapes that includes a distribution shift in time. We will then use standard classifiers to classify whether the image was present before or after the shift and explain those classifications using Lime.

__THE LIME PACKAGE NEEDS TO BE INSTALLED TO RUN THIS EXAMPLE__

In [None]:
import os,sys
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import gray2rgb, rgb2gray, label2rgb # since the code wants color images

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import Normalizer
from sklearn.decomposition import PCA

import lime
from lime import lime_image
from lime.wrappers.scikit_image import SegmentationAlgorithm

## Loading the Dataset

You can either cache the dataset (size ~ 900MB) or load it anew each time you start the notebook.

In [None]:
from sklearn.datasets import fetch_openml

mnist = fetch_openml('mnist_784')

In [None]:
#adapt image format for lime_image
X_vec = np.array(mnist.data).reshape((-1, 28, 28))
y_vec = np.array(mnist.target).astype(int)

## Creating our Artificial Dataset

For this notebook, we derive images from MNIST that show vertical lines (normal ones), horizontal lines (rotated ones) and plus signs (a combination of the two). That way, we design three classes with very simple geometric features. This is ideal to test how well a model can explain detected drift if we construct our data so that one class is only present before, one only after the drift and the third class remains consistent.

In [None]:
X_vec_1 = X_vec[y_vec == 1]

X_vec_1_turned = np.swapaxes(X_vec_1, 1,2)[np.random.permutation(X_vec_1.shape[0])] # horizontal line
X_vec_1 = X_vec_1[np.random.permutation(X_vec_1.shape[0])] # vertical line

#we combine the two lines to get a "+"
X_vec_plus = np.concatenate( (X_vec_1[np.random.permutation(X_vec_1.shape[0])][:,:,:,None],X_vec_1_turned[np.random.permutation(X_vec_1_turned.shape[0])][:,:,:,None]), axis=3).max(axis=3)

y = np.array(X_vec_1.shape[0]*[0]+X_vec_1_turned.shape[0]*[1]+X_vec_plus.shape[0]*[2])
X = np.vstack( (X_vec_1, X_vec_1_turned, X_vec_plus) )

fig, ax1 = plt.subplots(1,1)
ax1.imshow(X_vec_1.mean(axis=0), interpolation = 'none')
ax1.set_title("Average of digit 1 - A vertical line")
plt.show()

fig, ax1 = plt.subplots(1,1)
ax1.imshow(X_vec_1_turned.mean(axis=0), interpolation = 'none')
ax1.set_title("Average of digit 1 rotated - A horizontal line")
plt.show()

fig, ax1 = plt.subplots(1,1)
ax1.imshow(X_vec_plus.mean(axis=0), interpolation = 'none')
ax1.set_title("Both averages combined - a plus sign")
plt.show()

## Training with the Classifier

Before we start looking at the quality of our explanations, it is useful to briefly test whether the classifier performs well on our data at all. (As it is a simple dataset with clear features, good performance can be expected but we want to make sure.)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1,28*28), y, test_size=0.50)

print(ExtraTreesClassifier(max_depth=3).fit(X_train,y_train).score(X_test,y_test))

Now that we know that our classifier works, we create a stream from our data on which we can actually perform drift detection. The change that occurs with time is simple - during the first n timesteps, we will see only vertical "lines" (or ones), while after the shift we will see only horizontal lines (or rotated ones). The "plus sign" we created through overlap will be part of the data from before and after the shift. Notice that it has vertical and horizontal lines and thus carries characteristics from both time-points.

In [None]:
## Create Streams

n = 2500
sel = np.hstack((np.random.choice(np.where(y!=1)[0], n),np.random.choice(np.where(y!=0)[0], n)))
stream_X = X[sel]
stream_y = np.array( n*[0]+n*[1] ) #our stream_y now contains info on whether the image is from before or after
stream_c = y[sel] #stream_c contains information about the original classes

With our datastream constructed, we can now define a pipeline to preprocess the data and to apply the classifier to the stream. We then fit this model to out data.

In [None]:
## Example using the ExtraTreesClassifier on the dataset

#making sure our streams have the right types
stream_X = np.stack([gray2rgb(iimg) for iimg in stream_X],0).astype(np.uint8)
stream_y = stream_y.astype(np.uint8)

#helper functions for preprocessing
class step_wrap(object):
    def __init__(self, step_func):
        self._step_func=step_func
    def fit(self,*args):
        return self
    def transform(self,X):
        return self._step_func(X)


def flatten_step_(images):
    flats = []
    for img in images:
        flats.append(img.ravel())
    return flats

def makegray_step_(images):
    grays = []
    for img in images:
        grays.append(rgb2gray(img))
        
    return grays

#wrapping up the preprocessing functions

makegray_step = step_wrap(makegray_step_)
flatten_step = step_wrap(flatten_step_)

#defining our pipeline
simple_rf_pipeline = Pipeline([
   ('Make Gray', makegray_step),
    ('Flatten Image', flatten_step),
    #('Normalize', Normalizer()),
    #('PCA', PCA(5)),
    #('RF', RandomForestClassifier())
    ('ET',ExtraTreesClassifier(max_depth=8))
    #('DT',DecisionTreeClassifier(max_depth=8))
    #("MLP",MLPClassifier(max_iter=500))
                              ])


#fitting the pipeline to our data
simple_rf_pipeline.fit(stream_X, stream_y)


%load_ext autoreload
%autoreload 2

## Explaining the Results

To explain the drift, we use the LimeImageExplainer.

In [None]:
explainer = lime_image.LimeImageExplainer(verbose = False)
segmenter = SegmentationAlgorithm('slic', kernel_size=1, max_dist=200, ratio=0.2)

We apply the explanation algorithm only to a few random example images from MNIST in this notebook. This nicely illustrates which reasons weigh positively and which weigh negatively when it comes to classifying whether a specific image belongs to the period before or after the shift.

In [None]:
#Iterating over all three types of images
found = 0
for s in [0, 1, 2]:
    
    legend_s = {0 : "vertical line", 1 : "horizontal line", 2 : "combination/plus sign"}

    example = np.random.permutation(np.where(stream_c == s)[0])[0]
    print("Example number ", example)

    #This is where we derive the explanation for the specific example image
    explanation = explainer.explain_instance(stream_X[example], 
                                             classifier_fn = simple_rf_pipeline.predict_proba, 
                                             top_labels=10, hide_color=0, num_samples=10000, segmentation_fn=segmenter)



    fig, (ax1, ax2) = plt.subplots(1,2, figsize = (8, 4))
    fig.suptitle('Positive/Negative Regions for {} before and after the shift'.format(legend_s[s]))

    #Image "explanation" before the drift
    temp1, mask1 = explanation.get_image_and_mask(0, positive_only=True, num_features=1000, hide_rest=False, min_weight = 0.01)
    ax1.imshow(label2rgb(3-mask1,temp1, bg_label = 0), interpolation = 'nearest')
    #Image "explanation" after the drift
    temp2, mask2 = explanation.get_image_and_mask(1, positive_only=True, num_features=1000, hide_rest=False, min_weight = 0.01)
    ax2.imshow(label2rgb(3-mask2,temp2, bg_label = 0), interpolation = 'nearest')

    if mask1.sum() > 0 and mask2.sum() > 0:      
        found += 1

        ax1.set_xticks([])
        ax1.set_yticks([])
        ax2.set_xticks([])
        ax2.set_yticks([])

    plt.show()

We can see that the lime algorithm shows that the images belonging clearly to either time frame are classified very confidently. The images of the crosses, which were present in equal frequency throughout time, show that the regions that they have in common with either vertical or horizontal lines weigh a little differently when considered for classification before or after the time shift, indicating that the classifier indeed emphasizes the vertical and horizontal line areas especially.