In [None]:
from time import sleep

import numpy as np
import pandas as pd

from ipywidgets import *

from bqplot import OrdinalColorScale, CATEGORY10
import bqplot.pyplot as plt

In [None]:
n_slider = IntSlider(description='points', value=100, min=20, max=200, step=10)
k_slider = IntSlider(description='K', value=3, min=2, max=10)
iter_label_tmpl = 'Iterations: {}'
iter_label = Label(value=iter_label_tmpl.format(''))

fig = plt.figure(title='K-means Clustering', animation_duration=1000)
fig.layout.width = '700px'
fig.layout.height = '600px'

plt.scales(scales={'color': OrdinalColorScale(colors=CATEGORY10)})
axes_options = {'x': {'label': 'X1'},
                'y': {'label': 'X2'},
                'color': {'visible': False}}
points_scat = plt.scatter([], [], color=[], stroke='black',
                          axes_options=axes_options)
centroid_scat = plt.scatter([], [], color=[], stroke_width=4, 
                            fill=False, default_size=300,
                            axes_options=axes_options)
go_btn = Button(description='GO', button_style='success', 
                layout=Layout(width='50px'))

def clear():
    # clear all
    with points_scat.hold_sync():
        points_scat.x = []
        points_scat.color = []
    
    with centroid_scat.hold_sync():
        centroid_scat.x = []
        centroid_scat.color = []
    
    iter_label.value = iter_label_tmpl.format('')
    
def start_animation():
    go_btn.disabled = True
    clear()
    
    # get the values of parameters from sliders
    n = n_slider.value
    K = k_slider.value
    
    # features (randomly generated)
    X = pd.DataFrame(np.random.rand(n, 2), columns=['X1', 'X2'])
    
    # randomly assign clusters to start with
    # FIXME: not always K clusters created!
    clusters = np.arange(K)[np.random.randint(0, K, n)]
    
    with points_scat.hold_sync():
        points_scat.x = X['X1']
        points_scat.y = X['X2']
        points_scat.color = clusters
    
    i = 0
    while True:
        iter_label.value = iter_label_tmpl.format(i + 1)
        
        # color code the points by their clusters
        points_scat.color = clusters

        # compute cluster centroids
        centroids = X.groupby(clusters).mean()
        with centroid_scat.hold_sync():
            centroid_scat.x = centroids['X1']
            centroid_scat.y = centroids['X2']
            centroid_scat.color = list(centroids.index)

        # reassign clusters to points based on the closest cluster centroid
        # TBD: sometimes you will get fewer than K clusters. imrpove the algo
        new_clusters = X.apply(lambda x: np.argmin(np.linalg.norm(x.values - centroids, axis=1)), axis=1)

        # if the new clusters are same as old clusters stop
        if np.all(new_clusters == clusters):
            break
        else: # update the clusters with new clusters
            clusters = new_clusters
            i = i + 1
            sleep(1.5)
    
    go_btn.disabled = False

go_btn.on_click(lambda btn: start_animation())

controls_layout = VBox([n_slider, k_slider, go_btn, iter_label])
controls_layout.layout.margin = '60px 0px 0px 0px'

HBox([VBox([fig]), controls_layout])