# The `Cluster` Module in `TabuLLM`: Spherical K-Means for L2-Normalized Text Embeddings

## Why Clustering?

A common complain about embeddings is their lack of explainability, which is part of the broader 'black-box' narrative about ML/AI algorithms. This is especially acute given that embedding vectors produced by modern LLMs are often high-dimensional. For instance, OpenAI's `text-embedding-3-large` outputs a vector of length 3072! In many predictive applications, including a such high-dimensional vector would lead to several issues. First is that the vector consumes too many degrees of freedom and including the full vector may be impractical, especially for small datasets. Secondly, interpreting - or explaining - the resulting model would be challenging.

One way to validate or 'explain' the embeddings is to use them to cluster the observations, and examine the resulting clusters. Cluster labels can also be used as a categorical feature in downstream predictive models. As such, clusterng can server as a dimensionality-reduction technique. These are the motivations behind the `cluster` and `explain` modules of `TabuLLM`.

## Why *Spherical* K-Means?

A useful clustering technique that is often used with text embeddings is k-means. In its standard form, k-means uses the Euclidean - or L2 norm - to measure the distance between a pair of vectors. This is is how data points are assigned to cluster centroids in the Lloyd's algorithm for training k-means. Correspondingly, in the centroid-update step, the mean of vectors mapped to each centroid are calculated to update the centroid vector. The mean vector minimizes the total L2 distance from all cluster members.

Text embeddings, however, only contain directional information, and no magnitude. This means they are either L2 normalized, or must be L2 normalized by the user. This can be seen by examining the output of sentence transformer models:

In [5]:
from TabuLLM.embed import TextColumnTransformer
import pandas as pd

df = pd.read_csv('../data/raw.csv')

X = TextColumnTransformer(
    model_type = 'st'
).fit_transform(df[['diagnoses']])
#print((X**2).sum(axis=1)) # should be all 1's



0      1.0
1      1.0
2      1.0
3      1.0
4      1.0
      ... 
825    1.0
826    1.0
827    1.0
828    1.0
829    1.0
Length: 830, dtype: float32


Given the above property, we propose that the distance metric used for applying K-means to L2-normalized embedding vectors must be the cosine distance, rather than L2 norm. It must be noted that, during the assignment step of Lloyd's algorithm, cosine and L2 distances would produce identical outcomes, since they would rank any collectiond of pairs of points the same way.

In the centroid update step, the cosine-distance approach would take the mean of all cluster members and apply an L2 normalization. This last step creates a divergence between standard k-means and spherical k-means.

## The `SphericalKMeans` Class

The `SphericalKMeans` class in `TabuLLM` implements the familiar interface of `KMeans` in the `scikit-learn` package, including class constructor arguments and public methods. Besides the key difference with standard k-means that is using cosine distance instead of L2 distance, there are a few differences in the implementation of `SphericalKMeans` that are worth highlighting:
1. Applying unique to rows of X before random subset selection (for initialization)
1. Handling of empty clusters

Other topics to comment on:
- Speed vs. standard k-means
- Proof of superiority (or lack thereof) in clustering embeddings
- Acknowledge that there are other clustering techniques besides k-means, perhaps argue why k-means is still a better choice than seemingly more advanced techniques.
- Discuss the nuanced difference between fit_predict and fit_transform

### Cluster Consistency and Sensitivity to Initialization

A major concern with kmeans - or other clustering techniques that require some form of cluster initialization - is that the final results are sensitive to the starting point. Similar to scikit-learn's `KMeans`, `SphericalKMeans` allows users to perform multiple runs with different random initializations (via `n_init`), and use the best result, i.e., one with lowest total within-cluster distance (aka *inertia*).

It will be interesting to study how changing `n_init` impacts cluster stability. To quantify the latter, we use two metric, adjusted rand index (ARI) and adjusted mutual information (AMI), both available in scikit-learn. We calculate ARI and AMI between multiple runs of `SphericalKMeans`, and while changing `n_init` from one set of runs to the next. We hope to see that as `n_init` is increased, so does the average ARI/AMI values amongst pairs of runs.

In [9]:
import numpy as np
from TabuLLM.cluster import SphericalKMeans
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score

n_init_list = [1, 3, 10, 30, 100]
nrun = 10
ari_vec = np.zeros(len(n_init_list))
ami_vec = np.zeros(len(n_init_list))
for idx, n_init in enumerate(n_init_list):
    print(f'n_init = {n_init}')
    # initialize a numpy matrix to hold all X's
    labels_all = np.zeros((df.shape[0], nrun))
    for i in range(nrun):
        labels = SphericalKMeans(n_clusters=10, n_init=n_init).fit_predict(X)
        labels_all[:, i] = labels
    # calculate adjusted rand index between all pairs of columns in X_all
    ari = 0.0
    ami = 0.0
    for i in range(nrun):
        for j in range(i+1, nrun):
            ari = ari + adjusted_rand_score(labels_all[:, i], labels_all[:, j])
            ami = ari + adjusted_mutual_info_score(labels_all[:, i], labels_all[:, j])
    ari = ari / (nrun * (nrun - 1) / 2)
    ari_vec[idx] = ari
    ami = ami / (nrun * (nrun - 1) / 2)
    ami_vec[idx] = ami

    print(f'ARI = {ari}, AMI = {ami}')

n_init = 1
ARI = 0.5794053807032611, AMI = 0.592852688790101
n_init = 3
ARI = 0.6350779645549716, AMI = 0.6499854324630371
n_init = 10
ARI = 0.7423067530123156, AMI = 0.759498012313953
n_init = 30
ARI = 0.774646284122888, AMI = 0.7921260431790932
n_init = 100
ARI = 0.7546001383487387, AMI = 0.7710290978101386


As can be seen above, somewhere between 30 and 100 initializations is likely to be sufficient to achieve maximum cluster stability.

### Hard vs. Soft Clusters as Features

A fitted `SphericalKMeans` object has both a `transform` method and a `predict` method, which produce hard and soft clusters, respectively:

In [11]:
skmeans = SphericalKMeans(n_clusters=10, n_init=30).fit(X)
clusters_soft = skmeans.transform(X)
clusters_hard = skmeans.predict(X)
print(f'Shape of clusters_soft: {clusters_soft.shape}')
print(f'Shape of clusters_hard: {clusters_hard.shape}')

Shape of clusters_soft: (830, 10)
Shape of clusters_hard: (830,)


We can use both hard and soft clusters as features in a prediction model. We start with the soft clusters, and first prepare the data:

In [24]:
varnames_cluster_soft = ['X' + str(n) for n in range(clusters_soft.shape[1])]
varnames_baseline = ['is_female', 'age', 'height', 'weight', 'optime']
dfCluster_soft = pd.DataFrame(clusters_soft, columns=varnames_cluster_soft)
dfCombined_soft = pd.concat([df, dfCluster_soft], axis=1)
X_soft, y = dfCombined_soft[varnames_baseline + varnames_cluster_soft], dfCombined_soft['aki_severity']

dfCluster_hard = pd.DataFrame(clusters_hard, columns=['cluster'])
dfCombined_hard = pd.concat([df, dfCluster_hard], axis=1)
X_hard = dfCombined_hard[varnames_baseline + ['cluster']]

We can now train a classifier, such as logistic regression:

In [31]:
from sklearn.linear_model import LogisticRegression
obj = LogisticRegression(max_iter=10, solver = 'newton-cholesky', penalty=None).fit(X_soft, y)

As for the had clusters, we note that they are categorical labels, even if displayed as integers. We therefore use one-hot encoding to transform them to binary dummies:

In [32]:
# one-hot encoder followed by logistic regression
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

preprocessor = ColumnTransformer(
    transformers=[
        ('num', 'passthrough', varnames_baseline),
        ('cat', OneHotEncoder(drop = 'first'), ['cluster'])
    ]
    , remainder='drop'
)
clf = Pipeline(steps=[('preprocessor', preprocessor),
                      ('classifier', LogisticRegression(max_iter=10, solver = 'newton-cholesky', penalty=None))])
clf.fit(X_hard, y)

## Recap

We motivated the application of clustering to the text embeddings to facilitate their interpretation (more on this in the next tutorial), and also as dimensionality reduction for predictive models. We also argued why spherical k-means is more appropriate for L2-normalized text embeddings. Finally, we showed how hard and soft clusters can be used as features in a predictive model alongside other variables.