In [1]:
import numpy as np
from sklearn.neighbors import NearestNeighbors
from dataset import gaussian_blobs
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [2]:
def kmeans_pp_init(X, n_clusters, p=2, seed=None):

    rng = np.random.default_rng(seed=seed)
    n, p = X.shape
    indices = np.arange(n)

    centers_array = np.zeros((n_clusters, p))
    weights = np.ones(n) / n

    for i in range(n_clusters):
        center_ind = rng.choice(indices, p=weights)
        center = X[center_ind]
        centers_array[i] = center

        neighs = NearestNeighbors(n_neighbors=1, p=p)
        neighs.fit(centers_array[:i+1])
        samples_distance, _ = neighs.kneighbors(X=X)
        weights = samples_distance / samples_distance.sum()
        weights = weights.reshape(-1)

    return centers_array

def random_init(X, n_clusters, p=2, seed=0):

    rng = np.random.default_rng(seed=seed)
    n, p = X.shape
    indices = np.arange(n)

    random_ind = rng.choice(indices, n_clusters, replace=False)
    return X[random_ind]

In [3]:
samples = gaussian_blobs(d=1.5)

In [4]:
kmeanspp_centers = kmeans_pp_init(samples, 5, seed=5)
random_centers = random_init(samples, 5, seed=10)

In [5]:
fig = make_subplots(rows=1, cols=2, subplot_titles=["kmeans++", "random initialization"])
labels = (1 + np.arange(kmeanspp_centers.shape[0])).astype(str)
text = [f"<b>{label}</b>" for label in labels]
trace_X = go.Scatter(x=samples[:, 0], y=samples[:, 1], mode="markers", marker=dict(color="#a5b3cf"), showlegend=False)
fig.add_traces([trace_X, trace_X], rows=1, cols=[1, 2])

fig.add_traces(px.scatter(x=kmeanspp_centers[:, 0], y=kmeanspp_centers[:, 1], color=labels, text=text).data, rows=1, cols=1)
fig.add_traces(px.scatter(x=random_centers[:, 0], y=random_centers[:, 1], color=labels, text=text).data, rows=1, cols=2)

fig.update_layout(title="Initialization", margin={"t":50, "r":5, "l":5, "b":5}, width=700, height=300, showlegend=False)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_traces(textposition='top center')
fig.update_annotations(font=dict())
