# MNIST handwritten digits clustering and anomaly detection

In this notebook, we'll use unsupervised learning (clustering and anomaly detection) to analyze MNIST digits using scikit-learn.

First, the needed imports. 

In [None]:
%matplotlib inline

import numpy as np
from sklearn import datasets, __version__
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor
from sklearn.metrics import adjusted_rand_score

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from distutils.version import LooseVersion as LV
assert(LV(__version__) >= LV("0.20")), "Version >= 0.20 of sklearn is required."

Then we load the MNIST data. First time it downloads the data, which can take a while.

To speed up the computations, let's use only 10000 digits in this notebook.

In [None]:
mnist = datasets.fetch_openml('mnist_784')

X = mnist['data'][:10000]
y = mnist['target'][:10000]
print()
print('MNIST data loaded:')
print('X:', X.shape)
print('y:', y.shape)

## Clustering

### k-means

K-means clusters data by trying to separate samples in *k* groups of equal variance using an iterative two-step algorithm. It requires the number of clusters as a parameter.

In [None]:
%%time

n_clusters=10
kmeans = KMeans(n_clusters=n_clusters)
kmeans.fit(X)

The sizes of the clusters:

In [None]:
plt.hist(kmeans.labels_, bins=range(n_clusters+1), rwidth=0.5)
plt.xticks(0.5+np.arange(n_clusters), np.arange(n_clusters))
plt.title('Cluster sizes');

The k-means centroids are vectors in the same space as the original data, so we can take a look at them:

In [None]:
plt.figure(figsize=(n_clusters, 1))

for i in range(n_clusters):
    plt.subplot(1,n_clusters,i+1)
    plt.axis('off')
    plt.imshow(kmeans.cluster_centers_[i,:].reshape(28,28), cmap="gray")
    plt.title(str(i))

Let's also draw some digits from each cluster:

In [None]:
n_img_per_row = 32 # 32*32=1024
img = np.zeros((28 * n_clusters, 28 * n_img_per_row))

for i in range(n_clusters):
    ix = 28 * i
    X_cluster = X[kmeans.labels_==i,:]
    try:
        for j in range(n_img_per_row):    
            iy = 28 * j
            img[ix:ix + 28, iy:iy + 28] = X_cluster[i * n_img_per_row + j,:].reshape(28,28)
    except IndexError:
        pass

plt.figure(figsize=(12, 12))
plt.imshow(img, cmap='gray')
plt.title('Some MNIST digits from each k-means cluster')
plt.xticks([])
plt.yticks([])
plt.ylabel('clusters');

#### Evaluation

Since we know the correct labels for MNIST digits, we can evaluate the quality of the clustering.

In [None]:
print("Adjusted Rand index: %.3f"
      % adjusted_rand_score(y, kmeans.labels_))

### Hierarchical clustering

Hierarchical clustering is a family of clustering algorithms that build nested clusters by merging or splitting them successively.

The `linkage` criteria determines the metric used for the merge strategy:
* `ward` minimizes the sum of squared differences within all clusters
* `complete` linkage minimizes the maximum distance between observations of pairs of clusters
* `average` linkage minimizes the average of the distances between all observations of pairs of clusters
* `single` linkage minimizes the distance between the closest observations of pairs of clusters

In [None]:
%%time

n_clusters=10
linkage = "ward"
hclust = AgglomerativeClustering(n_clusters=n_clusters, linkage=linkage)
hclust.fit(X)

The sizes of the clusters:

In [None]:
plt.hist(hclust.labels_, bins=range(n_clusters+1), rwidth=0.5)
plt.xticks(0.5+np.arange(n_clusters), np.arange(n_clusters))
plt.title('Cluster sizes');

Some digits from each cluster:

In [None]:
n_img_per_row = 32 # 32*32=1024
img = np.zeros((28 * n_clusters, 28 * n_img_per_row))

for i in range(n_clusters):
    ix = 28 * i
    X_cluster = X[hclust.labels_==i,:]
    try:
        for j in range(n_img_per_row):    
            iy = 28 * j
            img[ix:ix + 28, iy:iy + 28] = X_cluster[i * n_img_per_row + j,:].reshape(28,28)
    except IndexError:
        pass
            
plt.figure(figsize=(12, 12))
plt.imshow(img, cmap='gray')
plt.title('Some MNIST digits from hierarchical clustering with {} linkage'.format(linkage))
plt.xticks([])
plt.yticks([])
plt.ylabel('clusters');

#### Evaluation

In [None]:
print("Adjusted Rand index: %.3f"
      % adjusted_rand_score(y, hclust.labels_))

## Anomaly detection
### Isolation forest

In [None]:
%%time

isofor = IsolationForest(contamination=0.01, behaviour='new')
predictions = isofor.fit(X).predict(X)
print('Number of anomalies:', np.sum(predictions==-1))

In [None]:
n_img_per_row = 32 # 32*32=1024
img = np.zeros((28 * 2, 28 * n_img_per_row))
anolabels = [-1, 1]

for i in range(2):
    ix = 28 * i
    X_ano = X[predictions==anolabels[i], :]
    try:
        for j in range(n_img_per_row):    
            iy = 28 * j
            img[ix:ix + 28, iy:iy + 28] = X_ano[i * n_img_per_row + j,:].reshape(28,28)
    except IndexError:
        pass
            
plt.figure(figsize=(12, 12))
plt.imshow(img, cmap='gray')
plt.title('Examples of anomalies (upper row) and normal data (lower row)')
plt.xticks([])
plt.yticks([]);

### Local outlier factor

In [None]:
%%time

lof= IsolationForest(contamination=0.01, behaviour="new")
predictions = lof.fit_predict(X)
print('Number of anomalies:', np.sum(predictions==-1))

In [None]:
n_img_per_row = 32 # 32*32=1024
img = np.zeros((28 * 2, 28 * n_img_per_row))
anolabels = [-1, 1]

for i in range(2):
    ix = 28 * i
    X_ano = X[predictions==anolabels[i], :]
    try:
        for j in range(n_img_per_row):    
            iy = 28 * j
            img[ix:ix + 28, iy:iy + 28] = X_ano[i * n_img_per_row + j,:].reshape(28,28)
    except IndexError:
        pass
            
plt.figure(figsize=(12, 12))
plt.imshow(img, cmap='gray')
plt.title('Examples of anomalies (upper row) and normal data (lower row)')
plt.xticks([])
plt.yticks([]);