In [None]:
# Import some basic libraries
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_context('paper')
# The following code is a modification of the code found here:
# https://stackoverflow.com/questions/35651932/plotting-img-with-matplotlib
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.cbook import get_sample_data
def imscatter(x, y, images, cmap=plt.cm.gray_r, ax=None, zoom=1):
    x, y = np.atleast_1d(x, y)
    artists = []
    for x0, y0, image in zip(x, y, images):
        im = OffsetImage(image, zoom=zoom, cmap=cmap, interpolation='nearest')
        ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False)
        artists.append(ax.add_artist(ab))
    ax.update_datalim(np.column_stack([x, y]))
    ax.autoscale()
    return artists

# Hands-on Activity 13.2 - Clustering High-dimensional Data

## Objectives

+ Combine principal component analysis with k-means to solve high-dimensional clustering problems

In this hands-on activity we are going to cluster the MNIST dataset.
We cannot apply K-means directly on it because of its high-dimensionality. If we did, we would get garbage.
Instead, we are going to first reduce the dimensionality of MNIST using PCA to two dimensions and then we will apply K-means on the principal components.

Note that, in contrast to the previous hands-on activity, we are going to work with the entire training set and not just one digit.
So, we know that there are 10 clusters (the digits from 0 to 9).
Let's see if the process we follow identifies clusters that correspond to digits...
Here we go. First, download and load the MNIST data:

In [None]:
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

Perfortm PCA on the entire data set keep two principal components:

In [None]:
from sklearn.decomposition import PCA
vectorized_x_train = x_train.reshape((x_train.shape[0], 28 * 28))
pca = PCA(n_components=2, whiten=True).fit(vectorized_x_train)

That's it. Let's now visualize the principal components:

In [None]:
fig, ax = plt.subplots(dpi=150)
Z = pca.transform(vectorized_x_train[:3000])
imscatter(Z[:, 0], Z[:, 1], x_train[:3000], ax=ax, zoom=0.2)
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$');

You can visually observe that PCA somewhat separates the digists.
It's not perfect (and you can do better with non-linear dimensionality reduction techniques), but it would do for now.

Now it's time for K-means:

In [None]:
from sklearn.cluster import KMeans
cl = KMeans(n_clusters=10).fit(Z)

Let's visualize the clusters:

In [None]:
fig, ax = plt.subplots(dpi=150)
for i, c in zip(range(10), ['Greys_r', 'Blues_r', 'BrBG', 'BuGn', 'BuPu', 'Greens_r', 'Dark2', 'GnBu',
                           'Reds_r', 'Set1', 'Spectral']):
    idx = cl.labels_[:3000] == i
    imscatter(Z[:3000][idx, 0], Z[:3000][idx, 1], x_train[:3000][idx], cmap=c, ax=ax, zoom=0.2);

This is nice! Observe that the the clusters look very reasonable.
Again, they are not perfect but the make sense.
Even in the cases that we know are not classified correctly, the errors are not ridiculously bad. As a matter of fact, the results are quite impressive if one takes into account that the algorithm we have put together does not know what digits are...

Okay. Now let's look at the cluster centers a bit more closely.
Let's visualize them as images.

In [None]:
# What do the cluster means look like?
for i in range(10):
    fig, ax = plt.subplots(dpi=28)
    ax.imshow(pca.inverse_transform(cl.cluster_centers_[i:(i+1), :]).reshape((28,28)),
                                    cmap=plt.cm.gray_r, interpolation='nearest')
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])

### Questions

+ Which cluster centers are digits and which aren't digits? Pick one of the non-digists clusters and figure out where it is. You can do this by looking at ``cl.centers_`` to identify the coordinates of the center. Why was it picked? Does its shape make sense now?

+ Repeat the analysis above but using 3 principal components (instead of 2). (Note that the 2D visualization of the principal components will not make much sense now, so take it with a grain of salt.) Pay special attention the identified cluster centers as images. Better or worse than before?

+ Repeat the analysis with 5 principal components.

+ Repeat the analysis with 200 principal components.