Skip to content

SalamanderXing/jax-min-batch-kmeans

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mini-Batch KMeans written in JAX

Just-in-time compiled, and accelerated ⚡ implementation of Mini-Batch KMeans[1]

Requirements

Installation

git clone https://github.com/GiulioZani/jax-min-batch-kmeans

cd jax-mini-batch-kmeans

Usage

from mini_batch_kmeans import MiniBatchKMeans

def main():
   xs = # a 2D array of shape (number of samples, number of features)
   mini_batch_kmeans = MiniBatchKMeans(
	xs, # can be a numpy or jax array
	k=4, # number of clusters
	batch_size=1000, # batch size
	iter=1000, # number of iterations
	random_state=0
   )
   mini_batch_kmeans.fit()

   print(f"{mini_batch_kmeans.centroids=}")

References

[1] D. Sculley. 2010. Web-scale k-means clustering. In Proceedings of the 19th international conference on World wide web (WWW '10). Association for Computing Machinery, New York, NY, USA, 1177–1178. https://doi.org/10.1145/1772690.1772862

Releases

No releases published

Packages

No packages published

Languages