In [None]:
import numpy as np
import matplotlib.pyplot as plt

# The One Goal For Today

To understand how k means clustering works.


# Clustering

Clustering algorithms separate a data set into groups, or clusters, that are near each other (similar to each other) using a distance metric we choose or define.

When we fit a regression model, we know the dependent variable - the label, the answer. When we cluster, we don't. Clustering is therefore an *unsupervised* method, although we as data scientists can make all kinds of decisions to influence it.


# K-means clustering

K-means clustering is one clustering algorithm. It divides the data into $k$ clusters. Each cluster has a *centroid*, or central point (or "mean"). 

## What we want to optimize

The points in the dataset are assigned to their closest centroid. In other words, we want to minimize the distances between data points and the centroids they are assigned to, across the whole data set. We want to minimize the *inertia*:
$1/N \sum_{j=1}^N d(\vec{x_j}, \vec{m_{\vec{x_j}}})^2$, where $\vec{m_{x_j}}$ is the centroid of the cluster that $x_j$ is currently assigned to, and $d$ is your chosen distance metric.

*What does this look like?*

## The algorithm

To make this computationally efficient, we calculate an approximate solution by iteration:
1. Pick initial centroids.
2. For step in range(max_steps):
   1. Assign each point to its closest centroid.
   2. Pick new centroids using the members in each cluster. If the centroids don't change, then return.

*Is this guaranteed to converge?*

## The 'hyperparameters'

Things a data scientist can do to influence the k-means clustering (by __looking at the data__):
* choose distance metric
* choose $k$
* choose starting points, or subset of data from which starting points should come

In [None]:
# Let's define a distance metric; which one is this??
def distance(a, b):
    # take the difference of the two data points
    subtracted = a-b
    # returns of the sqrt of the sum of each of the elements squared = euclidean distance
    return np.sqrt(np.dot(subtracted.T, subtracted))

In [None]:
# Let's define a function to calculate the distance from each data point to each centroid
def get_distances(item, centroids):
    # gets the distance between this item and each centroid
    distances = [distance(item, centroid) for centroid in centroids]
    # returns the distances
    return distances

# Let's define a function to update cluster assignments given a set of centroids
def update_clusters(data, centroids):
    # list of data points, list of centroids
    # map a centroid index to a list of data points
    clusters = {}
    # map a data point index to a centroid index
    mappings = {}
    # for the index of each centroid
    for i, x in enumerate(centroids):
        # set its cluster members to the empty list
        clusters[i] = []
    # for each data point
    for j, datum in enumerate(data):
        # figure out the distances between that data point and each centroid ***in order***
        distances = get_distances(datum, centroids)
        # find the index of the smallest distance = the index of the centroid with the smallest distance to this data point
        min_cluster_index = np.argmin(distances)
        # add this data point to that centroid's cluster
        clusters[min_cluster_index].append(datum)
        # set the centroid for this data point to be that centroid's index
        mappings[j] = min_cluster_index
    return clusters, mappings

In [None]:
# Let's define a function to update the centroids; data is a list of data points; clusters is a dictionary of centroid->datapoints mappings
def update_centroids(clusters):
    # set centroids to empty list
    centroids = []
    # for each set of data points in a cluster around a single centroid
    for data_in_cluster in clusters.values():
        # new centroid is the mean of that cluster
        centroids.append(np.mean(data_in_cluster, axis=0))
    return centroids

In [None]:
# Let's define a function to measure the inertia
def inertia(data, centroids, clusters):
    sum = 0
    for i in clusters.keys():
        # get the centroid
        centroid = centroids[i]
        # get each data point
        for datum in clusters[i]:
            # calculate the distance squared
            sum += distance(datum, centroid)**2
    # average over the data
    return sum / len(data)

## Let's try it on some toy data!

In [None]:
# Let's get some toy data
data = np.array([[1, 1], [2, 0.5], [1.5, 2], [3, 1.5], [3.5, 1.75], [4, 3.6], [4.25, 4], [5, 3.5]])

# Let's look at the data
plt.scatter(data[:, 0], data[:, 1])
plt.show()

In [None]:
def plot_clusters(data, mappings, centroids):
    plt.scatter(data[:, 0], data[:, 1], c=list(mappings.values()))
    print(centroids)
    for i, centroid in enumerate(centroids):
        plt.scatter(centroid[0], centroid[1], marker=i)
    plt.show()

In [None]:
# Let's pick k = 2
k = 2

# Let's pick k points to be centroids, at random
centroidids = np.random.choice(np.arange(len(data)), size=k, replace=False)
centroids = [data[x] for x in centroidids]
print("initial centroids")
print(centroids)

# Initially, only the centroids are in any cluster
clusters, mappings = update_clusters(data, centroids)
print(clusters)
plot_clusters(data, mappings, centroids)

In [None]:
# Let's loop over updating the centroids and plotting
while input() != 'stop':
    centroids = update_centroids(clusters)
    print(centroids)
    clusters, mappings = update_clusters(data, centroids)
    print(clusters)
    print(inertia(data, centroids, clusters))
    plot_clusters(data, mappings, centroids)

## Choosing k

K nearest neighbors can be frustrating, especially with high dimensional data, because you have to choose a value for k. How can you do it, if you can't visualize all the data?

You can inspect an elbow plot of inertia against k, starting with a small k and increasing.

Even if you use this method, it's still important to __look at your data__.

## Resources
* For a list of lots of clustering algorithms, see https://scikit-learn.org/stable/modules/clustering.html