In [32]:
import os
import sys

from PIL import Image
import numpy as np

Image Compression in Python using K-Means
https://rickwierenga.com/blog/machine%20learning/image-compressor-in-Python.html

In [33]:
# IMAGE ADDRESS
TEST_IMAGE1 = "test_image1.jpg"
TEST_IMAGE2 = "test_image2.jpg"
TEST_IMAGE3 = "test_image3.jpg"
TEST_IMAGE4 = "test_image4.jpg"
TEST_IMAGE5 = "test_image5.jpg"

Helper functions for K-means

In [34]:
def initialize_K_centroids(X, K):
    """ Choose K points from X at random """
    m = len(X)

    # random sampling without replacement from X to form K centroids initially (randomly chosen)
    return X[np.random.choice(m, K, replace=False), :]

In [35]:
def find_closest_centroids(X, centroids):
    m = len(X)

    c = np.zeros(m)
    for i in range(m):
        # find distances
        distances = np.linalg.norm(X[i] - centroids, axis=1)

        # assign closest cluster to c[i]
        c[i] = np.argmin(distances)

    # list of the INDEX closest to an element in X
    return c

In [36]:
def compute_means(X, idx, K):
    _, n = X.shape
    centroids = np.zeros((K, n))

    # go through all clusters / centroids
    for k in range(K):
        examples = X[np.where(idx == k)]

        # compute mean for each centroid
        mean = [np.mean(column) for column in examples.T]
        centroids[k] = mean
    return centroids

In [37]:
def find_k_means(X, K, max_iters=10):
    # initialise centroids
    centroids = initialize_K_centroids(X, K)

    # keep track of state of all centroids
    previous_centroids = centroids

    # run to max number of re-clustering
    for _ in range(max_iters):
        idx = find_closest_centroids(X, centroids)
        centroids = compute_means(X, idx, K)

        # if repeating centroids, we stop
        if (centroids == previous_centroids).all():
            # the centroids aren't moving anymore.
            return centroids
        
        else:
            previous_centroids = centroids

    return centroids, idx

Loading and modifying image

In [38]:
def load_image(path):
    """ Load image from path. Return a numpy array """
    image = Image.open(path)

    # scale pixels to between 0 and 1
    return np.asarray(image) / 255

In [39]:
IMAGE_TO_USE = TEST_IMAGE5
image = load_image(IMAGE_TO_USE)
w, h, d = image.shape
print('Image found with width: {}, height: {}, depth: {}'.format(w, h, d))

Image found with width: 183, height: 275, depth: 3


In [40]:
# reshape the grid into a linear scale, as each pixel simply represents a colour
X = image.reshape((w * h, d))

# number of colours desired in output image - large enough for colour
K = 5

In [41]:
colors, _ = find_k_means(X, K, max_iters=20)

In [None]:
# the find_kmeans() function returns array that is ONE ITERATION behind
idx = find_closest_centroids(X, colors)

In [None]:
# convert colours to uint8 array
idx = np.array(idx, dtype=np.uint8)

# reshaping back to original dimensions
X_reconstructed = np.array(colors[idx, :] * 255, dtype=np.uint8).reshape((w, h, d))
compressed_image = Image.fromarray(X_reconstructed)

In [None]:
compressed_image.save("test_image_output1.jpg")

In [None]:
compressed_image.show()