## Applying Gaussian Mixture models to the NIST digits data

We have already seen the NIST handwritten digits data from the perspective of classification.

With classification, the digits were classified by labeling them as one of ten digits.

The question arises: are the handwritten digits significantly different, to the extent of being able
to find ten clusters that correspond to the ten digits 0,..,9.

In [None]:
dataDir = "data"
# Make sure the outputDir subdirectory exists
outputDir = "output/GMM_KM"
import os, errno
try:
    os.makedirs(outputDir)
except OSError as e:
    if e.errno != errno.EEXIST:
        raise

In [None]:
from sklearn.datasets import load_digits
digits = load_digits()
digits.data.shape

Remember that each pixel in an image is a candidate feature, so there are 8\*8 = 64 candidate attributes (features).

Clustering in such high dimensional spaces can suffer from the curse of dimensionality.

Therefore we see whether PCA can help to reduce the dimensions, while maintaining most of the information in the data.

We apply a PCA decomposition with the intention of keeping 99% of the variance.

In [None]:
from sklearn.decomposition import PCA
pca = PCA(0.99, whiten=True)
data = pca.fit_transform(digits.data)
data.shape

As can be seen, the transformed data has only 41 attributes (features), which is a significant reduction from 64. 

In [None]:
import matplotlib.pyplot as plt 
from sklearn.mixture import GaussianMixture
import numpy as np
n_components = np.arange(50, 210, 10)
models = [GaussianMixture(n, covariance_type='full', random_state=0)
          for n in n_components]
aics = [model.fit(data).aic(data) for model in models]
plt.plot(n_components, aics)
frame = plt.gca()
frame.axes.get_yaxis().set_visible(False)
plt.savefig(outputDir + '/digitsPCAtransformedNumComponentsAIC.pdf')

As can be seen, there appear to be 150 components in the 'best' model according to the AIC (Akaike Information Criterion - a model validation metric that attempts to balance adding more terms against the danger of overfitting the training data) metric. We duly try a Gaussian Mixture model fit with 150 components and check that the EM algorithm converged in that case.

In [None]:
gmm = GaussianMixture(150, covariance_type='full', random_state=0)
gmm.fit(data)
print(gmm.converged_)

Since it converged, we generate a sample of 100 points from the 150 _cluster density functions_ of handwritten digits. Note that these sample points correspond to (transformed) handwritten digits, but are extremely unlikely to be identical to any of the handwritten digits used to derive the cluster regions/probability functions.

In [None]:
X_new, y_new = gmm.sample(100)
X_new.shape, y_new[:,None].shape

As you can see, `X_new` has 100 of the 41-attribute transformed handwritten digits. We apply the inverse transformation to return the data to 8\*8 pixel space and inspect the resulting images.

In [None]:
import clusSupport
digits_new = pca.inverse_transform(X_new)
clusSupport.plot_digits(digits_new)
plt.savefig(outputDir + '/digitsGenerated110Components.pdf')

As can be seen these are plausible handwritten digits, most of which are recognisable. Each was sampled from a region of one of the 110 Gaussian density functions, and was almost centainly not submitted as an image of an actual handwritten digit. Thus clustering can be used in this way to generate new instances, based on _clustered exemplars_.

One of the difficulties with classification is the need to label the instances in the training set. This is usually a laborious operation, with lots of human intervention. As we saw above, clustering can do a surprisingly good job even without labels. The natural question is: can clustering derive the labels itself?

To answer this question, we use KMeans to assign the handwritten digits to exactly 10 clusters, in the hope that these clusters might coincide with the 10 labels 0,1,..,9.

In [None]:
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=10, random_state=0)
clusters = kmeans.fit_predict(digits.data)
kmeans.cluster_centers_.shape

Now we take the centres of each of the 10 clusters, interpret them as pixel intensities and display them as images.

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(8, 3))
centers = kmeans.cluster_centers_.reshape(10, 8, 8)
for axi, center in zip(ax.flat, centers):
    axi.set(xticks=[], yticks=[])
    axi.imshow(center, interpolation='nearest', cmap=plt.cm.binary)
plt.savefig(outputDir + '/digitsClusterCentres10.pdf')

Amazingly, most of the KMeans cluster centres look like recognisable digits. We assume that each cluster centre is intended to match a particular digit, probably the most common of the labels used with the associated observation in the cluster. For example, if the cluster centre looks like '4', most of the instances that KMeans assigned to that cluster should have been labeled as 4 in the training data. Those that were not labeled as 4 but have been assigned to that cluster might have been assigned incorrectly by KMeans.

In [None]:
from scipy.stats import mode

labels = np.zeros_like(clusters)
for i in range(10):
    mask = (clusters == i)
    labels[mask] = mode(digits.target[mask])[0]

In [None]:
from sklearn.metrics import accuracy_score
accuracy_score(digits.target, labels)

Even though KMeans did not use any labels, it assigned approximately 80% of handwritten digits to the correct cluster, which is far better than might be expected by chance. Note that interpreting cluster membership is outside the scope of a clustering algorithm like KMeans.

In [None]:
import seaborn as sns
from sklearn.metrics import confusion_matrix
mat = confusion_matrix(digits.target, labels)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
            xticklabels=digits.target_names,
            yticklabels=digits.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label')
plt.savefig(outputDir + '/digitsClusterConfusionMatrix10.pdf')

The confusion matrix above indicates where any misclassifications arise. Generally, cluster assignment and classification are very close except for 1's and 8's, which is consistent with the cluster centre images above.