## Install Packages

In [1]:
!conda install -y -c conda-forge faiss-gpu
!apt-get -y update
!apt-get -y install libatlas-base-dev

done
Solving environment: done


  current version: 4.14.0
  latest version: 25.5.1

Please update conda by running

    $ conda update -n base -c conda-forge conda



## Package Plan ##

  environment location: /opt/conda

  added / updated specs:
    - faiss-gpu


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2025.8.3   |       hbd8a1cb_0         151 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         151 KB

The following packages will be UPDATED:

  ca-certificates                      2025.7.14-hbd8a1cb_0 --> 2025.8.3-hbd8a1cb_0



Downloading and Extracting Packages
ca-certificates-2025 | 151 KB    | ##################################### | 100% 
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
Retrieving notices: ...working... done
Hit:1 http:/

## Load Data

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.random_projection import GaussianRandomProjection

from tqdm import tqdm

import faiss

In [3]:
df = pd.read_csv("story_dataset.csv")
df

Unnamed: 0,prompt_id,prompt,story,hidden_state_file,len_generated_story,len_new_story
0,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Blaz...,./hidden_states/prompt_1.npz,270,271
1,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Spar...,./hidden_states/prompt_1.npz,349,350
2,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Scor...,./hidden_states/prompt_1.npz,278,278
3,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,117,118
4,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,129,130
...,...,...,...,...,...,...
9995,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,289,290
9996,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,119,119
9997,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,127,128
9998,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,441,441


In [4]:
max_story_len = max(df["len_generated_story"])
max_story_len

522

In [5]:
hidden_states_by_layer = {}
NUM_PROMPTS = 10

for prompt_id in range(1, NUM_PROMPTS + 1):
    with np.load(f'./hidden_states/prompt_{prompt_id}.npz') as loaded_data:
        for i in tqdm(range(1000)):
            curr_hidden_states = loaded_data[f"arr_{i}"][0]
#             print(curr_hidden_states.shape)
            for layer in range(5, 6):
                padded_arr = np.zeros((max_story_len, 512))
                padded_arr_len = len(curr_hidden_states[layer][0])
                
                padded_arr[:padded_arr_len] = curr_hidden_states[layer][0]
                
                padded_arr = padded_arr.flatten().astype('float32') #FAISS expects data in type float32 instead of float64 - saves memory too!
#                 print(padded_arr.shape)
                
                if(f"layer_{layer}" in hidden_states_by_layer):
                    hidden_states_by_layer[f"layer_{layer}"].append(padded_arr)
                else:
                    hidden_states_by_layer[f"layer_{layer}"] = [padded_arr]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:08<00:00,  3.24it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:42<00:00,  3.54it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:47<00:00,  4.39it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:43<00:00,  3.53it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [04:58<00:00,  3.36it/s]
100%|█████████████████████████████████████████████████████████████████

In [6]:
layer_hs_array = np.array(hidden_states_by_layer["layer_5"])
layer_hs_array.shape

(10000, 267264)

## Layer 5 Clustering

In [7]:
# Use original vectors for clustering - uncomment next line and comment out last two lines

# dim_reduced_vecs = layer_hs_array

random_projector = GaussianRandomProjection(random_state = 42)
dim_reduced_vecs = random_projector.fit_transform(layer_hs_array).astype('float32')

In [8]:
dim_reduced_vecs = np.array([v / np.linalg.norm(v) for v in dim_reduced_vecs])
dim_reduced_vecs.shape

(10000, 7894)

In [9]:
# K-means Clustering

ncentroids = NUM_PROMPTS
niter = 20
verbose = True
dim = dim_reduced_vecs.shape[1]
kmeans = faiss.Kmeans(dim, ncentroids, niter = niter, verbose = verbose, gpu = True, nredo = 10, spherical = True, max_points_per_centroid = 1000)
kmeans.train(dim_reduced_vecs)

Clustering 10000 points in 7894D to 10 clusters, redo 10 times, 20 iterations
  Preprocessing in 0.04 s
Outer iteration 0 / 10
  Iteration 19 (0.76 s, search 0.53 s): objective=5425.93 imbalance=1.613 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 10
  Iteration 19 (1.52 s, search 1.07 s): objective=5432.01 imbalance=2.115 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 10
  Iteration 19 (2.28 s, search 1.60 s): objective=5448.12 imbalance=1.334 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 10
  Iteration 19 (3.05 s, search 2.13 s): objective=5442.44 imbalance=1.288 nsplit=0       
Outer iteration 4 / 10
  Iteration 19 (3.81 s, search 2.67 s): objective=5438.22 imbalance=1.187 nsplit=0       
Outer iteration 5 / 10
  Iteration 19 (4.57 s, search 3.20 s): objective=5434.79 imbalance=1.577 nsplit=0       
Outer iteration 6 / 10
  Iteration 19 (5.33 s, search 3.74 s): objective=5425.1 imbalance=1.655 nsplit=0

5455.923828125

In [10]:
kmeans.centroids #cluster centers

array([[ 0.0072792 , -0.01527321,  0.00621297, ..., -0.01300611,
         0.00983786,  0.0020636 ],
       [ 0.00692313, -0.01829372,  0.01082267, ..., -0.00569263,
         0.00908529,  0.00967556],
       [ 0.01100189, -0.01232994,  0.00653843, ..., -0.00977939,
         0.00474763,  0.00390089],
       ...,
       [ 0.00940989, -0.01930307,  0.00886085, ..., -0.01090563,
         0.00713332,  0.00276421],
       [ 0.00265495, -0.01216395,  0.01212723, ..., -0.00677637,
         0.00801237, -0.00356999],
       [ 0.00792372, -0.01597789,  0.01099395, ..., -0.00616443,
         0.00649714,  0.00765719]], dtype=float32)

In [11]:
for centroid in kmeans.centroids:
    print(np.linalg.norm(centroid))

0.9999998
0.99999994
0.99999976
1.0
0.99999994
0.99999994
0.99999964
0.9999999
1.0
1.0000005


In [12]:
kmeans.obj #inertia at each iteration

array([3467.33520508, 5262.19628906, 5370.63134766, 5409.11425781,
       5416.80126953, 5419.55810547, 5421.05859375, 5422.08935547,
       5422.83642578, 5423.35205078, 5423.86083984, 5424.14990234,
       5424.35253906, 5424.55957031, 5424.75146484, 5424.99951172,
       5425.27783203, 5425.53613281, 5425.71972656, 5425.93261719,
       3626.39868164, 5246.43798828, 5341.88916016, 5392.92724609,
       5411.24072266, 5416.234375  , 5419.27636719, 5423.18408203,
       5427.38476562, 5429.58154297, 5430.69238281, 5431.31933594,
       5431.54541016, 5431.65136719, 5431.73193359, 5431.82666016,
       5431.88427734, 5431.93994141, 5431.99023438, 5432.00830078,
       3417.40161133, 5269.80761719, 5376.26904297, 5409.54541016,
       5422.02783203, 5430.91210938, 5436.97802734, 5439.3046875 ,
       5440.41455078, 5441.25292969, 5441.89111328, 5442.44238281,
       5443.10302734, 5443.91845703, 5445.03271484, 5445.68603516,
       5446.07617188, 5446.54736328, 5447.43164062, 5448.12060

In [13]:
pd.Series(kmeans.index.search(dim_reduced_vecs.astype(np.float32), 1)[1].flatten()).value_counts()

1    2219
6    1590
7    1552
4    1070
3     916
0     735
2     547
5     539
9     416
8     416
dtype: int64

In [14]:
normalized_vecs = [v / np.linalg.norm(v) for v in dim_reduced_vecs]

In [15]:
cos_similarities = normalized_vecs @ kmeans.centroids.T
classifications = np.argmax(cos_similarities, axis=1)

In [16]:
pd.Series(classifications).value_counts()

1    2219
6    1590
7    1552
4    1070
3     916
0     735
2     547
5     539
9     416
8     416
dtype: int64