# The One Goal For Today

To understand how we can use k-means clustering for image segmentation and image compression.

In [1]:
import numpy as np
import matplotlib.pyplot as plt

# K-means for image segmentation and compression

An image is a matrix of pixels. Each pixel can be represented as an array of RGB (red, green, blue) values. 

In most pictures, the RGB values in one pixel are likely to be similar to those in the pixels around it. For example, the below picture (from https://vancouver.citynews.ca/2016/11/22/dangers-parking-lot-crashes-ahead-holiday-rush/) has a lot of grey pixels.

![cars in a car park](https://www.citynews1130.com/wp-content/blogs.dir/sites/9/2016/11/22/parking.jpg)

If we can identify the common colors (the common RGB values) in a picture, then we can replace the RGB values for each pixel by just centroid RGB values. This in turn allows us to:
* compress the image - we can use the centroid IDs as the pixel's values, and just store the centroid RGB values in a dictionary
* segment the image into regions by color

To do this, we are going to treat the image itself as a dataset. Each row will correspond to a pixel, and the columns will be the RGB values.

## Let's do k-means on a picture!

This code below comes from day 21, except that there's a slightly more efficient implementation of update_clusters.

In [2]:
# Euclidean distance
def distance(a, b):
    subtracted = a-b
    return np.sqrt(np.dot(subtracted.T, subtracted))

# Calculate the distance from each data point to each centroid
def get_distances(item, centroids):
    return [distance(item, centroid) for centroid in centroids]

# Update cluster assignments given a set of centroids
# This is a slightly more efficient version than the one from Monday
def update_clusters(data, centroids):
    # initialize clusters
    clusters = {}
    for i in range(len(centroids)):
        clusters[i] = []
    # for each data point
    for datum in data:
        # find the index of the centroid with the smallest distance to this data point, and add this data point to that centroid's cluster
        clusters[np.argmin(get_distances(datum, centroids))].append(datum)
    return clusters

# Update the centroids given the data
def update_centroids(clusters):
    # set centroids to empty list
    centroids = []
    # for each set of data points in a cluster around a single centroid
    for data_in_cluster in clusters.values():
        # new centroid is the mean of that cluster
        centroids.append(np.mean(data_in_cluster, axis=0))
    return centroids

# Measure the inertia
def inertia(data, centroids, clusters):
    sum = 0
    for i in clusters.keys():
        for datum in clusters[i]:
            # calculate the distance squared between each data point and its centroid
            sum += distance(datum, centroids[i])**2
    # average over the data
    return sum / len(data)

In [3]:
# Load the picture
from matplotlib.image import imread

image = imread('data/parking.jpg')
print(image.shape)

(216, 302, 3)


In [4]:
# Reshape the picture so that it's three long arrays, one for each of R/G/B
data = image.reshape(image.shape[0]*image.shape[1], image.shape[2])
print(data.shape)

(65232, 3)


In [5]:
%%time

# We do k-means clustering for k from 2 to 16 to see what is a good value for k
# Watch how it takes longer and longer (both to converge, and to do one round) as k increases
inertia_by_k = []

for k in range(2, 17, 4):
    print(k)
    # make some initial centroids
    centroids = np.array([data[x] for x in np.random.choice(np.arange(len(data)), size=k, replace=False)])
    # get the clusters for these centroids
    clusters = update_clusters(data, centroids)
    # calculate the inertia for this clustering
    this_inertia = inertia(data, centroids, clusters)
    # initialize last_inertia so we go around at least once
    last_inertia = this_inertia + 1
    # stop when the inertia stops changing very much
    while abs(last_inertia - this_inertia) > 0.01:
        last_inertia = this_inertia
        # update the centroids
        centroids = update_centroids(clusters)
        # update the clusters
        clusters = update_clusters(data, centroids)
        # update the inertia
        this_inertia = inertia(data, centroids, clusters)
    inertia_by_k.append([k, this_inertia])


2
6


In [None]:
print(inertia_by_k)
inertia_by_k = np.array(inertia_by_k)
fig = plt.figure(figsize=(6,4))
ax1 = fig.add_subplot(111)
ax1.plot(inertia_by_k[:, 0], inertia_by_k[:, 1])
ax1.set_xlabel('k')
ax1.set_ylabel('Inertia')
ax1.set_title('Elbow Plot')
plt.show()

## Let's replace each pixel in the picture with its corresponding centroid, and then display it!

In [None]:
def rebuild(centroids, clusters, input):
    # for each item in clusters add the corresponding centroid
    res = np.array([centroids[x] for x in clusters])
    print(res.shape, input.shape)
    # reshape
    res = res.reshape(input.shape[0], input.shape[1], input.shape[2])
    return res

In [None]:
plt.imshow(image)

In [None]:
output = rebuild(centroids, clusters, image)
print(output.shape)
output = output.round().astype(int)
plt.imshow(output.round())

*If you replace each pixel in the above image with an integer (the index of the corresponding centroid), and keep the centroid dictionary, how much smaller is the picture size than the original?*

# Resources

* https://github.com/hanyoseob/python-k-means