In [1]:
##############################################
# Program to illustrate the K-means algorithm
# Author: Sumohana Channappayya
##############################################
# Import all required libraries
import numpy as np
from matplotlib import pyplot as plt

In [2]:
#function to calculate distance vector containing all the distances
def distance_vector(cur_y,D,k,p):
    dist_vector = []

    for i in range(k):
        d = 0
        for j in range(p):
            d = d + np.square(cur_y[j] - D[j,i]) 
        dist_vector.append(d)
    return dist_vector

In [3]:
#function to calculate error
def err(update,D,k,p):
    error = 0
    sum = 0
    for i in range(k):
        for j in range(p):
            sum = sum + np.square(update[i,j] - D[j,i])
    error = np.sqrt(sum)
    return error

In [4]:
# NOTE: This code is written specifically for the case K = 4 and data dimension p = 2
# You will have to update this to handle the general case

# First generate random numbers
# For illustration, we will generate 2-D points
# Specifically, four clusters from 2-D Gaussian distribution
NUM_PTS = 2000

#choose K and p
k = 6
p = 2

#Generate clusters
Y = np.empty((p,0),float)
exp_cent = []
for i in range(k):
    mean =  np.random.randint(low=-k/4,high=k/4+1,size=(p,)) 
    A = np.random.uniform(low=0,high=0.2,size=(p,p))
    cov = np.diag(np.diag(A))
    temp = np.random.multivariate_normal(mean, cov, NUM_PTS).T
    Y = np.concatenate((Y,temp),axis=1)
    exp_cent.append(mean)

# Display the data
# plt.plot(Y[0,:], Y[1,:],'x')
# plt.axis('equal')
# plt.grid(True)
# plt.xlabel('y1')
# plt.ylabel('y2')
# plt.title('K-means demo')
# plt.show()

In [5]:
# NOTE: This code is written specifically for the case K = 4 and data dimension p = 2
# You will have to update this to handle the general case

# Now start the algorithm
# Set stopping condition 
epsilon = 0.001
# Initialize error to a large value
error = 10000
# Initialize centroids - assume K of them from a Mutivariate Gaussian distribution 
mean = np.ones((p,),float)
cov = np.ones((p,p),float)
cov = np.diag(np.diag(cov))
D = np.random.multivariate_normal(mean, cov, k).T 
# Plot centroids
# plt.plot(Y[0,:], Y[1,:],'x')
# plt.axis('equal')
# plt.grid(True)
# plt.xlabel('y1')
# plt.ylabel('y2')
# plt.plot(D[0,:], D[1,:],'r+')
# plt.show()

In [6]:
# NOTE: This code is written specifically for the case K = 4 and data dimension p = 2
# You will have to update this to handle the general case
# Initialize iteration count to 0
count = 0
# Initialize cluster size to 0
while (error > epsilon):
    num = np.zeros((k,),float)
    # Initialize centroid update to 0
    update = np.zeros((k,p),float)
    # Update clusters based on distance
    for idx in range(k*NUM_PTS):
        cur_y = Y[:,idx]
        # NOTE: There are more efficient ways of computing distances and norms.
        # This code is just for illustration and in no way optimized.

        # Find Euclidean distance of current point from each centroid
        
        dist_vector = distance_vector(cur_y,D,k,p)

        # Find closest centroid 
        min_idx = dist_vector.index(np.min(dist_vector))

        for j in range(p):
            update[min_idx][j]+=cur_y[j]

        num[min_idx]+=1
    
    # NOTE: The divide by zero case has to be handled gracefully
    # Calculate centroids
    # x1 = x1/num1; y1 = y1/num1
    # x2 = x2/num2; y2 = y2/num2
    # x3 = x3/num3; y3 = y3/num3
    # x4 = x4/num4; y4 = y4/num4
    for i in range(k):
        for j in range(p):
            if(num[i]>0):
                update[i][j] = update[i][j]/num[i]
    # Compute distance between centroids
    # error = np.sqrt((np.square(x1 - D[0,0]) + np.square(y1 - D[1,0])) + (np.square(x2 - D[0,1]) + np.square(y2 - D[1,1])) + (np.square(x3 - D[0,2]) + np.square(y3 - D[1,2])) + (np.square(x4 - D[0,3]) + np.square(y4 - D[1,3])))
    error = err(update,D,k,p)
    # Update centroids
    # D[0, 0] = x1; D[1, 0] = y1
    # D[0, 1] = x2; D[1, 1] = y2
    # D[0, 2] = x3; D[1, 2] = y3
    # D[0, 3] = x4; D[1, 3] = y4
    for i in range(k):
        for j in range(p):
            D[j,i] = update[i,j]

    # Update count
    count += 1
    
    #Plot updated centroids
    
    # plt.plot(Y[0,:], Y[1,:],'x')
    # plt.axis('equal')
    # plt.grid(True)
    # plt.plot(D[0,:], D[1,:],'r+')
    # plt.pause(1)
    # plt.ion()
    # plt.show()

    # Print error

    print ('Iteration: ', count, 'Error: ', error)

# Print centroids

print ('Centroids:')
for i in range(k):
    print(D[:,i])
print ('Expected centroids :')
for i in range(k):
    print(exp_cent[i])

Iteration:  1 Error:  1.3073613324034719
Iteration:  2 Error:  0.5467382363700163
Iteration:  3 Error:  0.2955620364471244
Iteration:  4 Error:  0.28971826833851716
Iteration:  5 Error:  0.32955862563101773
Iteration:  6 Error:  0.27165989004713736
Iteration:  7 Error:  0.3109570920412029
Iteration:  8 Error:  0.23440245230802823
Iteration:  9 Error:  0.1423457787651052
Iteration:  10 Error:  0.09070722871506874
Iteration:  11 Error:  0.05754341263833031
Iteration:  12 Error:  0.0369723093042566
Iteration:  13 Error:  0.023924949972887716
Iteration:  14 Error:  0.01647259920848211
Iteration:  15 Error:  0.011345997723481854
Iteration:  16 Error:  0.008742603600588298
Iteration:  17 Error:  0.004957052867431524
Iteration:  18 Error:  0.004338833779338352
Iteration:  19 Error:  0.004103818289891901
Iteration:  20 Error:  0.002353030682559319
Iteration:  21 Error:  0.0017583569853842787
Iteration:  22 Error:  0.0004115932326733517
Centroids:
[-0.69031998  1.01939044]
[ 0.93277896 -0.15341