# K-means iteration
Stough, DIP

Here we do k-means clustering on an image, to get
representative colors for the image. 

In [35]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np

# For importing from alternative directory sources
import sys  
sys.path.insert(0, '../dip_utils')

from matrix_utils import (arr_info,
                          make_linmap)
from vis_utils import (vis_rgb_cube,
                       vis_hists,
                       vis_pair,
                       vis_surface)

from scipy.spatial.distance import cdist

K = 16
MAXITER = 20
NUMPOINTS = 100

In [36]:
I = plt.imread('../dip_pics/bellagio.jpg').astype(float)
X = np.stack([I[...,i].ravel() for i in range(3)]).T

In [37]:
plt.figure()
plt.imshow(I/255)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x2d31b3868c8>

In [38]:
X[:10,:]

array([[118.,  57.,   0.],
       [117.,  55.,   8.],
       [125.,  64.,  10.],
       [130.,  70.,   8.],
       [119.,  58.,   3.],
       [118.,  57.,  10.],
       [112.,  52.,   2.],
       [115.,  57.,  11.],
       [104.,  52.,   4.],
       [126.,  61.,   7.]])

In [39]:
# For fun: Let's pick some random (and good) cluster colors
clusterColors = np.random.rand(K, 3) # Picking random colors for each cluster.
varsSoFar = np.var(clusterColors, axis=0) # Should be 1 x 2 of the variance of each column.

for i in range(3*K):
    tempColors = np.random.rand(K, 3) #Picking K random colors
    vartemp = np.var(tempColors, axis=0)
    if np.all(vartemp > varsSoFar):
        clusterColors = tempColors
        varsSoFar = vartemp

In [40]:
clusterColors

array([[0.74643787, 0.83693395, 0.58762399],
       [0.88874007, 0.27330553, 0.13750022],
       [0.12507434, 0.24678442, 0.8125241 ],
       [0.8386965 , 0.6907052 , 0.38506042],
       [0.64390606, 0.98813635, 0.19241079],
       [0.74246131, 0.25433831, 0.03998556],
       [0.78800876, 0.04163534, 0.62301089],
       [0.26347326, 0.89919763, 0.92599247],
       [0.05658958, 0.95707786, 0.29369148],
       [0.27903652, 0.52557469, 0.9377231 ],
       [0.45627356, 0.09604034, 0.31410873],
       [0.08502557, 0.53111974, 0.54974435],
       [0.40558427, 0.16552639, 0.302825  ],
       [0.2927781 , 0.31564111, 0.64914525],
       [0.16930597, 0.35896736, 0.80885688],
       [0.07435364, 0.02949479, 0.98129898]])

&nbsp;

### Pick some initial cluster centers.

In [41]:
# K-means: initialization
# pick K initial cluster centers.
# whichinit = random.randint(0, len(X), size=(K,)) # Could generate repeats.
whichinit = np.random.choice(len(X), size=K, replace=False)
CC = X[whichinit, :].copy() # Cluster Centers

In [42]:
CC_init = CC.copy()
CC

array([[229., 194., 138.],
       [132.,  72.,  18.],
       [ 34.,  37.,  16.],
       [153.,  78.,  10.],
       [126.,  98.,  74.],
       [149.,  95.,  48.],
       [102.,  53.,  12.],
       [111., 107., 108.],
       [113.,  45.,   0.],
       [153., 132., 101.],
       [ 77.,  29.,   6.],
       [ 65.,  28.,   1.],
       [133.,  58.,  27.],
       [220., 162.,  80.],
       [ 82.,  28.,   2.],
       [ 23.,  62.,  35.]])

&nbsp;

### The main Expectation-Maximization loop

Basically, we assign a cluster to each point, and then
recompute the clusters based on that assignment.

In [43]:
# K-means: compute: for every data point determine which center is closest.
# Need some magic function that computes the distance between every row of X
# (the points) and every row of CC (the clusters).
for i in range(MAXITER):
    D = cdist(X, CC, 'euclidean')
    # D should be NUMPOINTS x K

    whichCluster = np.argmin(D, axis=1) # NUMPOINTS x 1 of which center was closest

    # K-means: recompute the cluster centers as the mean of the data in each cluster
    for c in range(K):
        if np.any(whichCluster == c):
            CC[c,:] = np.mean(X[whichCluster == c, :], axis=0) # average of just those that were closest to c.

In [44]:
# Doing this on a big image, don't want to scatter 100Ks of points, really slow.
rands = np.sort(np.random.choice(len(X), size=500*K, replace=False))


f, ax = plt.subplots(1,3, figsize=(9,3), sharex=True, sharey=True)
ax[0].scatter(X[rands,0], X[rands,1], c='gray', s=20)
ax[0].set_title('Original Data')


ax[1].scatter(X[rands,0], X[rands,1], c='gray', alpha=.5, s=20)
ax[1].scatter(CC_init[:,0], CC_init[:,1], c=CC_init/255, s=50)
ax[1].set_title('Initial Cluster Centers')


pointColors = CC[whichCluster[rands], :]
clusterEdgeColors = 1 - clusterColors # for contrast, make the cluster center edges opposite.

ax[2].scatter(X[rands,0], X[rands,1], c=pointColors/255, alpha=.5, s=20)
ax[2].scatter(CC[:,0], CC[:,1], c=CC/255, edgecolors=clusterEdgeColors, s=50)
ax[2].set_title('Recomputed Clusters')

plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [45]:
f, ax = plt.subplots(1,2, figsize=(8,3), sharex=True, sharey=True)

ax[0].imshow(I/255)
ax[0].set_title('Original Image')

# Reconstructed Image.
Ir = np.reshape(CC[whichCluster,:], I.shape)
ax[1].imshow(Ir/255) # Because it's floating point.
ax[1].set_title('{} color reconstruction'.format(K))

plt.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [46]:
vis_rgb_cube(I)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [47]:
vis_rgb_cube(Ir)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [48]:
origvals = I[..., 0].ravel()*256**2 + I[..., 1].ravel()*256 + I[..., 2].ravel()

In [49]:
len(np.unique(origvals))

223982

In [50]:
newvals = Ir[..., 0].ravel()*256**2 + Ir[..., 1].ravel()*256 + Ir[..., 2].ravel()

In [51]:
len(np.unique(newvals))

16