## Install Packages

In [1]:
!conda install -y -c conda-forge faiss-gpu
!apt-get -y update
!apt-get -y install libatlas-base-dev

done
Solving environment: done


  current version: 4.14.0
  latest version: 25.5.1

Please update conda by running

    $ conda update -n base -c conda-forge conda



# All requested packages already installed.

Retrieving notices: ...working... done
Hit:1 http://archive.ubuntu.com/ubuntu bionic InRelease
Hit:2 http://archive.ubuntu.com/ubuntu bionic-updates InRelease                
Hit:3 http://archive.ubuntu.com/ubuntu bionic-backports InRelease              
Hit:4 http://security.ubuntu.com/ubuntu bionic-security InRelease              
Reading package lists... Done
Reading package lists... Done
Building dependency tree       
Reading state information... Done
libatlas-base-dev is already the newest version (3.10.3-5).
0 upgraded, 0 newly installed, 0 to remove and 83 not upgraded.


## Load Data

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.random_projection import GaussianRandomProjection

from tqdm import tqdm

import faiss

In [3]:
df = pd.read_csv("story_dataset.csv")
df

Unnamed: 0,prompt_id,prompt,story,hidden_state_file,len_generated_story,len_new_story
0,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Blaz...,./hidden_states/prompt_1.npz,270,271
1,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Spar...,./hidden_states/prompt_1.npz,349,350
2,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Scor...,./hidden_states/prompt_1.npz,278,278
3,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,117,118
4,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,129,130
...,...,...,...,...,...,...
9995,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,289,290
9996,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,119,119
9997,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,127,128
9998,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,441,441


In [4]:
max_story_len = max(df["len_generated_story"])
max_story_len

522

In [5]:
hidden_states_by_layer = {}
NUM_PROMPTS = 10

for prompt_id in range(1, NUM_PROMPTS + 1):
    with np.load(f'./hidden_states/prompt_{prompt_id}.npz') as loaded_data:
        for i in tqdm(range(1000)):
            curr_hidden_states = loaded_data[f"arr_{i}"][0]
#             print(curr_hidden_states.shape)
            for layer in range(6, 7):
                padded_arr = np.zeros((max_story_len, 512))
                padded_arr_len = len(curr_hidden_states[layer][0])
                
                padded_arr[:padded_arr_len] = curr_hidden_states[layer][0]
                
                padded_arr = padded_arr.flatten().astype('float32') #FAISS expects data in type float32 instead of float64 - saves memory too!
#                 print(padded_arr.shape)
                
                if(f"layer_{layer}" in hidden_states_by_layer):
                    hidden_states_by_layer[f"layer_{layer}"].append(padded_arr)
                else:
                    hidden_states_by_layer[f"layer_{layer}"] = [padded_arr]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:26<00:00, 11.53it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:18<00:00, 12.74it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:17<00:00, 12.97it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:15<00:00, 13.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:12<00:00, 13.78it/s]
100%|█████████████████████████████████████████████████████████████████

In [6]:
layer_hs_array = np.array(hidden_states_by_layer["layer_6"])
layer_hs_array.shape

(10000, 267264)

## Layer 6 Clustering

In [7]:
# Use original vectors for clustering - uncomment next line and comment out last two lines

# dim_reduced_vecs = layer_hs_array

random_projector = GaussianRandomProjection(random_state = 42)
dim_reduced_vecs = random_projector.fit_transform(layer_hs_array).astype('float32')

In [8]:
dim_reduced_vecs = np.array([v / np.linalg.norm(v) for v in dim_reduced_vecs])
dim_reduced_vecs.shape

(10000, 7894)

In [9]:
# K-means Clustering

ncentroids = NUM_PROMPTS
niter = 20
verbose = True
dim = dim_reduced_vecs.shape[1]
kmeans = faiss.Kmeans(dim, ncentroids, niter = niter, verbose = verbose, gpu = True, nredo = 10, spherical = True)
kmeans.train(dim_reduced_vecs)

Sampling a subset of 2560 / 10000 for training
Clustering 2560 points in 7894D to 10 clusters, redo 10 times, 20 iterations
  Preprocessing in 0.07 s
Outer iteration 0 / 10
  Iteration 19 (0.15 s, search 0.08 s): objective=754.796 imbalance=1.178 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 10
  Iteration 19 (0.31 s, search 0.17 s): objective=769.023 imbalance=1.422 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 10
  Iteration 19 (0.46 s, search 0.25 s): objective=751.45 imbalance=1.514 nsplit=0        
Outer iteration 3 / 10
  Iteration 19 (0.61 s, search 0.33 s): objective=766.077 imbalance=1.265 nsplit=0       
Outer iteration 4 / 10
  Iteration 19 (0.77 s, search 0.41 s): objective=752.943 imbalance=1.446 nsplit=0       
Outer iteration 5 / 10
  Iteration 19 (0.92 s, search 0.49 s): objective=764.044 imbalance=1.383 nsplit=0       
Outer iteration 6 / 10
  Iteration 19 (1.07 s, search 0.57 s): objective=751.41 imbalance=1.628 

769.0228881835938

In [10]:
kmeans.centroids #cluster centers

array([[ 0.01673672,  0.00401446,  0.01086004, ..., -0.00742791,
        -0.00702039,  0.00149118],
       [ 0.02889898,  0.02086297,  0.00117888, ..., -0.00232237,
        -0.00899892, -0.01132663],
       [ 0.00957622,  0.01992157, -0.00987406, ...,  0.00061228,
        -0.00289478, -0.00183564],
       ...,
       [ 0.00278024,  0.01324403,  0.00109854, ..., -0.01671038,
        -0.01153946, -0.00393818],
       [ 0.01942942,  0.00382576,  0.00502452, ..., -0.0129452 ,
        -0.00836569, -0.01600607],
       [ 0.02642076,  0.00428157,  0.0057397 , ..., -0.01018341,
        -0.01322437, -0.01052065]], dtype=float32)

In [11]:
for centroid in kmeans.centroids:
    print(np.linalg.norm(centroid))

0.9999997
1.0
0.9999998
1.0000002
1.0000001
1.0000001
1.0000001
1.0000005
0.99999976
1.0


In [12]:
kmeans.obj #inertia at each iteration

array([276.28485107, 715.73626709, 738.60113525, 744.57910156,
       746.39331055, 748.50994873, 750.86102295, 752.63293457,
       753.18487549, 753.40435791, 753.6048584 , 753.89715576,
       754.50030518, 754.69390869, 754.76251221, 754.78485107,
       754.7958374 , 754.7958374 , 754.7958374 , 754.7958374 ,
       267.88897705, 724.83129883, 750.06140137, 758.92926025,
       764.2376709 , 766.54315186, 768.05639648, 768.68273926,
       768.89825439, 768.99328613, 769.02288818, 769.02288818,
       769.02288818, 769.02288818, 769.02288818, 769.02288818,
       769.02288818, 769.02288818, 769.02288818, 769.02288818])

In [13]:
pd.Series(kmeans.index.search(dim_reduced_vecs.astype(np.float32), 1)[1].flatten()).value_counts()

3    2734
7    1595
8    1587
4     754
5     712
6     684
0     664
2     587
1     411
9     272
dtype: int64

In [15]:
normalized_vecs = [v / np.linalg.norm(v) for v in dim_reduced_vecs]

In [16]:
cos_similarities = normalized_vecs @ kmeans.centroids.T
classifications = np.argmax(cos_similarities, axis=1)

In [17]:
pd.Series(classifications).value_counts()

3    2734
7    1595
8    1587
4     754
5     712
6     684
0     664
2     587
1     411
9     272
dtype: int64