# K-Means étape par étape

In [1]:
from IPython.display import clear_output

In [2]:
import time

from scipy.spatial.distance import euclidean
import numpy as np

from sklearn import datasets
from sklearn.cluster import KMeans

from ipywidgets import IntSlider, HBox, VBox, Button, Output, Play, jslink, Layout
from ipywidgets import Label as widgets_label

from bqplot import (
    LogScale, LinearScale, OrdinalColorScale, ColorAxis,
    Axis, Scatter, Lines, CATEGORY10, Label, Figure, Tooltip
)

In [3]:
output = Output()

In [4]:
current_k = 5

In [5]:
initial_k=5

In [6]:
n_samples = 2000

In [7]:
blobs = datasets.make_blobs(n_samples=n_samples,centers=initial_k, random_state=20)

In [8]:
data_tab = blobs[0]

In [9]:
x_sc = LinearScale(min=blobs[0][:,0].min()-2, max=blobs[0][:,0].max()+2)
y_sc = LinearScale(min=blobs[0][:,1].min()-2, max=blobs[0][:,1].max()+2)
c_sc = OrdinalColorScale(domain=np.arange(10).tolist(), colors=CATEGORY10[:6])

In [10]:
ax_y = Axis(label='y', scale=y_sc, orientation='vertical', side='left', grid_lines='solid')
ax_x = Axis(label='x', scale=x_sc, grid_lines='solid')

In [11]:
scat = Scatter(x=blobs[0][:,0],
               y=blobs[0][:,1],
               scales={'x': x_sc, 'y': y_sc, 'color': c_sc}, color=[1])

In [12]:
scat_centroids = Scatter(x=[0], y=[0], scales={"x": x_sc, "y": y_sc}, colors=["black"])

In [13]:
fig = Figure(marks=[scat, scat_centroids], title='K-means', animation_duration=10, axes=[ax_x, ax_y])

In [14]:
inertia = 0.

In [15]:
start_button = Button(description="Start")
reset_button = Button(description="Reset")
next_button = Button(description="Next step")

In [16]:
play_button = Button(description="Play animation")

In [17]:
def affect_members(member, centroids):
    member_old = member.copy()
    
    for i, x in enumerate(data_tab):
        dist_x = euclidean(x,centroids[current_k-1])
        member[i] = current_k-1
        for j in range(current_k-1):
            if euclidean(x,centroids[j]) < dist_x:
                dist_x = euclidean(x,centroids[j])
                member[i] = j
    
    scat.color = member_old
    
    scat.color = member
    
    with output:
        clear_output()
        print("Members affected")
    
    return member

In [18]:
def update_centroids(member, centroids):
    centroids_old = centroids.copy()
    
    #compute the new centroids
    for j in range(current_k):
        centroids[j] = np.asarray([elt for i, elt in enumerate(data_tab) if member[i]==j]).mean(axis=0)
    
    scat_centroids.x = centroids_old[:, 0]
    scat_centroids.y = centroids_old[:, 1]
    
    scat_centroids.x = centroids[:, 0]
    scat_centroids.y = centroids[:, 1]
    
    with output:
        clear_output()
        print("centroids updated")
    
    return centroids

In [19]:
centroids_i = np.random.choice(n_samples,current_k,replace=False)
centroids = blobs[0][centroids_i]
cmpt = 0

member=np.empty_like(blobs[1])

substep1 = True

scat_centroids.x = centroids[:, 0]
scat_centroids.y = centroids[:, 1]

def init_or_reset():
    global centroids
    global cmpt
    global member
    global substep1
    
    member=np.empty_like(blobs[1])
    centroids_i = np.random.choice(n_samples,current_k,replace=False)
    centroids = blobs[0][centroids_i]
    cmpt = 0
    substep1 = True
       
    scat_centroids.x = centroids[:, 0]
    scat_centroids.y = centroids[:, 1]
    scat.color = [0]
    
    with output:
        clear_output()
        print("centroids initialized")
        
# def on_start_clicked(b):
#     display(button, reset_button, output)
#     init_or_reset()
    
def on_next_step_clicked(b):
    global member
    global centroids
    global cmpt
    global substep1

    if substep1:
        member = affect_members(member, centroids)
        substep1 = False
    else:
        centroids = update_centroids(member, centroids)
        substep1 = True
    
    # Computing inertia
    inertia = 0.
    for i, x in enumerate(data_tab):
        inertia+=euclidean(x,centroids[member[i]])**2
    
    with output:
        print(f"inertia: {inertia}")
    
    return inertia
            
def on_reset_clicked(b):
    init_or_reset()

In [20]:
def on_play_clicked(b):
    old_inertia = 0
    while True:
        time.sleep(1)
        new_inertia = on_next_step_clicked(b)
        if abs(new_inertia - old_inertia) < 10e-5:
            with output:
                print("animation finished")
            break
        old_inertia = new_inertia

In [21]:
# start_button.on_click(on_start_clicked)
next_button.on_click(on_next_step_clicked)
reset_button.on_click(on_reset_clicked)
play_button.on_click(on_play_clicked)

In [22]:
VBox([HBox([play_button]), HBox([next_button, reset_button]), fig, output], layout=Layout(height="800px"))

VBox(children=(HBox(children=(Button(description='Play animation', style=ButtonStyle()),)), HBox(children=(But…