# The `Cluster` Module in `TabuLLM`: Spherical K-Means for L2-Normalized Text Embeddings

## Why Clustering?

A common complain about embeddings is their lack of explainability, which is part of the broader 'black-box' narrative about ML/AI algorithms. This is especially acute given that embedding vectors produced by modern LLMs are often high-dimensional. For instance, OpenAI's `text-embedding-3-large` outputs a vector of length 3072! In many predictive applications, including a such high-dimensional vector would lead to several issues. First is that the vector consumes too many degrees of freedom and including the full vector may be impractical, especially for small datasets. Secondly, interpreting - or explaining - the resulting model would be challenging.

One way to validate or 'explain' the embeddings is to use them to cluster the observations, and examine the resulting clusters. Cluster labels can also be used as a categorical feature in downstream predictive models. As such, clusterng can server as a dimensionality-reduction technique. These are the motivations behind the `cluster` and `explain` modules of `TabuLLM`.

## Why *Spherical* K-Means?

A useful clustering technique that is often used with text embeddings is k-means. In its standard form, k-means uses the Euclidean - or L2 norm - to measure the distance between a pair of vectors. This is is how data points are assigned to cluster centroids in the Lloyd's algorithm for training k-means. Correspondingly, in the centroid-update step, the mean of vectors mapped to each centroid are calculated to update the centroid vector. The mean vector minimizes the total L2 distance from all cluster members.

Text embeddings, however, only contain directional information, and no magnitude. This means they are either L2 normalized, or must be L2 normalized by the user. This can be seen by examining the output of sentence transformer models:

In [4]:
from TabuLLM.embed import TextColumnTransformer
import pandas as pd

df = pd.DataFrame({
    'text': ['hello world', 'goodbye world', 'hello goodbye']
})

X = TextColumnTransformer(
    type = 'st'
    , embedding_model_st = 'sentence-transformers/all-MiniLM-L6-v2'
).fit_transform(df)
print((X**2).sum(axis=1)) # should be all 1's



0    1.0
1    1.0
2    1.0
dtype: float32


Given the above property, we propose that the distance metric used for applying K-means to L2-normalized embedding vectors must be the cosine distance, rather than L2 norm. It must be noted that, during the assignment step of Lloyd's algorithm, cosine and L2 distances would produce identical outcomes, since they would rank any collectiond of pairs of points the same way.

In the centroid update step, the cosine-distance approach would take the mean of all cluster members and apply an L2 normalization. This last step creates a divergence between standard k-means and spherical k-means.

## The `SphericalKMeans` Class

The `SphericalKMeans` class in `TabuLLM` implements the familiar interface of `KMeans` in the `scikit-learn` package, including class constructor arguments and public methods. Besides the key difference with standard k-means that is using cosine distance instead of L2 distance, there are a few differences in the implementation of `SphericalKMeans` that are worth highlighting:
1. Applying unique to rows of X before random subset selection (for initialization)
1. Handling of empty clusters

Other topics to comment on:
- Speed vs. standard k-means
- Proof of superiority (or lack thereof) in clustering embeddings
- Acknowledge that there are other clustering techniques besides k-means, perhaps argue why k-means is still a better choice than seemingly more advanced techniques.