In [1]:
'''Implement k-means clustering algorithm.
'''

import numpy as np

def k_means_cluster(data, C, max_num_iteration=10):
    '''
    args:
        data: np.array of shape [N, E]
        c: number of targeted clusters
    returns:
        Cluster ID np.array of shape [N] and type int

    1) init
    2) put all the points into the closest center (by squared Euclidian Distance) and form clusters
    3) update the center of each clusters
    4) Go back to 2) unless the clusters don't change or we hit the number of total iterations.
    '''
    N = data.shape[0]
    E = data.shape[1]
    if C > N:
        return np.arange(N)
    # 1) random init
    centers = data[:C, :]  # C, E

    for step in range(max_num_iteration):
        # 2) form clusters
        square_distances = ((data.reshape(N, 1, E) - centers.reshape(1, C, E)) ** 2).sum(2)  # N, C
        cluster_ids = square_distances.argmin(1, keepdims=1)  # N, 1

        # 3) update centers
        new_centers = centers.copy()
        for i in range(C):
            cluster_mask = cluster_ids == i
            # TODO: handle the case where no points are in the cluster
            new_centers[i] = (data * cluster_mask).sum(0) / cluster_mask.sum()

        # print(f"Iteration {step}")
        # print(f"cluster_ids={cluster_ids}")
        # print(f"centers={centers}")
        # print()
        
        if np.allclose(centers, new_centers):
            break
        else:
            centers = new_centers
    
    print(f"Cluster Result after {step + 1} iterations: {cluster_ids.reshape(-1)}")

    return cluster_ids





# k_means_cluster(np.array([
#     [0, 0], [4, 4], [4, 3], [1, 1], [0, 3]
# ], dtype=float), 3)  # expect [0, 1, 2, 2, 0]

# k_means_cluster(np.array([
#     [0, 0], [0, 3], [4, 4], [4, 3], [1, 1] 
# ], dtype=float), 3)  # expect [0, 1, 2, 2, 0]

# k_means_cluster(np.random.randn(10, 2), 3)

# test_data = [
#     np.random.randn(10, 2),
#     np.random.randn(10, 2) + np.array([[10, 10.0]]),
# ]
# k_means_cluster(np.concatenate(test_data), 2)

# test_data = [
#     np.random.randn(10, 2),
#     np.random.randn(10, 2) + np.array([[5, 10.0]]),
#     np.random.randn(10, 2) + np.array([[0, 20.0]]),
# ]
# k_means_cluster(np.concatenate(test_data), 3)

test_data = [
    np.random.randn(100, 2),
    np.random.randn(100, 2) + np.array([[1, 2.0]]),
    np.random.randn(100, 2) + np.array([[0, 3.0]]),
]
k_means_cluster(np.concatenate(test_data), 80, max_num_iteration=100)

Cluster Result after 26 iterations: [ 0  1 45  3  4  5 16  7  8  9 10 11 22 13 14 77 36 17 18 68 20 21 38 23
 28 22 26 27 10 69 30 56 32 33 34 35 36 37 38 76  1 41 33 43 44 76 46 47
 48 49 50 51 24 53 54 55 56 57 58 59 69 35 62 63  5 65 66 67 68  1 70 71
 72 73 74 75 76 77 78 79 47 76 56 26 19 29 42 54 15 30 64 76 77 57  7 35
 70 28 49 37  2 52 52 16 16 52  2 16 25 28 39 52 24 52 42 52 16 52 28 57
 24 31 76 36 52 28 24 16 42 52 16 16 31 30 52  2 16 30 24 30 36 24 16 24
 40 28 25 12 25 16 16 24 60 24 24 52 27 36 52 12 52 39 12 39 36 36 39 24
 24 27 16 30 24  2 16 39 30 24 12 25 60 60 30 16 12 16 31 42 60 24 24 36
 25 16 52 12 24 36 52  2 56 52 22 40 52 40 61 40 40 69  6  1 40 69 40 60
 12 40 40 69 12 24  6 31 61 40 12  6 60 16 35 16 60 60 25 40 60 40 25 29
 25 16 60 24 52 22  6 29 40 12 60 40 60 40 60 52 40 52 35  6 52 60 61 60
 24 60 12 60 69  6 40 40 40 36 29 16 60 60 12 40 52 29 29 40 29 68 16 60
 60 35 40  6 60 69 25 36 52 60 22 29]


array([[ 0],
       [ 1],
       [45],
       [ 3],
       [ 4],
       [ 5],
       [16],
       [ 7],
       [ 8],
       [ 9],
       [10],
       [11],
       [22],
       [13],
       [14],
       [77],
       [36],
       [17],
       [18],
       [68],
       [20],
       [21],
       [38],
       [23],
       [28],
       [22],
       [26],
       [27],
       [10],
       [69],
       [30],
       [56],
       [32],
       [33],
       [34],
       [35],
       [36],
       [37],
       [38],
       [76],
       [ 1],
       [41],
       [33],
       [43],
       [44],
       [76],
       [46],
       [47],
       [48],
       [49],
       [50],
       [51],
       [24],
       [53],
       [54],
       [55],
       [56],
       [57],
       [58],
       [59],
       [69],
       [35],
       [62],
       [63],
       [ 5],
       [65],
       [66],
       [67],
       [68],
       [ 1],
       [70],
       [71],
       [72],
       [73],
       [74],
       [75],
       [76],