In [1]:
import numpy as np
import pandas as pd
import math

In [33]:
def generate_dataset(sigma: float):
    samples1 = np.random.multivariate_normal(\
        np.array([-1, -1]), sigma * np.array([[2, 0.5], [0.5, 1]]), 100)
    samples2 = np.random.multivariate_normal(\
        np.array([1, -1]), sigma * np.array([[1, -0.5], [-0.5, 2]]), 100)
    samples3 = np.random.multivariate_normal(\
        np.array([0, 1]), sigma * np.array([[1, 0], [0, 2]]), 100)
    dataset = np.concatenate((samples1, samples2, samples3))
    labels = ['a'] * 100 + ['b'] * 100 + ['c'] * 100
    return (dataset, labels)

In [63]:
def km_clustering(X: np.array, K: int) -> (np.array, np.array):
    # randomly select k centroids
    dbounds = [(np.min(X[:, i]), np.max(X[:, i])) for i in range(X.shape[1])]
    print('Bounds: ' + str(dbounds))
    centers = np.array([[np.random.random_sample() * (dbounds[j][1] - dbounds[j][0]) + dbounds[j][0] \
                for j in range(len(dbounds))] for i in range(K)])
    print('Centers: ' + str(centers))
    converged = False
    Y = np.zeros(X.shape[0])
    while not converged:
        change = False
        # assign each x to a centroid
        for i in range(len(X)):
            prev_y = Y[i]
            x = X[i]
            distances = np.linalg.norm(centers - x, axis=1)
            closest_center = np.argmin(distances)
            Y[i] = closest_center
            if Y[i] != prev_y:
                change = True
        # update each centroid
        for k in range(len(centers)):
            cluster_points = [X[j] for j in np.where(Y == k)[0]]
            centers[k] = np.mean(cluster_points)
        print('Centers: ' + str(centers))
        # if no centroid assignment changed, we've converged
        if not change:
            converged = True
            print('Converged')
    return (centers, Y)

In [64]:
X, Y = generate_dataset(0.5)
centroids, classes = km_clustering(X, 3)
print(centroids)

Bounds: [(-3.807723229680784, 2.347077117056819), (-3.3501783528858677, 3.620042721979201)]
Centers: [[-3.36665674  1.95101026]
 [-1.530027   -0.89918082]
 [-2.54025088 -0.10031124]]
Centers: [[ 0.84739587  0.84739587]
 [-0.23724215 -0.23724215]
 [-0.39647157 -0.39647157]]
Centers: [[ 0.7558784   0.7558784 ]
 [-0.02096139 -0.02096139]
 [-0.96877614 -0.96877614]]
Centers: [[ 0.78174741  0.78174741]
 [-0.09379344 -0.09379344]
 [-1.14854369 -1.14854369]]
Centers: [[ 0.77166064  0.77166064]
 [-0.15430662 -0.15430662]
 [-1.27086695 -1.27086695]]
Centers: [[ 0.7663419   0.7663419 ]
 [-0.18921669 -0.18921669]
 [-1.35058354 -1.35058354]]
Centers: [[ 0.74568961  0.74568961]
 [-0.21428731 -0.21428731]
 [-1.37846645 -1.37846645]]
Centers: [[ 0.74568961  0.74568961]
 [-0.21811489 -0.21811489]
 [-1.38792509 -1.38792509]]
Centers: [[ 0.74568961  0.74568961]
 [-0.22207552 -0.22207552]
 [-1.39725203 -1.39725203]]
Centers: [[ 0.74038361  0.74038361]
 [-0.22539688 -0.22539688]
 [-1.39725203 -1.39725203]