Just-in-time compiled, and accelerated ⚡ implementation of Mini-Batch KMeans[1]
- JAX 😎 >= 0.3.17
git clone https://github.com/GiulioZani/jax-min-batch-kmeans
cd jax-mini-batch-kmeans
from mini_batch_kmeans import MiniBatchKMeans
def main():
xs = # a 2D array of shape (number of samples, number of features)
mini_batch_kmeans = MiniBatchKMeans(
xs, # can be a numpy or jax array
k=4, # number of clusters
batch_size=1000, # batch size
iter=1000, # number of iterations
random_state=0
)
mini_batch_kmeans.fit()
print(f"{mini_batch_kmeans.centroids=}")
[1] D. Sculley. 2010. Web-scale k-means clustering. In Proceedings of the 19th international conference on World wide web (WWW '10). Association for Computing Machinery, New York, NY, USA, 1177–1178. https://doi.org/10.1145/1772690.1772862