In [45]:
import numpy as np 
points = np.vstack(((np.random.randn(150, 2) * 0.75 + np.array([1, 0])),
                  (np.random.randn(50, 2) * 0.25 + np.array([-0.5, 0.5])),
                  (np.random.randn(50, 2) * 0.5 + np.array([-0.5, -0.5]))))

In [46]:
def initialize_centroids(points, k):
    """returns k centroids from the initial points"""
    centroids = points.copy()
    np.random.shuffle(centroids)
    return centroids[:k]

In [47]:
# points: N * d
# centroids: k * d
def closest_centroid(points, centroids):
    """return the closest centroids"""
    distance = np.sqrt(np.square(points[np.newaxis,...] - centroids[:,np.newaxis]).sum(axis=2))
    return np.argmin(distance, axis=0)

In [48]:
def move_centroids(points, centroids, closest):
    """update centroids"""
    updated_centroids = []
    for k in range(centroids.shape[0]):
        updated_centroids.append(points[closest == k,:].mean(axis=0))
    return np.vstack(updated_centroids)

In [49]:
def WWSM(points, centroids, closest):
    """  """
    # points N * d
    # centroids k * d
    # closest N
    res = 0
    for k in range(centroids.shape[0]):
        dist = np.sqrt(((points[closest == k,:] - centroids[k,:])**2).sum(axis=-1)).mean(axis=0)
        res += dist
    return res

In [80]:
def fit(points, k, eps):
    centroids = initialize_centroids(points, k)
    closest = closest_centroid(points, centroids)
    prev = closest + 100
    while True:
        closest = closest_centroid(points, centroids)
        centroids = move_centroids(points, centroids, closest)
        if np.sum(prev == closest) / closest.shape[0] >  1 - eps:
            break
        prev = closest
        print(WWSM(points, centroids, closest))
    return centroids, WWSM(points, centroids, closest)

In [84]:
fit(points, 10, 1e-10)

4.105544557908495
4.004051902113315
3.8573030988485084
3.7313833607868596
3.64874697852526
3.589107350238372
3.593052073264886
3.593689825457206
3.5920748957768023
3.591375824760367
3.5882545520887064
3.585247262758899
3.5815308096746707


(array([[-0.8073919 , -0.60212976],
        [ 0.59248482,  0.136477  ],
        [ 1.07327096, -1.4094164 ],
        [ 0.82138418,  1.09118977],
        [-0.33945092,  0.15242423],
        [ 1.79716692,  0.61942885],
        [-0.67029524,  0.54505877],
        [ 1.24657722, -0.4243448 ],
        [ 2.19533952, -0.4427906 ],
        [ 0.19350729, -0.66008546]]),
 3.5815308096746707)