## K-Means
finds K clusters from "unlabeled" data

In [None]:
# data=[(x_0,y_0),...,(x_n,y_n)]
# labels=[c_0,...,c_n]

In [4]:
import random
def initialize_centroids(data, k):
    x_min = y_min = float('inf')
    x_max = y_max = float('-inf')
    for point in data:
        x_min = min(point[0], x_min)
        x_max = max(point[0], x_max)
        y_min = min(point[1], y_min)
        y_max = max(point[1], y_max)
    centroids = []
    for i in range(k):
        centroids.append([x_min+(x_max-x_min)*random.random(), y_min+(y_max-y_min)*random.random()])
    
    return centroids

In [7]:
def get_dist(point0, point1):
    return ((point0[0]-point0[0])**2 + (point1[1]-point1[1])**2)**0.5

In [9]:
# c_i:=argmin||x_i-u_j||
def get_labels(data, centroids):
    labels = []
    for point in data:
        min_dist = float('inf')
        label = None
        for i, centroid in enumerate(centroids):
            new_dist = get_dist(point, centroid)
            if min_dist > new_dist:
                min_dist = new_dist
                label = i
        labels.append(label)
    return labels

In [11]:
def update_centroids(points, labels, k):
    new_centroids = [[0, 0] for i in range(k)]
    counts = [0]*k

    # sum of all x and y in each label 
    for point, label in zip(points, labels):
        new_centroids[label][0] += point[0]
        new_centroids[label][1] += point[1]
        counts[label] += 1

    # cal center
    for i, (x, y) in enumerate(new_centroids):
        new_centroids[i] = (x/counts[i], y/counts[i])
    return new_centroids
        

In [12]:
def should_stop(old_centroids, centroids, threshold=1e-5):
    total_movement = 0
    for old_point, new_point in zip(old_centroids, centroids):
        total_movement += get_dist(old_point, new_point)
    return total_movement < threshold

In [13]:
# init cent
# c_i,miu_j
def main(data, k):
    centroids = initialize_centroids(data, k)
    while True:
        old_centroids = centroids
        labels = get_labels(data, centroids)
        centroids = update_centroids(data, labels, k)

        if should_stop(old_centroids, centroids):
            break

    return labels