### Handwritten Digits Clustering

In this code snippet, you are going to practice *K-means* clustering using [`scikit-learn`](https://scikit-learn.org), which is the well-known machine learning package in Python. We cluster samples of a dataset, containing 8x8 pixel images of handwritten digits (totally 10 clusters for `0` to `9`). Then, we will see how to assign a new sample to the corresponding cluster by comparing the sample distance to the centroids.

In [None]:
import numpy as np
from sklearn.datasets import load_digits
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split
from utils import plot_images, plot_clusters, plot_centroids

### Step 1. Load Data
The handwritten image dataset in the `scikit-learn` package contains 1797 samples of 10 digits (around 180 samples per class). We use [`load_digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) function to load the dataset. 

In [None]:
X, y = load_digits(return_X_y=True)
print(np.shape(X))
plot_images(X)

#### Split Test and Train Sets
Using the [`train_test_split`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html) in `scikit-learn.model_selection`, you can shuffle the dataset randomly; then, split the dataset into train and test sets according to your desired train or test size.

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=20)

print("Train set: ")
print(np.shape(X_train))
plot_images(X_train)

print("Test set: ")
print(np.shape(X_test))
plot_images(X_test)

### Step 2. K-means Clustering
The [`KMeans`](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans) in the `scikit-learn` package is convenient to use. The init function to initialize an instance of the class is defined as follows:

`KMeans(n_clusters, n_init, max_iter)`
* `n_clusters`: The number of clusters to form as well as the number of centroids to generate.
* `n_init`: Number of time the k-means algorithm will be run with different centroid seeds. The final results will be the best output of n_init consecutive runs in terms of inertia.
* `max_iter`: Maximum number of iterations of the k-means algorithm for a single run.

Then, the [`fit(X=input)`](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.fit) function clusters the input into groups.

In [None]:
kmeans_obj = KMeans(n_clusters=10, n_init=50, max_iter=100)
clusters_train = kmeans_obj.fit(X_train)
plot_clusters(X_train, clusters_train.labels_)

### Step 3. Test New Samples

The [`predict`](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans.predict) function evaluates unseen samples to predict the closest cluster each sample in X belongs to. 

In [None]:
clusters_test = kmeans_obj.predict(X_test)
plot_images(X_test)
print(clusters_test)

In [None]:
plot_centroids(clusters_train, clusters_test)