## Install Packages

In [1]:
!conda install -y -c conda-forge faiss-gpu
!apt-get -y update
!apt-get -y install libatlas-base-dev

done
Solving environment: done


  current version: 4.14.0
  latest version: 25.5.1

Please update conda by running

    $ conda update -n base -c conda-forge conda



# All requested packages already installed.

Retrieving notices: ...working... done
Hit:1 http://security.ubuntu.com/ubuntu bionic-security InRelease
Hit:2 http://archive.ubuntu.com/ubuntu bionic InRelease
Hit:3 http://archive.ubuntu.com/ubuntu bionic-updates InRelease
Hit:4 http://archive.ubuntu.com/ubuntu bionic-backports InRelease
Reading package lists... Done                     
Reading package lists... Done
Building dependency tree       
Reading state information... Done
libatlas-base-dev is already the newest version (3.10.3-5).
0 upgraded, 0 newly installed, 0 to remove and 83 not upgraded.


## Load Data

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.random_projection import GaussianRandomProjection

from tqdm import tqdm

import faiss

In [3]:
df = pd.read_csv("story_dataset.csv")
df

Unnamed: 0,prompt_id,prompt,story,hidden_state_file,len_generated_story,len_new_story
0,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Blaz...,./hidden_states/prompt_1.npz,270,271
1,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Spar...,./hidden_states/prompt_1.npz,349,350
2,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Scor...,./hidden_states/prompt_1.npz,278,278
3,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,117,118
4,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,129,130
...,...,...,...,...,...,...
9995,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,289,290
9996,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,119,119
9997,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,127,128
9998,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,441,441


In [4]:
max_story_len = max(df["len_generated_story"])
max_story_len

522

In [5]:
hidden_states_by_layer = {}
NUM_PROMPTS = 10

for prompt_id in range(1, NUM_PROMPTS + 1):
    with np.load(f'./hidden_states/prompt_{prompt_id}.npz') as loaded_data:
        for i in tqdm(range(1000)):
            curr_hidden_states = loaded_data[f"arr_{i}"][0]
#             print(curr_hidden_states.shape)
            for layer in range(8, 9):
                padded_arr = np.zeros((max_story_len, 512))
                padded_arr_len = len(curr_hidden_states[layer][0])
                
                padded_arr[:padded_arr_len] = curr_hidden_states[layer][0]
                
                padded_arr = padded_arr.flatten().astype('float32') #FAISS expects data in type float32 instead of float64 - saves memory too!
#                 print(padded_arr.shape)
                
                if(f"layer_{layer}" in hidden_states_by_layer):
                    hidden_states_by_layer[f"layer_{layer}"].append(padded_arr)
                else:
                    hidden_states_by_layer[f"layer_{layer}"] = [padded_arr]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:41<00:00,  9.88it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:34<00:00, 10.63it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:46<00:00,  9.42it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:51<00:00,  8.95it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:44<00:00,  9.59it/s]
100%|█████████████████████████████████████████████████████████████████

In [6]:
layer_hs_array = np.array(hidden_states_by_layer["layer_8"])
layer_hs_array.shape

(10000, 267264)

## Layer 8 Clustering

In [7]:
# Use original vectors for clustering - uncomment next line and comment out last two lines

# dim_reduced_vecs = layer_hs_array

random_projector = GaussianRandomProjection(random_state = 42)
dim_reduced_vecs = random_projector.fit_transform(layer_hs_array).astype('float32')

In [8]:
dim_reduced_vecs = np.array([v / np.linalg.norm(v) for v in dim_reduced_vecs])
dim_reduced_vecs.shape

(10000, 7894)

In [9]:
# K-means Clustering

ncentroids = NUM_PROMPTS
niter = 20
verbose = True
dim = dim_reduced_vecs.shape[1]
kmeans = faiss.Kmeans(dim, ncentroids, niter = niter, verbose = verbose, gpu = True, nredo = 10, spherical = True, max_points_per_centroid = 1000)
kmeans.train(dim_reduced_vecs)

Clustering 10000 points in 7894D to 10 clusters, redo 10 times, 20 iterations
  Preprocessing in 0.07 s
Outer iteration 0 / 10
  Iteration 19 (1.99 s, search 1.23 s): objective=6100.76 imbalance=1.217 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 10
  Iteration 19 (3.96 s, search 2.46 s): objective=6100.93 imbalance=1.319 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 10
  Iteration 19 (5.96 s, search 3.69 s): objective=6102.28 imbalance=1.060 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 10
  Iteration 19 (7.95 s, search 4.93 s): objective=6100.56 imbalance=1.066 nsplit=0       
Outer iteration 4 / 10
  Iteration 19 (9.95 s, search 6.16 s): objective=6078.47 imbalance=1.265 nsplit=0       
Outer iteration 5 / 10
  Iteration 19 (11.91 s, search 7.39 s): objective=6100.17 imbalance=1.345 nsplit=0       
Outer iteration 6 / 10
  Iteration 19 (13.92 s, search 8.64 s): objective=6102.78 imbalance=1.085 nspli

6102.7841796875

In [10]:
kmeans.centroids #cluster centers

array([[-0.01518083, -0.0015754 , -0.0008974 , ..., -0.01520034,
        -0.00685114,  0.00065254],
       [-0.00106073, -0.00399352,  0.02793805, ..., -0.0022775 ,
        -0.02300845, -0.00254802],
       [-0.02250788, -0.00108937, -0.00447611, ..., -0.01290773,
        -0.0060121 ,  0.00877293],
       ...,
       [-0.01440714, -0.00092288,  0.01395446, ..., -0.01005596,
        -0.01149299,  0.0005758 ],
       [-0.008718  , -0.00099671,  0.01780267, ..., -0.01418401,
        -0.01185136,  0.00403163],
       [-0.00639699,  0.01060635,  0.03331428, ..., -0.0016776 ,
        -0.01840435, -0.00039815]], dtype=float32)

In [11]:
for centroid in kmeans.centroids:
    print(np.linalg.norm(centroid))

1.0000001
0.99999964
1.0
0.99999976
1.0
0.99999994
1.0000004
1.0000001
0.9999998
1.0


In [12]:
kmeans.obj #inertia at each iteration

array([3705.50390625, 6017.96679688, 6072.06103516, 6078.33251953,
       6081.09619141, 6082.72021484, 6084.28125   , 6086.15527344,
       6088.49609375, 6091.51806641, 6093.48632812, 6095.06298828,
       6098.69384766, 6100.04443359, 6100.26025391, 6100.39501953,
       6100.50097656, 6100.62646484, 6100.69482422, 6100.76220703,
       3632.60791016, 5975.41943359, 6034.75976562, 6065.55908203,
       6087.60595703, 6091.69042969, 6094.29931641, 6096.44482422,
       6097.93505859, 6098.67041016, 6099.09326172, 6099.4140625 ,
       6099.66455078, 6099.93652344, 6100.19677734, 6100.45849609,
       6100.65283203, 6100.78466797, 6100.87353516, 6100.93164062,
       3772.68652344, 6044.81787109, 6081.34082031, 6088.53662109,
       6091.44873047, 6092.56396484, 6092.96875   , 6093.18017578,
       6093.30029297, 6093.43017578, 6093.5859375 , 6093.75244141,
       6094.12548828, 6095.00048828, 6097.91894531, 6099.62890625,
       6100.44970703, 6101.21484375, 6101.80224609, 6102.28320

In [13]:
pd.Series(kmeans.index.search(dim_reduced_vecs.astype(np.float32), 1)[1].flatten()).value_counts()

3    1519
4    1287
5    1208
8    1134
0    1091
2    1024
7     922
1     620
6     612
9     583
dtype: int64

In [14]:
normalized_vecs = [v / np.linalg.norm(v) for v in dim_reduced_vecs]

In [15]:
cos_similarities = normalized_vecs @ kmeans.centroids.T
classifications = np.argmax(cos_similarities, axis=1)

In [16]:
pd.Series(classifications).value_counts()

3    1519
4    1287
5    1208
8    1134
0    1091
2    1024
7     922
1     620
6     612
9     583
dtype: int64