# K-means single iteration
Stough, DIP

Demo on k means clustering. We generate a little data and then do a single
step of the EM scheme for clustering.

Finding objects in images:
- Color coherence
- Spatial coherence

In [36]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np

# For importing from alternative directory sources
import sys  
sys.path.insert(0, '../dip_utils')

from matrix_utils import (arr_info,
                          make_linmap)
from vis_utils import (vis_rgb_cube,
                       vis_hists,
                       vis_pair,
                       vis_surface)

from scipy.spatial.distance import cdist

K = 2
NUMPOINTS = 1000

In [37]:
# Let's generate 2d random data
X1 = np.random.randn(NUMPOINTS//2, 2)  # Unit normal (mean 0, var 1)
X2 = np.random.randn(NUMPOINTS//2, 2)

X2[:,0] = 5 + 2*X2[:,0]
X2[:,1] = 3 + 2*X2[:,1]  # X2 data should be mean (5,3) var 4

X = np.concatenate([X1, X2], axis=0)

### Look at the generated data

In [38]:
plt.figure()
plt.scatter(X[:,0], X[:,1])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.collections.PathCollection at 0x1d1cbf41908>

In [39]:
# For fun: Let's pick some random (and good) cluster colors
clusterColors = np.random.rand(K, 3) # Picking random colors for each cluster.
varsSoFar = np.var(clusterColors, axis=0) # Should be 1 x 2 of the variance of each column.

for i in range(3*K):
    tempColors = np.random.rand(K, 3) #Picking K random colors
    vartemp = np.var(tempColors, axis=0)
    if np.all(vartemp > varsSoFar):
        clusterColors = tempColors
        varsSoFar = vartemp

In [40]:
# K-means: initialization
# pick K initial cluster centers.
# whichinit = random.randint(0, len(X), size=(K,)) # Could generate repeats.
whichinit = np.random.choice(len(X), size=K, replace=False)
CC = X[whichinit, :].copy() # Cluster Centers

CC_init = CC.copy()

In [41]:
CC

array([[6.58570957, 2.65056788],
       [5.19440058, 5.30489351]])

In [42]:
plt.figure(figsize=(4,4))
plt.scatter(X[:,0], X[:,1], c='gray', alpha=.5)
plt.scatter(CC[:,0], CC[:,1], s=30, c=clusterColors)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.collections.PathCollection at 0x1d1cc316788>

In [43]:
# K-means: compute: for every data point determine which center is closest.
# Need some magic function that computes the distance between every row of X
# (the points) and every row of CC (the clusters).
D = cdist(X, CC, 'euclidean')
# D should be NUMPOINTS x K

whichCluster = np.argmin(D, axis=1) # NUMPOINTS x 1 of which center was closest

In [44]:
D[:10,:]

array([[8.41736571, 8.34468327],
       [6.73222898, 6.88836513],
       [8.44733644, 8.50686989],
       [7.57209471, 7.56758002],
       [7.12863689, 7.2459259 ],
       [6.3308682 , 6.28511512],
       [5.3577122 , 5.23548232],
       [7.92974429, 7.99907557],
       [6.00956935, 6.37432658],
       [7.42153421, 7.15924759]])

In [45]:
X[0,:]

array([-1.5056444 ,  0.33064622])

In [46]:
plt.figure(figsize=(4,4))
plt.scatter(X[:,0], X[:,1], c=clusterColors[whichCluster], alpha=.5)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.collections.PathCollection at 0x1d1cc382f88>

In [47]:
# K-means: recompute the cluster centers as the mean of the data in each cluster
for c in range(K):
    CC[c,:] = np.mean(X[whichCluster == c, :], axis=0) # average of just those that were closest to c.

In [48]:
np.mean(X[whichCluster == 0, :],axis=0)

array([2.50390014, 0.54106784])

In [49]:
(whichCluster == 1).sum()

309

In [50]:
plt.figure(figsize=(4,4))
plt.scatter(X[:,0], X[:,1], c='gray', alpha=.5)
plt.scatter(CC[:,0], CC[:,1], s=30, c=clusterColors)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.collections.PathCollection at 0x1d1cc3d9fc8>

In [51]:
# K-means: compute: for every data point determine which center is closest.
# Need some magic function that computes the distance between every row of X
# (the points) and every row of CC (the clusters).
D = cdist(X, CC, 'euclidean')
# D should be NUMPOINTS x K

whichCluster = np.argmin(D, axis=1) # NUMPOINTS x 1 of which center was closest

In [52]:
plt.figure(figsize=(4,4))
plt.scatter(X[:,0], X[:,1], c=clusterColors[whichCluster], alpha=.5)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.collections.PathCollection at 0x1d1cc42a888>

In [53]:
f, ax = plt.subplots(1,3, figsize=(9,3), sharex=True, sharey=True)
ax[0].scatter(X[:,0], X[:,1], c='gray', s=20)
ax[0].set_title('Original Data')


ax[1].scatter(X[:,0], X[:,1], c='gray', alpha=.5, s=20)
ax[1].scatter(CC_init[:,0], CC_init[:,1], c=clusterColors, s=50)
ax[1].set_title('Initial Cluster Centers')






pointColors = clusterColors[whichCluster]
clusterEdgeColors = 1 - clusterColors # for contrast, make the cluster center edges opposite.

ax[2].scatter(X[:,0], X[:,1], c=pointColors, alpha=.5, s=20)
ax[2].scatter(CC[:,0], CC[:,1], c=clusterColors, edgecolors=clusterEdgeColors, s=50)
ax[2].set_title('Recomputed Clusters')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'Recomputed Clusters')