## Install Packages

In [1]:
!conda install -y -c conda-forge faiss-gpu
!apt-get -y update
!apt-get -y install libatlas-base-dev

done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: failed with repodata from current_repodata.json, will retry with next repodata source.
done
Solving environment: done


  current version: 4.10.3
  latest version: 25.7.0

Please update conda by running

    $ conda update -n base -c defaults conda



## Package Plan ##

  environment location: /opt/conda

  added / updated specs:
    - faiss-gpu


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2025.8.3   |       hbd8a1cb_0         151 KB  conda-forge
    certifi-2024.8.30          |     pyhd8ed1ab_0         160 KB  conda-forge
    conda-4.12.0               |   py37h89c1867_0         1.0 MB  conda-forge
    faiss-1.7.1                |py37cuda111h7f21d35_1_cuda         2.0 MB  conda-forge
    faiss-gpu-1.7.1            |       h788eb59_1          15 KB  conda-f

## Load Data

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.random_projection import GaussianRandomProjection

from tqdm import tqdm

import faiss

In [3]:
df = pd.read_csv("story_dataset.csv")
df

Unnamed: 0,prompt_id,prompt,story,hidden_state_file,len_generated_story,len_new_story
0,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Blaz...,./hidden_states/prompt_1.npz,270,271
1,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Spar...,./hidden_states/prompt_1.npz,349,350
2,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Scor...,./hidden_states/prompt_1.npz,278,278
3,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,117,118
4,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,129,130
...,...,...,...,...,...,...
9995,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,289,290
9996,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,119,119
9997,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,127,128
9998,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,441,441


In [4]:
max_story_len = max(df["len_generated_story"])
max_story_len

522

In [5]:
hidden_states_by_layer = {}
NUM_PROMPTS = 10

for prompt_id in range(1, NUM_PROMPTS + 1):
    with np.load(f'./hidden_states/prompt_{prompt_id}.npz') as loaded_data:
        for i in tqdm(range(1000)):
            curr_hidden_states = loaded_data[f"arr_{i}"][0]
#             print(curr_hidden_states.shape)
            for layer in range(2, 3):
                padded_arr = np.zeros((max_story_len, 512))
                padded_arr_len = len(curr_hidden_states[layer][0])
                
                padded_arr[:padded_arr_len] = curr_hidden_states[layer][0]
                
                padded_arr = padded_arr.flatten().astype('float32') #FAISS expects data in type float32 instead of float64 - saves memory too!
#                 print(padded_arr.shape)
                
                if(f"layer_{layer}" in hidden_states_by_layer):
                    hidden_states_by_layer[f"layer_{layer}"].append(padded_arr)
                else:
                    hidden_states_by_layer[f"layer_{layer}"] = [padded_arr]

100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:22<00:00,  3.82it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:40<00:00,  4.53it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:24<00:00,  4.89it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:52<00:00,  4.30it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:57<00:00,  4.21it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:40<00:00,  4.54it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:12<00:00,  3.96it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:42<00:00,  4.49it/s]
100%|███████████████████████████

In [6]:
layer_hs_array = np.array(hidden_states_by_layer["layer_2"])
layer_hs_array.shape

(10000, 267264)

## Layer 2 Clustering

In [7]:
# Use original vectors for clustering - uncomment next line and comment out last two lines

dim_reduced_vecs = layer_hs_array

# random_projector = GaussianRandomProjection(random_state = 42)
# dim_reduced_vecs = random_projector.fit_transform(layer_hs_array).astype('float32')

In [8]:
dim_reduced_vecs = np.array([v / np.linalg.norm(v) for v in dim_reduced_vecs])
dim_reduced_vecs.shape

(10000, 267264)

In [9]:
# K-means Clustering

ncentroids = NUM_PROMPTS
niter = 20
verbose = True
dim = dim_reduced_vecs.shape[1]
kmeans = faiss.Kmeans(dim, ncentroids, niter = niter, verbose = verbose, gpu = True, nredo = 10, spherical = True, max_points_per_centroid = 1000)
kmeans.train(dim_reduced_vecs)

Clustering 10000 points in 267264D to 10 clusters, redo 10 times, 20 iterations
  Preprocessing in 1.59 s
Outer iteration 0 / 10
  Iteration 19 (26.00 s, search 17.91 s): objective=4885.53 imbalance=1.107 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 10
  Iteration 19 (52.07 s, search 35.93 s): objective=4877.22 imbalance=1.393 nsplit=0       
Outer iteration 2 / 10
  Iteration 19 (78.12 s, search 53.84 s): objective=4886.87 imbalance=1.092 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 10
  Iteration 19 (104.40 s, search 71.79 s): objective=4892.7 imbalance=1.047 nsplit=0        
Objective improved: keep new clusters
Outer iteration 4 / 10
  Iteration 19 (130.62 s, search 89.71 s): objective=4886.17 imbalance=1.104 nsplit=0       
Outer iteration 5 / 10
  Iteration 19 (156.66 s, search 107.65 s): objective=4875.28 imbalance=1.356 nsplit=0       
Outer iteration 6 / 10
  Iteration 19 (182.70 s, search 125.56 s): objective=4880.82 i

4892.703125

In [10]:
kmeans.centroids #cluster centers

array([[ 5.74872084e-03,  3.08733503e-03,  1.43712182e-02, ...,
        -1.81957247e-07,  6.38515576e-06,  1.52332259e-05],
       [ 5.23568364e-03,  2.81181093e-03,  1.30886734e-02, ...,
         1.84849878e-08, -1.16551428e-06, -1.38975122e-06],
       [ 6.16712356e-03,  3.31203826e-03,  1.54171558e-02, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       ...,
       [ 6.02112757e-03,  3.23363463e-03,  1.50522031e-02, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 6.13717828e-03,  3.29595641e-03,  1.53423175e-02, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 5.26678376e-03,  2.82851164e-03,  1.31664146e-02, ...,
         4.61338942e-07, -5.33175211e-08,  3.62436163e-06]], dtype=float32)

In [11]:
for centroid in kmeans.centroids:
    print(np.linalg.norm(centroid))

1.0000554
1.0000361
1.0000104
1.0000465
1.0000411
1.0000107
1.0000515
1.0000116
1.0000114
1.0000287


In [12]:
kmeans.obj #inertia at each iteration

array([3034.62646484, 4718.64599609, 4790.91015625, 4827.23583984,
       4848.84423828, 4864.20800781, 4870.99560547, 4874.46337891,
       4876.77050781, 4878.75683594, 4880.36083984, 4881.63330078,
       4882.41162109, 4883.06640625, 4883.50390625, 4883.88671875,
       4884.22949219, 4884.60742188, 4885.02050781, 4885.53173828,
       3195.13061523, 4709.79443359, 4754.78173828, 4785.34375   ,
       4821.73779297, 4852.14550781, 4864.74072266, 4868.94726562,
       4870.29736328, 4870.99169922, 4871.55371094, 4872.140625  ,
       4872.93408203, 4873.95410156, 4875.17431641, 4876.10449219,
       4876.66357422, 4876.94970703, 4877.13427734, 4877.22460938,
       2980.88500977, 4714.51123047, 4797.36474609, 4837.08447266,
       4852.89941406, 4860.71777344, 4866.62548828, 4871.31884766,
       4874.86279297, 4877.27294922, 4878.60644531, 4879.63085938,
       4880.53271484, 4881.62890625, 4882.68652344, 4883.71289062,
       4884.72460938, 4885.64746094, 4886.28369141, 4886.87353

In [13]:
normalized_vecs = [v / np.linalg.norm(v) for v in dim_reduced_vecs]

In [14]:
cos_similarities = normalized_vecs @ kmeans.centroids.T
classifications = np.argmax(cos_similarities, axis=1)

In [15]:
pd.Series(classifications).value_counts()

2    1408
8    1228
9    1211
3    1053
6    1046
0     886
1     844
5     822
7     796
4     706
dtype: int64

In [16]:
pd.Series(kmeans.index.search(dim_reduced_vecs.astype(np.float32), 1)[1].flatten()).value_counts()

2    1408
8    1228
9    1211
3    1053
6    1046
0     886
1     844
5     822
7     796
4     706
dtype: int64

In [17]:
prompt_ids = df["prompt_id"]
prompt_ids = prompt_ids.to_numpy()
prompt_ids

array([ 1,  1,  1, ..., 10, 10, 10])

In [18]:
# Get most common centroid for each 1000 points (same label)
max_centroid_per_label = [pd.Series(classifications[i * 1000:(i + 1) * 1000]).value_counts().idxmax() for i in range(10)]
max_centroid_per_label

[2, 8, 2, 4, 3, 8, 4, 3, 1, 0]

In [19]:
# Get most common label for each point classified to a centroid (same centroid)
centroid_labels = [np.where(classifications == i)[0] for i in range(10)]
max_label_per_centroid = [pd.Series(prompt_ids[centroid_labels[i]]).value_counts().idxmax() for i in range(10)]
max_label_per_centroid

[10, 9, 3, 5, 4, 4, 3, 5, 2, 2]

In [20]:
max_centroids = [centroid for centroid in max_centroid_per_label for _ in range(1000)]

max_labels = [label for label in max_label_per_centroid for _ in range(1000)]

In [21]:
np.array(max_centroids)

array([2, 2, 2, ..., 0, 0, 0])

In [22]:
np.array(max_labels)

array([10, 10, 10, ...,  2,  2,  2])

In [23]:
label_to_centroid = {idx + 1 : max_centroid_per_label[idx] for idx in range(len(max_centroid_per_label))}

centroid_to_label = {idx : max_label_per_centroid[idx] for idx in range(len(max_label_per_centroid))}

In [24]:
label_to_centroid

{1: 2, 2: 8, 3: 2, 4: 4, 5: 3, 6: 8, 7: 4, 8: 3, 9: 1, 10: 0}

In [25]:
centroid_to_label

{0: 10, 1: 9, 2: 3, 3: 5, 4: 4, 5: 4, 6: 3, 7: 5, 8: 2, 9: 2}

In [26]:
vectorized_map = np.vectorize(centroid_to_label.get)
classifications_to_label = vectorized_map(classifications)

classifications_to_label

array([ 5,  9,  9, ..., 10, 10, 10])

In [27]:
for i in range(10):
    print(f"Prompt {i + 1} Accuracy: ", np.mean(classifications_to_label[i * 1000:(i + 1) * 1000] == (i + 1)))

Prompt 1 Accuracy:  0.0
Prompt 2 Accuracy:  0.535
Prompt 3 Accuracy:  0.433
Prompt 4 Accuracy:  0.366
Prompt 5 Accuracy:  0.598
Prompt 6 Accuracy:  0.0
Prompt 7 Accuracy:  0.0
Prompt 8 Accuracy:  0.0
Prompt 9 Accuracy:  0.258
Prompt 10 Accuracy:  0.872


In [28]:
print(f"Overall Accuracy: ", np.mean(classifications_to_label == prompt_ids))

Overall Accuracy:  0.3062
