In [1]:
import functools as ft
import math

import numpy as np

import jax
from jax import jit, lax, vmap
import jax.numpy as jnp
import jax.random as jrand
import tensorflow_probability.substrates.jax.distributions as jd

from sklearn.cluster._kmeans import kmeans_plusplus

In [2]:

@ft.partial(jit, static_argnums=[1])
def _row_norms(x, squared=False):
    norms = jnp.einsum('ij, ij->i', x, x)

    if not squared:
        norms = jnp.sqrt(norms)
    return norms

In [3]:
@ft.partial(jit, static_argnums=[2, 3, 4])
def _euclidean_distances(x, y, x_norm_squared=None, y_norm_squared=None, squared=False):
    """Computational part of euclidean_distances

    Assumes inputs are already checked.

    If norms are passed as float32, they are unused. If arrays are passed as
    float32, norms needs to be recomputed on upcast chunks.
    """
    if x_norm_squared is not None:
        xx = x_norm_squared.reshape(-1, 1)
    else:
        xx = _row_norms(x, squared=True)[:, jnp.newaxis]

    if y is x:
        yy = xx.T
    else:
        if y_norm_squared is not None:
            yy = y_norm_squared.reshape(1, -1)
        else:
            yy = _row_norms(y, squared=True)[jnp.newaxis, :]

    distances = -2 * jnp.dot(x, y.T)
    distances += xx
    distances += yy
    distances = jnp.maximum(distances, 0)

    if squared:
        return distances
    else:
        return jnp.sqrt(distances)

In [4]:
@ft.partial(jit, static_argnums=[1, 2])
def _kmeans_plusplus(x, n_clusters, x_squared_norms=None, *, key=None):
    n_samples, n_features = x.shape

    n_clusters = min(n_samples, n_clusters)

    k0, k1 = jrand.split(key)

    def euclidean_distance_square(x, y):
        return jnp.square(x - y).sum(axis=-1)

    center_id = jrand.randint(k0, (1,), minval=0, maxval=n_samples)[0]

    initial_center = x[center_id]

    initial_indice = center_id

    initial_closest_dist_sq = euclidean_distance_square(x[center_id], x)


    current_pot = initial_closest_dist_sq.sum()

    def _step(carry, inp=None):
        (_current_pot, _closest_dist_sq, key) = carry

        k0, k1 = jrand.split(key)

        candidate_ids = jd.Categorical(logits=jnp.log(_closest_dist_sq)).sample(
            seed=k0, sample_shape=(int(math.log(n_clusters) + 2), ))

        # rand_vals = jrand.uniform(key=key, shape=(
        #     int(math.log(n_clusters)) + 2,)) * _current_pot

        # candidate_ids = jnp.searchsorted(
        #     jnp.cumsum(_closest_dist_sq), rand_vals)

        candidate_ids = jnp.clip(
            candidate_ids, a_max=n_samples - 1)

        # Compute distances to center candidates
        distance_to_candidates = vmap(lambda x, y: euclidean_distance_square(x, y),
                                      in_axes=(0, None))(x[candidate_ids], x)
        # distance_to_candidates = vmap(lambda x, y: _euclidean_distances(x, y),
        #                               in_axes=(None, 0))(x[candidate_ids], x)

        # distance_to_candidates = _euclidean_distances(
        #     x[candidate_ids], x, y_norm_squared=x_squared_norms, squared=True
        # )

        distance_to_candidates = vmap(jnp.minimum, in_axes=(0, None))(
            distance_to_candidates, _closest_dist_sq)

        # distance_to_candidates = jnp.minimum(
        #     _closest_dist_sq, distance_to_candidates)
        candidates_pot = distance_to_candidates.sum(axis=-1)

        # Decide which candidate is the best
        best_candidate = jnp.argmin(candidates_pot)
        _current_pot = candidates_pot[best_candidate]
        _closest_dist_sq = distance_to_candidates[best_candidate]
        best_candidate = candidate_ids[best_candidate]

        carry = (_current_pot, _closest_dist_sq, k1)

        return carry, (x[best_candidate], best_candidate)

    init = (current_pot, initial_closest_dist_sq, k1)
    _, (centers, indices) = lax.scan(
        _step, init, xs=None, length=n_clusters - 1)

    centers = jnp.vstack([initial_center, centers])
    indices = jnp.vstack([initial_indice, indices])

    return centers, indices

In [36]:
X = np.array([[1., 2.], [1., 4.], [1., 0.],
            [10., 2.], [10., 4.], [10., 0.]])

In [5]:
k = jrand.PRNGKey(0)

In [6]:
X = jrand.normal(k, shape=(1000, 100)) * 10.

In [14]:
%timeit -n100 -r10 _kmeans_plusplus(X, 2, key=k)

151 µs ± 4.41 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)


In [15]:
%timeit -n100 -r10 kmeans_plusplus(np.array(X), 2)

5.64 ms ± 389 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)
