In [5]:
from sklearn.metrics import euclidean_distances
import numpy as np

In [2]:
a = np.array([[1, 0], [0,1]], dtype=np.float32)


In [108]:
X = np.ones([2, 100])
Y = np.zeros([1, 100])

In [8]:
euclidean_distances(X, Y)

array([[3.74165739]])

In [14]:
14 / np.sqrt(14)

3.7416573867739413

In [15]:
import jax
import jax.numpy as jnp

import equinox as eqx
from equinox.jit import filter_jit

import numpy as np


def _row_norms(X, squared=False):
    norms = jnp.einsum('ij, ij->i', X, X)

    if not squared:
        norms = jnp.sqrt(norms)

    return norms

In [17]:
_row_norms(X, squared=True)

DeviceArray([14], dtype=int32)

In [18]:
from sklearn.utils.extmath import row_norms

In [22]:
row_norms(X, squared=True)

array([14.])

In [24]:
jnp.maximum(X, 0, out=X)

TypeError: _one_to_one_binop.<locals>.<lambda>() got an unexpected keyword argument 'out'

In [62]:
from functools import partial

In [77]:
@partial(jax.jit, static_argnums=[1])
def _row_norms(x, squared=False):
    norms = jnp.einsum('ij, ij->i', x, x)

    if not squared:
        norms = jnp.sqrt(norms)
    return norms


@partial(jax.jit, static_argnums=[2, 3, 4])
def _euclidean_distances(x, y, x_norm_squared=None, y_norm_squared=None, squared=False):
    """Computational part of euclidean_distances

    Assumes inputs are already checked.

    If norms are passed as float32, they are unused. If arrays are passed as
    float32, norms needs to be recomputed on upcast chunks.
    """
    if x_norm_squared is not None:
        xx = x_norm_squared.reshape(-1, 1)
    else:
        xx = _row_norms(x, squared=True)[:, jnp.newaxis]

    if y is x:
        yy = xx.T
    else:
        if y_norm_squared is not None:
            yy = y_norm_squared.reshape(1, -1)
        else:
            yy = _row_norms(y, squared=True)[jnp.newaxis, :]

    distances = -2 * jnp.dot(x, y.T)
    distances += xx
    distances += yy
    distances = jnp.maximum(distances, 0)

    if squared:
        return distances
    else:
        return jnp.sqrt(distances)

In [64]:
_euclidean_distances(X, Y)

DeviceArray([[10.]], dtype=float32)

In [87]:
import jax.random as jrand
from jax import lax

In [276]:
@partial(jax.jit, static_argnums=[1, 2])
def _kmeans_plusplus(x, n_clusters, x_squared_norms=None, *, key=None):

    x = jnp.asarray(x)

    n_samples, n_features = x.shape

    centers = jnp.zeros((n_clusters, n_features))

    center_id = jrand.randint(key, (1, ), minval=0, maxval=n_samples)[0]
    indices = jnp.full((n_clusters, ), -1, dtype=jnp.int32)

    centers = centers.at[0].set(x[center_id])
    indices = indices.at[0].set(center_id)

    closest_dist_sq = _euclidean_distances(centers[0, jnp.newaxis],
                                           x,
                                           y_norm_squared=x_squared_norms,
                                           squared=True)

    current_pot = closest_dist_sq.sum()

    def _step(carry, inp=None):
        (i, _centers, _indices, _current_pot, _closest_dist_sq) = carry

        rand_vals = jrand.uniform(key=key, shape=(n_clusters + 2, )) * _current_pot

        candidate_ids = jnp.searchsorted(jnp.cumsum(_closest_dist_sq), rand_vals)

        candidate_ids = jnp.clip(candidate_ids, None, _closest_dist_sq.size - 1)

        # Compute distances to center candidates
        distance_to_candidates = _euclidean_distances(
            x[candidate_ids], x, y_norm_squared=x_squared_norms, squared=True
        )

        distance_to_candidates = jnp.minimum(_closest_dist_sq, distance_to_candidates)
        candidates_pot = distance_to_candidates.sum(axis=1)

        # Decide which candidate is the best
        best_candidate = jnp.argmin(candidates_pot)
        _current_pot = candidates_pot[best_candidate]
        _closest_dist_sq = distance_to_candidates[best_candidate]
        best_candidate = candidate_ids[best_candidate]

        _centers = _centers.at[i].set(x[best_candidate])
        _indices = _indices.at[i].set(best_candidate)

        carry = (i + 1, _centers, _indices, _current_pot, _closest_dist_sq)

        return carry, (_centers, _indices)

    init = (1, centers, indices, current_pot, closest_dist_sq[0])
    _, (centers, indices) =  lax.scan(_step, init, xs=None, length=n_clusters - 1)

    return centers, indices

In [277]:
X = np.array([[1., 2.], [1., 4.], [1., 0.],
               [10., 2.], [10., 4.], [10., 0.]])

In [238]:
key = jrand.PRNGKey(0)

In [281]:
_kmeans_plusplus(X, 2, key=jrand.PRNGKey(1024))

(DeviceArray([[[10.,  4.],
               [ 1.,  2.]]], dtype=float32),
 DeviceArray([[4, 0]], dtype=int32))

In [79]:
%timeit -n100 -r3 _euclidean_distances(X, Y)

6.55 µs ± 2.97 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [75]:
%timeit -n100 -r3 euclidean_distances(X, Y)

29 µs ± 1.32 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [83]:
z = jnp.array(X)

In [84]:
z.size

100

In [137]:
jrand.uniform(key=key, shape=(2, ))

DeviceArray([0.21629536, 0.8041241 ], dtype=float32)