In [13]:
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import Button, Checkbox, IntSlider, IntText, Play, interactive, jslink
from sklearn.datasets import make_blobs

In [5]:
X, y = make_blobs(
    centers=4,
    cluster_std=2,
    random_state=42,
)


def _show(color: bool) -> None:
    fig, ax = plt.subplots()
    if color:
        ax.scatter(X[:, 0], X[:, 1], c=y, marker=".")
    else:
        ax.scatter(X[:, 0], X[:, 1], marker=".")
    ax.set(title=f"Raw data {len(X):,}")


interactive(
    _show,
    color=Checkbox(False, description="Show Cluster?"),
)

interactive(children=(Checkbox(value=False, description='Show Cluster?'), Output()), _dom_classes=('widget-int…

In [14]:
def kmeans_step(
    data: np.ndarray,
    current_centroids: np.ndarray,
) -> tuple[np.ndarray, float, np.ndarray]:
    n_points, n_dims = data.shape
    k = len(current_centroids)

    # take current centroids, and calculate full distances
    sq_d = ((data[np.newaxis, :, :] - current_centroids[:, np.newaxis, :]) ** 2).sum(axis=-1)

    # assign points to centroids based on distances
    clusters = sq_d.argmin(axis=0)

    # calculate new centroids
    l = []
    for _c in range(k):
        mask = clusters == _c
        if not np.any(mask):
            cluster_centroid = current_centroids[_c]
        else:
            cluster_centroid = data[clusters == _c, :].mean(axis=0)
        l.append(cluster_centroid)
    new_centroids = np.array(l)

    # calculate total movement of centroids and return that along with new positions
    movement = np.linalg.norm(new_centroids - current_centroids)

    return new_centroids, movement, clusters


def kmeans_train(
    k: int,
    data: np.ndarray,
    n_iter: int = 3,
    random_state: int = 42,
) -> tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(random_state)

    data_min = data.min(axis=0)
    data_range = data.max(axis=0) - data_min
    x = (data - data_min) / data_range

    n_points, n_dims = x.shape
    res_centers = np.zeros(shape=(n_iter, k, n_dims))
    res_clusters = np.zeros(shape=(n_iter, n_points))

    centers = rng.choice(x, k, replace=False, axis=0)
    for i in range(n_iter):
        new_centers, movement, clusters = kmeans_step(x, centers)
        res_centers[i, ...] = centers
        res_clusters[i, ...] = clusters
        centers = new_centers
    cents_orig = res_centers * data_range + data_min
    return res_clusters, cents_orig


def kmeans_predict(
    data: np.ndarray,
    centroids: np.ndarray,
) -> np.ndarray:
    n_points, n_dims = data.shape
    k, n_dims2 = centroids.shape
    assert n_dims == n_dims2

    sq_d = ((data[np.newaxis, :, :] - centroids[:, np.newaxis, :]) ** 2).sum(axis=-1)

    # assign points to centroids based on distances
    clusters = sq_d.argmin(axis=0)

    return clusters


T = 20
k = 4
clusts, cents = kmeans_train(k, X, T)

In [27]:
h = 0.01
x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"]
cmap = mcolors.ListedColormap(colors)
assert k <= len(colors)


def simulate(t: int, boundary: bool) -> None:
    cls = clusts[t, ...]
    cs = cents[t, ...]

    fig, ax = plt.subplots()

    if boundary:
        z = kmeans_predict(np.c_[xx.ravel(), yy.ravel()], cs)
        z = z.reshape(xx.shape)
        # ax.contourf(xx, yy, z, cmap=cmap, alpha=0.3)
        ax.imshow(
            z,
            origin="lower",
            extent=(xx.min(), xx.max(), yy.min(), yy.max()),
            cmap=cmap,
            alpha=0.3,
            interpolation="nearest",
        )
        ax.set_aspect("auto")

    ax.scatter(X[:, 0], X[:, 1], c=cls, marker=".", cmap=cmap)
    ax.scatter(cs[:, 0], cs[:, 1], color="black", marker="o")

    ax.set(
        title=t,
    )

If the initial centroids are badly placed, the algorithm will fail to converge to good clusters.

In [29]:
play = Play(value=0, min=0, max=T - 1, step=1, interval=500)
slider = IntSlider(value=0, min=0, max=T - 1, step=1)
jslink((play, "value"), (slider, "value"))
w = interactive(
    simulate,
    t=play,
    boundary=Checkbox(True, description="Show decision boundaries?"),
)
display(slider, w)

IntSlider(value=0, max=19)

interactive(children=(Play(value=0, description='t', interval=500, max=19), Checkbox(value=True, description='…