In [34]:
import numpy as np

In [35]:
class KMeans():
  def __init__(self, n_clusters, max_iter=100, e=0.0001):
    """
    n_clusters: K - Number of clusters
    max_iter: Maximum number of iterations
    e: Tolerance - Difference between successive distortion values to be considered converged
    """
    self.n_clusters = n_clusters
    self.max_iter = max_iter
    self.e = e

  def fit(self, x):
    assert len(x.shape) == 2, "Fit only takes 2D numpy arrays as input"

    np.random.seed(42)
    N, D = x.shape
    K = self.n_clusters

    # Initialize \mu_k as random datapoints from x
    cluster_centers = np.random.choice(N, size=K, replace=False)
    mu = x[cluster_centers, :]

    # Intialize the loss function (Distortion Function)
    J = np.inf

    for i in range(self.max_iter):
      # Compure r
      r = np.zeros(N)
      dist = np.zeros((N, K))

      for n in range(N):
        for k in range(K):
          dist[n, k] = np.inner(mu[k, :] - x[n, :], mu[k, :] - x[n, :])

      r = np.argmin(dist, axis=1)

      J_new = 0
      for n in range(N):
        J_new += dist[n, r[n]]

      # Computed the average distortion
      J_new /= N

      print(f"Iteration {i}: J = {J}")

      # If it is less the tolerance, it has reached convergence
      if np.absolute(J - J_new) <= self.e:
        return (mu, r, i)

      J = J_new

      # Update the cluster centers
      for k in range(K):
        k_idx_samples, = np.where(r == k)

        if len(k_idx_samples) > 0:
            mu[k] = x[k_idx_samples].mean(axis=0)
        else:
            mu[k] = x[np.random.choice(N)]

    print("Did not reach convergence")
    return (mu, r, self.max_iter)

In [36]:
import matplotlib.pyplot as plt
import os

def transform_image(image, code_vectors):
  """
  Quantize the image using the code_vectors.
  Return a new image from the image by replacing each RGB value in the image
  with the nearest code vectors (nearest by Euclidean Distance)
  """

  assert image.shape[2] == 3 and len(image.shape) == 3, "Image should be a 3-D array with size (?, ?, 3)"
  assert code_vectors.shape[1] == 3 and len(code_vectors.shape) == 2, "code_vectors should be a 2-D array with size (?, 3)"

  H, W, _ = image.shape
  K, _ = code_vectors.shape

  new_image = np.zeros(image.shape)

  for h in range(H):
    for w in range(W):
      dist = np.zeros(K)

      for k in range(K):
        dist[k] = np.inner(image[h, w] - code_vectors[k], image[h, w] - code_vectors[k])

      k = np.argmin(dist)
      new_image[h, w] = code_vectors[k]

  return new_image

In [37]:
def kmeans_image_compression(filename):
  os.makedirs("plots", exist_ok=True)
  os.makedirs("results", exist_ok=True)

  im = plt.imread(filename)
  N, M = im.shape[:2]
  im = im / 255 # Normalize to [0, 1]

  # Convert to RGB array
  data = im.reshape(N * M, 3)

  k_means = KMeans(n_clusters=16, max_iter=100, e=1e-6)
  centroids, _, i = k_means.fit(data)

  print(f"RGB centroids computed in {i} iterations")
  new_im = transform_image(im, centroids)

  assert new_im.shape == im.shape, "Shape of transformed image should be same as image"

  mse = np.sum((im - new_im)**2) / (N * M)
  print(f"Mean Squared Error per Pixel is {mse}")

  plt.imsave("plots/compressed_colorful_img.png", new_im)

  np.savez("results/k_means_compression.npz", im=im, centroids=centroids, step=i, new_image=new_im, pixel_error=mse)

In [38]:
kmeans_image_compression("baboon.tiff")

Iteration 0: J = inf
Iteration 1: J = 0.0213871498797225
Iteration 2: J = 0.013155467082854195
Iteration 3: J = 0.0113426576121396
Iteration 4: J = 0.010419410011150732
Iteration 5: J = 0.010187261407946915
Iteration 6: J = 0.010063018072320542
Iteration 7: J = 0.009990184912162458
Iteration 8: J = 0.00994627603746321
Iteration 9: J = 0.0099170297274798
Iteration 10: J = 0.009896312820468902
Iteration 11: J = 0.009880773669092046
Iteration 12: J = 0.00986803558355922
Iteration 13: J = 0.009856609203531763
Iteration 14: J = 0.009845835778608366
Iteration 15: J = 0.009835822537712455
Iteration 16: J = 0.00982683564042616
Iteration 17: J = 0.009818443777669687
Iteration 18: J = 0.009810533955649122
Iteration 19: J = 0.009802561625479546
Iteration 20: J = 0.009794453484327347
Iteration 21: J = 0.009786665740350076
Iteration 22: J = 0.00977904502488457
Iteration 23: J = 0.009771748463356551
Iteration 24: J = 0.00976440194913811
Iteration 25: J = 0.0097567890586577
Iteration 26: J = 0.009749