# Explaining Distribution Shifts in Time using Random Forests and t-SNE on Fashion-MNIST

In this Python notebook, we use machine learning techniques to explain a distribution shift in a time series. Specifically, we aim to identify which data points are specific to the time before or after the distribution shift, as well as which data points are distributed independently of time.

To achieve this goal, we train a random forest (RF) model on the Fashion-MNIST dataset, which consists of images of fashion items. Fashion-MNIST is a classic dataset in machine learning that is often used as a benchmark for image classification and dimensionality reduciton tasks, as it presents a challenging high-dimensional classification problem with various classes to choose from. Its classes are not as easily separated as those of the classic MNIST digit dataset. The random forest model is a versatile and robust model that can handle a wide range of input data and is relatively insensitive to weak data assumptions making it a common choice in stream learning. The model is well-suited to handle the complex and high-dimensional input data, making it a good choice for this machine learning task.

To visualize the results, we use t-SNE to project the high-dimensional data onto a two-dimensional plot. The t-SNE algorithm is used to derive an embedding space from the information that is extracted during the random forest training using ideas from discriminative dimensionality reduction. The plot shows how clusters change before and after the distribution shift, providing insights to better understand the nature of the shift. By visualizing the embedding space, we can see which data points are specific to the time before or after the distribution shift and how the clusters change over time. This information can be used to improve the performance of the random forest model and gain a better understanding of the data.

Furthermore, we showcase the the observed change in the distribution using counterfactual explanation which are based on contrasting similar yet differently classified datapoints.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.sparse.linalg import eigs
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier, ExtraTreesRegressor, ExtraTreesClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cluster import MiniBatchKMeans

## Exploring the Fashion-MNIST Dataset: Mean Images, Similarity Matrix, and t-SNE Visualization

In this part, we explore the Fashion-MNIST dataset using mean images, similarity, and t-SNE. We first compute the mean image for each digit class and use it to compute the similarity between each pair of digit classes. We then visualize the similarity matrix using matplotlib. Finally, we use t-SNE to project the high-dimensional data onto a two-dimensional space, which allows us to visualize the dataset and gain insights into its structure.

In [None]:
print('Downloading the dataset...')
mnist = fetch_openml('Fashion-MNIST')#, parser='auto')
print('Download complete.')

labels = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# Store the dataset in an appropriate format
X = np.array(mnist.data).reshape((-1, 28, 28))
y = np.array(mnist.target).astype(int)

In [None]:
# Create an empty numpy array to store the mean images for each label
Xmean = np.zeros((len(set(y)), 28, 28))

# Iterate over each label in the dataset
for i in set(y):
    # Compute the mean image for the current label
    Xmean[i] = X[y == i].mean(axis=0)
    
# Create a figure with two rows and five columns to plot class means
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(10, 3))

# Iterate over each subplot and plot the data
for i in range(2):
    for j in range(5):
        label = i * 5 + j
        axs[i, j].axis('off')
        axs[i, j].matshow(Xmean[label])
        axs[i, j].set_title('{} mean:'.format(labels[label]))

# Adjust the spacing between subplots
plt.subplots_adjust(wspace=0.3, hspace=0.7)
plt.show()

As can be seen some classes are quite similar when compared on a mean value level. Consider for example Pullover, Coat, and Shirt. This is a important difference when compared to the MNIST dataset where all digit classes are easily separated by thair mean values.

We can observe a similar effect in the (pixel-wise) similarity matrix:

In [None]:
# Create an empty numpy array to store the dot product between the mean images
scalar = np.zeros((Xmean.shape[0], Xmean.shape[0]))

# Iterate over each pair of labels and compute the dot product between their mean images
for i in range(Xmean.shape[0]):
    for j in range(Xmean.shape[0]):
        scalar[i, j] = Xmean[i].flatten() @ Xmean[j].flatten()

# Normalize the similarity matrix so that its rows and columns sum to 1
scalar = scalar / (scalar.sum(axis=0)[:, None] * scalar.sum(axis=1)[None, :])

# Display the similarity matrix using matplotlib
plt.matshow(scalar)
plt.xlabel('Digit Class')
plt.xticks(range(10))
plt.ylabel('Digit Class')
plt.yticks(range(10))
plt.title('Similarity matrix between mean images')

plt.show()

In [None]:
# Randomly select 5000 samples from the dataset
embedding_selection = np.random.choice(range(X.shape[0]), size=5000, replace=False)

# Fit the t-SNE model to the selected samples and obtain the low-dimensional embeddings
print('Fitting the t-SNE model...')
tsne = TSNE(init='random', learning_rate='auto')
X_embedded = tsne.fit_transform(X[embedding_selection].reshape(-1, 28*28))
print('Fitting complete.')

In [None]:
# Set the figure size to 10x10 inches
plt.figure(figsize=(10, 10))

# Specify the percentage of samples to display for each label
percentage = 30

# Iterate over each label in the dataset
for i in set(y):
    # Select the samples with the current label
    selection_label = y[embedding_selection] == i
    # Calculate the number of samples to display based on the specified percentage
    num_samples = int(np.ceil(np.sum(selection_label) * percentage / 100))
    # Select a random subset of the samples with the current label
    selection_random = np.random.choice(np.where(selection_label)[0], size=num_samples, replace=False)
    # Create a scatter plot of the selected samples, with the label as the marker and alpha set to 0.5
    #plt.scatter(X_embedded[selection_random, 0], X_embedded[selection_random, 1], marker="$%i$" % i, alpha=0.5, s=25)
    plt.scatter(X_embedded[selection_random, 0], X_embedded[selection_random, 1], marker="$%s$" % labels[i][:4], alpha=0.5, s=250)

# Display the plot
plt.title(f't-SNE visualization of MNIST ({percentage}% of the dataset)')
plt.show()

t-SNE is based on k-nearest neighbors embeddings. Thus, the neighborhoods are transferred from the high-dimensional space to the low-dimensional embedding. As can be seen in the scatter plot, there is a large overlap between the different classes when using a metric that is not specifically designed, i.e., the Eucliden metric, which fits our previous observation. This is again a major difference from the MNIST dataset, where the classes are mostly well separated in the t-SNE plot.

## Preparing the data for the distribution shift analysis

In this part, we prepare the Fashion-MNIST dataset for the distribution shift analysis. We map each digit in the dataset to a label indicating whether it occurs before or after the change point, or both, or neither. We then remove the samples with label 0 (i.e., pieces that do not occur before nor after the change point) from the input features and labels. Finally, we randomly assign labels of 1 or 2 to samples with label 3 (i.e., pieces that occur both before and after the change point). This ensures that we have a balanced dataset with labels 1 and 2 representing the digits that occur before and after the change point, respectively.

In [None]:
# Map each digit to a label indicating whether it occurs before or after the change point, or both, or neither
#  0 - never, 1 - before, 2 - after, 3 - both
label_map = {
    0: 1,
    1: 1,
    2: 2,
    3: 2,
    4: 0,
    5: 0,
    6: 0,
    7: 3,
    8: 0,
    9: 3
}
y_mapped = np.array([label_map[digit] for digit in y])

# Remove samples with label 0 (i.e., digits that do not occur before or after the change point) 
#  from input features and labels
X_clean = X.copy().reshape(-1, 28*28)[y_mapped != 0]
y_clean = y.copy()[y_mapped != 0]
y_mapped_clean = y_mapped.copy()[y_mapped != 0]

# Randomly assign labels of 1 or 2 to samples with label 3
#  (i.e., digits that occur both before and after the change point)
label_3_idx = np.where(y_mapped_clean == 3)[0]
y_mixed = y_mapped_clean.copy()
y_mixed[label_3_idx] = np.random.choice([1, 2], size=len(label_3_idx))

## Computing the RF-Kernel Matrix and Visualizing Learned Similarities Between Digits

In this part, we train a random forest classifier on the mixed training set of fasion piece, where the pieces that occur both before and after the distribution shift have been randomly assigned to either the before or after group. We use the trained classifier to compute an RF-kernel-matrix for a subset of the input features and labels. We then display the RF-kernel-matrix, where the samples are grouped by original class label.

In [None]:
# Set a boolean flag to determine whether to skip model selection and train with max_leaf_nodes=150
skip_model_selection = True

# Split the dataset into a training set and a test set, with a 55/45 split
X_clean_train, X_clean_test, y_mixed_train, y_mixed_test = \
    train_test_split(X_clean.reshape(-1,28*28), y_mixed, train_size=0.55)

# Initialize a random forest model with max_leaf_nodes=150
model = RandomForestClassifier(min_samples_leaf=500,max_leaf_nodes=15)
#model = ExtraTreesClassifier(min_samples_leaf=50,max_leaf_nodes=25)

# Train the model on the mixed training set (group 3 randomly assigned to 1 or 2)
if not skip_model_selection:
    best_model, best_score = None, -5
    start_nodes, end_nodes, step_size = 15, 50, 5
    for max_leaf_nodes in range(start_nodes, end_nodes+1, step_size):
        # Train a random forest model with the current value of max_leaf_nodes
        rf_model = RandomForestClassifier(min_samples_leaf=max_leaf_nodes).fit(X_clean_train, y_mixed_train)
        # Evaluate the model on the test set and store the score
        test_score = rf_model.score(X_clean_test, y_mixed_test)
        print(f"Test set accuracy for max_leaf_nodes = {max_leaf_nodes}: {test_score:.3f}")
        # If the current model has a higher score than the previous best model, update the best model and best score
        if test_score > best_score:
            best_model = rf_model
            best_score = test_score
    print("----")
    print(f"Best test set accuracy: {best_score:.3f}")
    # Set the model to the best model found during model selection
    model = best_model

# Fit the model to the mixed set (group 3 is randomly assigned to 1 or 2)
print('Fitting Random Forest classifier...')
model.fit(X_clean, y_mixed);
print('Fitting complete.')

In [None]:
# Select a random subset of the clean input features and labels
subset_size = 5000
subset_indices = np.random.choice(range(X_clean.shape[0]), size=subset_size, replace=False)
X_subset, y_subset = X_clean[subset_indices], y_clean[subset_indices]

# Compute the RF-Kernel-Matrix for the subset using the trained random forest model
print('Computing Random Forest kernel matrix...')
rf_kernel_matrix = np.zeros((subset_size, subset_size))

leaf_indices = model.apply(X_subset.reshape(-1,28*28))
for leaf_vector in leaf_indices.T:
    # Compute the pairwise similarity between leaf indices using boolean array comparison
    rf_kernel_matrix += leaf_vector[:, None] == leaf_vector[None, :]

# Normalize the RF-Kernel-Matrix by the number of decision trees in the random forest
rf_kernel_matrix /= leaf_indices.shape[1]
print('Computing complete.')

In [None]:
# Sort the subset of input features and labels by their ground truth class
sorted_indices_by_class = np.argsort(y_subset)

# Display the RF-kernel matrix, with samples grouped by class
plt.set_cmap("viridis")
plt.matshow(rf_kernel_matrix[sorted_indices_by_class, :][:, sorted_indices_by_class])

# Get the unique labels and their counts in the sorted subset
unique_labels, label_counts = np.unique(y_subset, return_counts=True)

# Sort the unique labels and their counts by label
sort_indices = np.argsort(unique_labels)
unique_labels = unique_labels[sort_indices]
label_counts = label_counts[sort_indices]

# Calculate the midpoint of each group of samples
midpoints = np.cumsum(label_counts) - label_counts / 2 - 1

# Add x and y ticks with labels for each groupplt.xticks(midpoints, unique_labels)
plt.xticks(midpoints, [labels[i] for i in unique_labels])
plt.xlabel('Fahion Piece Class')
plt.yticks(midpoints, [labels[i] for i in unique_labels])
plt.ylabel('Fahion Piece Class')

# Add a title
plt.title('Similarities from RF classifier sorted by pieces')
plt.show();

This plot visualizes the computed similarity between pairs of samples derived from the classification model. High similarity is encoded by bright colors. The plot shows the samples sorted by their ground truth classes, which are mapped to labels indicating whether they occurred before or after the change point, or both. The original digit classes are unknown to the classifier. We can see that the model is able to distinguish well between the drifiting behaviou and original classes. Note that even though Sneaker and Ankle boot was randomly assigned to either group 1 or 2, the model was able to separate this class from the others very clearly resulting in a single block, which indicates that samples of both clases are comprably similar compared to the other classes with different drift profile.

## t-SNE Visualization of Random Forest Kernel Matrix Embeddings with Grouped Pices

In the previous section, we computed the similarity matrix of the random forest classifier and plotted it. In this section, we apply t-SNE to the first five principal components of the eigenvectors of the RF-kernel-matrix and visualize the results. We show that the resulting plot exhibits clear cluster structure, corresponding to the different digit types. We also show that the color of each point in the plot can be interpreted as the probability of the point occurring before or after the change point. This plot allows us to visually explore the relationship between the digit types and their temporal context.

In [None]:
# Compute the eigenvalues and eigenvectors of the RF-kernel-matrix
print('Computing eigenvalues and eigenvectors...')
embedding_eigenvalues, embedding_eigenvectors = eigs(rf_kernel_matrix, k=6, which='LM')
print('Computing eigenvalues and eigenvectors complete.')

# Apply t-SNE to the first 5 principal components of the eigenvectors of the RF-Kernel-Matrix
print('Fitting t-SNE model...')
tsne_model = TSNE(init='random', learning_rate='auto')
tsne_embedding = tsne_model.fit_transform(np.real((embedding_eigenvectors * (embedding_eigenvalues**0.5)[None,:])))
print('Fitting complete.')

In [None]:
# Specify the percentage of samples to display for each label
percentage = 30

# Set the colormap to 'coolwarm' and plot the samples using different markers and colors for each label
plt.figure(figsize=(12,10))
plt.set_cmap("coolwarm")
for label in set(y_subset):
    selection_label = y_subset == label
    # Calculate the number of samples to display based on the specified percentage
    num_samples = int(np.ceil(np.sum(selection_label) * percentage / 100))
    # Select a random subset of the samples with the current label
    selection_random = np.random.choice(np.where(selection_label)[0], size=num_samples, replace=False)
    if selection_label.sum() > 0:
        # Plot the samples with the current label, using the label as the marker and the model's predicted probability 
        #  as the color
        plt.scatter(tsne_embedding[selection_random][:,0], tsne_embedding[selection_random][:,1], marker=["o","*","^","v","","","","X","","P"][label],
                    #label=["T-Shirt","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot"][label],
                    c=model.predict_proba(X_subset[selection_random].reshape(-1,28*28))[:,0], alpha=0.5, s=40)

# Add legend (use full dataset for better estimation of mean value)
for label in range(6):
    plt.plot([], [], ["o","*","^","v","X","P"][label],
            label=labels[[0,1,2,3,7,9][label]],
            color=plt.get_cmap()(model.predict_proba(X[y == [0,1,2,3,7,9][label]].reshape(-1,28*28))[:,0].mean()))
        
# Display the plot
plt.legend()
plt.xlabel("t-SNE dimension 1")
plt.ylabel("t-SNE dimension 2")
plt.title(f't-SNE Embedding of MNIST Data Colored by Random Forest Predictions ({percentage}% of the dataset)')
plt.colorbar()
plt.show()

The t-SNE plot shows a clear clustering structure according to the digit types. Each cluster can be identified in its time context by the color that is computed from the RF prediction. This color coding gives rise to the two time dependent groups. Note that the structure that we obtained by the kernel analysis has two main advantages over solely evaluating the color coded class probability: It helps us distinguish between different kinds of data points that would be indistinguishable when looking at the predicted class probabilities only. Furthermore it helps identify time-independent clusters even if the classifier does not predict each of its samples with 0.5.

## Drift Explanations using Counterfactuals

So far, we have considered a global explanatory scheme using dimensionality reduction. This can be helpful when the global structure of the drift is of interest. However, especially for non-experts, a sampling-based explanation may be easier to understand. In what follows, we will demonstrate such explanations using counterfactuals that provide us with the information in which sense we need to modify certain samples to change their temporal assignment. As a result, we can grasp the important features by contrasting the resulting pair of samples.

To apply this explanatory scheme, we must first identify particularly relevant sample points. This is done by applying prototype-based clustering algorithms (k-means in this case) to the metric induced by random forest. In addition, we identify the data points that are acutely affected by the drift and require an explanation.

In [None]:
theta_decision = 0.6

# Compute Euclidean embedding of 
proj = np.real( (embedding_eigenvectors * embedding_eigenvalues**0.5) )

# Compute drift localization / drift regions
y_pred = model.predict_proba(X_subset)[:,0] 
has_drift = 2*np.abs(y_pred-0.5) > 1-theta_decision
before_drift = y_pred > 0.5
regions = (has_drift*(2*before_drift-1)).astype(int)

# Cluster projected data to obtain characteristic samples
cluster = MiniBatchKMeans(n_clusters=6).fit_predict(proj)

As a quick sanity check we check the alignment of the found prototypes with the fashion pice classes and the (true) temporal classes (before, after, both).

In [None]:
# Sanity check using true label information (not avaiilable in practice)
confusion_mtx_true_label = np.zeros( (np.unique(y_subset).shape[0],np.unique(cluster).shape[0]) )
for i,true_label in enumerate(np.unique(y_subset)):
    for j,cluster_id in enumerate(np.unique(cluster)):
        confusion_mtx_true_label[i,j] = ((y_subset==true_label)*(cluster==cluster_id)).sum() / (cluster==cluster_id).sum()
plt.matshow(confusion_mtx_true_label)
plt.title("Confusion Matrix Prototype vs. True Label")
plt.xlabel("Prototype")
plt.ylabel("Label")
plt.yticks(range(confusion_mtx_true_label.shape[0]), [labels[label] for label in np.unique(y_subset)])
plt.show()

# Sanity check using drift information
confusion_mtx_time_label = np.zeros( (3,np.unique(cluster).shape[0]) )
for i,time_label in enumerate(range(1,4)):
    for j,cluster_id in enumerate(np.unique(cluster)):
        for true_label in [k for k,v in label_map.items() if v == time_label]:
            confusion_mtx_time_label[i,j] += ((y_subset==true_label)*(cluster==cluster_id)).sum() 
        confusion_mtx_time_label[i,j] /= (cluster==cluster_id).sum()
plt.matshow(confusion_mtx_time_label)
plt.title("Confusion Matrix Prototype vs. Timepoint")
plt.xlabel("Prototype")
plt.ylabel("Timepoint")
plt.yticks([0,1,2],["before drift","after drift","both"])
plt.show()

We determine the most relevant data points by selecting those that are closest to the obtained prototypes and belong to a particular temporal class as predicted by the model. To measure proximity, we use the Euclidean distance or the distance induced by the random forest.

In [None]:
# Compute Characteristic Samples and Counterfactual
counterfactuals = []
for cluster_id in np.unique(cluster):
    cf = dict()
    cf["region_dist"] = np.eye(3)[regions[cluster == cluster_id]+1].mean(axis=0)
    cf["region"] = [-1,0,1][np.argmax(cf["region_dist"])]
    cf["sample_euc"] = dict()
    cf["sample_rfk"] = dict()
    cf["proto_rfk"] = proj_proto = proj[cluster == cluster_id].mean(axis=0)
    cf["proto_euc"] = X_proto = X_subset[cluster == cluster_id].mean(axis=0)
    for region in [-1,0,1]:
        sel = np.where(regions == region)[0] 
        cf["sample_euc"][region] = sel[np.argmin( ((X_subset[sel] - X_proto)**2).sum(axis=1) )]
        cf["sample_rfk"][region] = sel[np.argmin( ((proj[sel] - proj_proto)**2).sum(axis=1) )]
    counterfactuals.append(cf)

In [None]:
# Plot Characteristic Samples 
fig = plt.figure(figsize=(5*len(counterfactuals),5))
for cf,frames in zip(counterfactuals,list(fig.subplots(2,len(counterfactuals)).T)):
    for cf_type,frame in zip(["sample_euc","sample_rfk"],list(frames)):
        frame.imshow(X_subset[cf[cf_type][cf["region"]]].reshape(28,28))
        frame.set_xticks([])
        frame.set_yticks([])
plt.show()

In [None]:
# Plot Counterfactuals
drifting_counterfactuals = list(filter(lambda cf: cf["region"] != 0, counterfactuals))

for cf_type in ["sample_euc","sample_rfk"]:
    print(cf_type)
    fig = plt.figure(figsize=(5*len(drifting_counterfactuals),5))
    for cf,frames in zip(drifting_counterfactuals,list(fig.subplots(2,len(drifting_counterfactuals)).T)):
        for rel_region,frame in zip([1,-1],list(frames)):
            frame.imshow(X_subset[cf[cf_type][rel_region*cf["region"]]].reshape(28,28))
            frame.set_xticks([])
            frame.set_yticks([])
    plt.show()

As can be seen, before the drift we have pullovers and dresses, after the drift we have pants and shirts. This matches the initial setting. Also, as expected, the Euclidean distance tries to fit the shape of the fashion pieces.