In [45]:
#Image Compression
#By Anton Fedun
import time
import numpy as np
import random
import matplotlib.pyplot as plt
import scipy.misc
from PIL import Image
from scipy import ndimage

start_time = time.time()
fname = "nature" #works with 512x512 images
N = 6                #number of colors in compressed image
max_iters = 4      #the bigger - the most likely data fit


def load_data(fname):
    image = plt.imread(fname + '.jpg')
    num_px = image.shape[0]
    num_py = image.shape[1]
    print(image.shape)
    my_image = image.reshape(num_px * num_py, 3)
    X = my_image / 255
    return num_px, num_py, X

print("Loading data")
num_px, num_py, X = load_data(fname)        

def findClosestCentroids(X, centroids):
    m = X.shape[0]
    idx = np.zeros((m, 1))
    dist = 0
    for i in range(m):
        minim = np.inf
        for k in range(N):
            diff = X[i] - centroids[k]
            dist = np.dot(diff.T, diff)
            if minim > dist:
                minim = dist
                idx[i] = k
    return idx

def initCentroids(X):  
    return np.random.permutation(X)

def computeCentroids(X, idx):
    m = X.shape[0]
    n = X.shape[1]
    centroids = np.zeros((N, n))
    for k in range(N):
        summ = np.zeros(n)
        numk = len(idx[idx == k])
        for i in range(m):
            if k == idx[i]:
                summ = summ + X[i]
                
        centroids[k] = (summ / numk).T
    return centroids

def runKMeans(X, init_centroids):
    centroids = init_centroids
    for i in range(max_iters):
        print("Iteration ", i + 1, " of ", max_iters)
        idx = findClosestCentroids(X, centroids)
        centroids = computeCentroids(X, idx)
    return centroids, idx

def getImage(X, centroids):
    idx = findClosestCentroids(X, centroids)
    
    X_recov = centroids[idx.astype(int)].reshape(num_px, num_py, 3)
    return X_recov



print("Get initial centroids")
init_centroids = initCentroids(X)

print("Running K-Means")
centroids, idx = runKMeans(X, init_centroids)

print("Compressing image")
X_recov = getImage(X, centroids)
print(X_recov.shape)
plt.imsave(fname + '_compressed.jpg', X_recov)
print("Done!")
print(abs((start_time - time.time())))

Loading data
(512, 512, 3)
Get initial centroids
Running K-Means
Iteration  1  of  4
Iteration  2  of  4
Iteration  3  of  4
Iteration  4  of  4
Compressing image
(512, 512, 3)
Done!
58.06285643577576
