In [None]:
import numpy as np
import matplotlib.pyplot as plt

# K-means for image compression; choosing k

An image is a matrix of pixels. Each pixel can be represented as an array of RGB (red, green, blue) values. 

In most pictures, the RGB values in one pixel are likely to be similar to those in the pixels around it. For example, the below picture (from Wikimedia) has a lot of white pixels; even in the cat's fur, the grey and black tend to blend together.

![a cat](https://upload.wikimedia.org/wikipedia/commons/4/4d/Cat_March_2010-1.jpg)

If we can identify the common colors (the common RGB values) in a picture, then we can replace the RGB values for each pixel by just centroid RGB values. This in turn allows us to:
* compress the image - we can use the centroid IDs as the pixel's values, and just store the centroid RGB values in a dictionary
* segment the image into regions by color

To do this, we are going to treat the image itself as a dataset. Each row will correspond to a pixel, and the columns will be the RGB values.

## Let's do k-means on a picture!

There are __three bugs__ in the code directly below. How quickly can you find them?

In [None]:
# Most of this comes from day 22

# Let's define a distance metric; which one is this??
def distance(a, b):
    subtracted = a-b
    return np.sqrt(np.dot(subtracted.T, subtracted))

# Let's define a function to calculate the distance from each data point to each centroid
def get_distances(item, centroids):
    distances = [distance(item, centroid) for centroid in centroids]
    return distances

# Let's define a function to update cluster assignments given a set of centroids
def update_clusters(data, centroids):
    return [np.argmin(get_distances(item, centroids)) for item in data]

# Let's define a function to update the centroids
def update_centroids(data, clusters):
    with_clusters = np.hstack((data, np.array([clusters]).T))
    indices = np.argsort(with_clusters[:, -1])
    with_clusters_sorted = with_clusters[indices]
    by_cluster = np.array_split(with_clusters_sorted, np.where(np.diff(with_clusters_sorted[:, -1])!=0)[0]+1)
    return np.array([np.mean(cluster[:, :-1], axis=0) for cluster in by_cluster])

# Let's define a function to measure the inertia
def inertia(data, centroids, clusters):
    sum_squares = 0
    for i in range(len(data)):
        sum_squares += distance(data[i], centroids[clusters[i]])**2
    return sum_squares / len(data)

In [None]:
# We have to load the picture
from matplotlib.image import imread

image = imread('data/wikimedia_cat.jpg')
print(image.shape)

In [None]:
# We have to reshape it
data = image.reshape(image.shape[0]*image.shape[1], image.shape[2])
print(data.shape)

In [None]:
%%time

# We do k-means clustering for k from 2 to 16 to see what is a good value for k
# Watch how it takes longer and longer (both to converge, and to do one round) as k increases
inertia_by_k = []

for k in range(2, 17, 2):
    print(k)
    centroids = np.array([data[x] for x in np.random.choice(np.arange(len(data)), size=k, replace=False)])
    clusters = update_clusters(data, centroids)
    this_inertia = inertia(data, centroids, clusters)
    last_inertia = this_inertia + 1
    while abs(last_inertia - this_inertia) > 0.01:
        last_inertia = this_inertia
        centroids = update_centroids(data, clusters)
        clusters = update_clusters(data, centroids)
        this_inertia = inertia(data, centroids, clusters)
        print(this_inertia)
    inertia_by_k.append([k, this_inertia])


In [None]:
inertia_by_k = np.array(inertia_by_k)
print(inertia_by_k)
fig = plt.figure(figsize=(6,4))
ax1 = fig.add_subplot(111)
ax1.plot(inertia_by_k[:, 0], inertia_by_k[:, 1])
ax1.set_xlabel('k')
ax1.set_ylabel('Inertia')
ax1.set_title('Elbow Plot')
plt.show()

## Let's replace each pixel in the picture with its corresponding centroid, and then display it!

In [None]:
def rebuild(centroids, clusters, input):
    # for each item in clusters add the corresponding centroid
    res = np.array([centroids[x] for x in clusters])
    print(res.shape, input.shape)
    # reshape
    res = res.reshape(input.shape[0], input.shape[1], input.shape[2])
    return res

In [None]:
plt.imshow(image)

In [None]:
output = rebuild(centroids, clusters, image)
print(output.shape)
output = output.round().astype(int)
plt.imshow(output.round())

In [None]:
# Load the big version of the image
image = imread('data/wikimedia_cat_large.jpg')
print(image.shape)
plt.imshow(image)

# We have to reshape it
data = image.reshape(image.shape[0]*image.shape[1], 3)
print(data.shape)

In [None]:
%%time

# Do k-means with k=14; why 14?
k = 14

centroids = np.array([data[x] for x in np.random.choice(np.arange(len(data)), size=k, replace=False)])
clusters = update_clusters(data, centroids)
this_inertia = inertia(data, centroids, clusters)
last_inertia = this_inertia + 1
while abs(last_inertia - this_inertia) > 0.01:
    last_inertia = this_inertia
    centroids = update_centroids(data, clusters)
    clusters = update_clusters(data, centroids)
    this_inertia = inertia(data, centroids, clusters)
    print(this_inertia)

In [None]:
# Let's display the output!
output = rebuild(centroids, clusters, image)
print(output.shape)
output = output.round().astype(int)
plt.imshow(output.round())

# Question

If you replace each pixel in the above image with an integer (the index of the corresponding centroid), and keep the centroid dictionary, how much smaller is the picture size than the original?

# K-means++

In addition to choosing the distance metric, and choosing $k$, we can change the way we initialize k-means. So far we've tried:
* random - what if it's the wrong random? Try several random initializations, take the best one
* get someone to label a little bit of the data to get an idea - what if the subset they label is not representative?

Now we will look at a third way. It goes like this:
1. Choose a single random data point as the first centroid, $\vec{c_1}$.
2. Repeat til $k$
  1. Calculate the distance between each data point $\vec{x_i}$ and its nearest previously chosen centroid $\vec{c_i}$.
  2. Pick the next centroid according to $p(\vec{x_i}) = \frac{d(\vec{x_i}, \vec{c_i})}{\sum_{j=1}^N d(\vec{x_j}, \vec{c_i})}$. (In python, you can get the index of this next centroid using index = np.random.choice([0, 1, ..., N], p=$p(x_i)$). You can also square the distances in the numerator and denominator above, to spread the distances out further.)

This means that we are more likely to choose a further away data point as the next centroid.


# Resources

* https://github.com/hanyoseob/python-k-means