## Install Packages

In [1]:
!conda install -y -c conda-forge faiss-gpu
!apt-get -y update
!apt-get -y install libatlas-base-dev

done
Solving environment: done


  current version: 4.14.0
  latest version: 25.7.0

Please update conda by running

    $ conda update -n base -c conda-forge conda



# All requested packages already installed.

Retrieving notices: ...working... done
Hit:1 http://security.ubuntu.com/ubuntu bionic-security InRelease
Hit:2 http://archive.ubuntu.com/ubuntu bionic InRelease
Get:3 http://archive.ubuntu.com/ubuntu bionic-updates InRelease [102 kB]
Get:4 http://archive.ubuntu.com/ubuntu bionic-backports InRelease [102 kB]
Fetched 204 kB in 1s (370 kB/s)                            
Reading package lists... Done
Reading package lists... Done
Building dependency tree       
Reading state information... Done
libatlas-base-dev is already the newest version (3.10.3-5).
0 upgraded, 0 newly installed, 0 to remove and 83 not upgraded.


## Load Data

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.random_projection import GaussianRandomProjection

from tqdm import tqdm

import faiss

In [3]:
df = pd.read_csv("story_dataset.csv")
df

Unnamed: 0,prompt_id,prompt,story,hidden_state_file,len_generated_story,len_new_story
0,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Blaz...,./hidden_states/prompt_1.npz,270,271
1,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Spar...,./hidden_states/prompt_1.npz,349,350
2,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Scor...,./hidden_states/prompt_1.npz,278,278
3,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,117,118
4,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,129,130
...,...,...,...,...,...,...
9995,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,289,290
9996,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,119,119
9997,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,127,128
9998,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,441,441


In [4]:
max_story_len = max(df["len_generated_story"])
max_story_len

522

In [5]:
hidden_states_by_layer = {}
NUM_PROMPTS = 10

for prompt_id in range(1, NUM_PROMPTS + 1):
    with np.load(f'./hidden_states/prompt_{prompt_id}.npz') as loaded_data:
        for i in tqdm(range(1000)):
            curr_hidden_states = loaded_data[f"arr_{i}"][0]
#             print(curr_hidden_states.shape)
            for layer in range(3, 4):
                padded_arr = np.zeros((max_story_len, 512))
                padded_arr_len = len(curr_hidden_states[layer][0])
                
                padded_arr[:padded_arr_len] = curr_hidden_states[layer][0]
                
                padded_arr = padded_arr.flatten().astype('float32') #FAISS expects data in type float32 instead of float64 - saves memory too!
#                 print(padded_arr.shape)
                
                if(f"layer_{layer}" in hidden_states_by_layer):
                    hidden_states_by_layer[f"layer_{layer}"].append(padded_arr)
                else:
                    hidden_states_by_layer[f"layer_{layer}"] = [padded_arr]

100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:50<00:00,  4.34it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:33<00:00,  4.69it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:58<00:00,  5.59it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:23<00:00,  4.92it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:33<00:00,  4.69it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:26<00:00,  4.85it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:26<00:00,  4.85it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:21<00:00,  4.97it/s]
100%|███████████████████████████

In [6]:
layer_hs_array = np.array(hidden_states_by_layer["layer_3"])
layer_hs_array.shape

(10000, 267264)

## Layer 3 Clustering

In [7]:
# Use original vectors for clustering - uncomment next line and comment out last two lines

dim_reduced_vecs = layer_hs_array

# random_projector = GaussianRandomProjection(random_state = 42)
# dim_reduced_vecs = random_projector.fit_transform(layer_hs_array).astype('float32')

In [8]:
dim_reduced_vecs = np.array([v / np.linalg.norm(v) for v in dim_reduced_vecs])
dim_reduced_vecs.shape

(10000, 267264)

In [9]:
# K-means Clustering

ncentroids = NUM_PROMPTS
niter = 20
verbose = True
dim = dim_reduced_vecs.shape[1]
kmeans = faiss.Kmeans(dim, ncentroids, niter = niter, verbose = verbose, gpu = True, nredo = 10, spherical = True, max_points_per_centroid = 1000)
kmeans.train(dim_reduced_vecs)

Clustering 10000 points in 267264D to 10 clusters, redo 10 times, 20 iterations
  Preprocessing in 1.60 s
Outer iteration 0 / 10
  Iteration 19 (26.17 s, search 17.92 s): objective=4906.07 imbalance=1.130 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 10
  Iteration 19 (52.03 s, search 35.74 s): objective=4890.32 imbalance=1.507 nsplit=0       
Outer iteration 2 / 10
  Iteration 19 (77.92 s, search 53.57 s): objective=4898.44 imbalance=1.316 nsplit=0       
Outer iteration 3 / 10
  Iteration 19 (103.91 s, search 71.41 s): objective=4912.11 imbalance=1.036 nsplit=0       
Objective improved: keep new clusters
Outer iteration 4 / 10
  Iteration 19 (129.91 s, search 89.23 s): objective=4908.69 imbalance=1.103 nsplit=0       
Outer iteration 5 / 10
  Iteration 19 (155.77 s, search 107.05 s): objective=4896.37 imbalance=1.353 nsplit=0       
Outer iteration 6 / 10
  Iteration 19 (181.65 s, search 124.91 s): objective=4892.11 imbalance=1.540 nsplit=0       
Outer i

4916.31005859375

In [10]:
kmeans.centroids #cluster centers

array([[ 3.7964880e-03,  1.3179081e-03,  9.7678797e-03, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 4.0940675e-03,  1.4212129e-03,  1.0533521e-02, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 3.8711505e-03,  1.3438276e-03,  9.9599911e-03, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       ...,
       [ 3.9874315e-03,  1.3841932e-03,  1.0259150e-02, ...,
        -1.9801594e-06,  6.2586798e-07,  6.5047693e-06],
       [ 4.4617304e-03,  1.5488414e-03,  1.1479470e-02, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 4.5087570e-03,  1.5651671e-03,  1.1600457e-02, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00]], dtype=float32)

In [11]:
for centroid in kmeans.centroids:
    print(np.linalg.norm(centroid))

1.0000291
1.0000548
1.0000271
1.0000266
1.000009
1.0000542
1.0000498
1.0000446
1.0000128
1.000009


In [12]:
kmeans.obj #inertia at each iteration

array([3043.52978516, 4738.26318359, 4812.96191406, 4852.26904297,
       4873.41015625, 4885.33496094, 4892.20166016, 4895.76025391,
       4898.19921875, 4899.66503906, 4900.73583984, 4901.5234375 ,
       4902.17138672, 4902.73291016, 4903.20605469, 4903.73876953,
       4904.28955078, 4905.01513672, 4905.82373047, 4906.07275391,
       3202.37329102, 4740.50878906, 4783.265625  , 4806.73144531,
       4830.73193359, 4854.86572266, 4867.85986328, 4874.83837891,
       4878.26904297, 4881.65673828, 4885.52246094, 4888.16162109,
       4889.53027344, 4890.01904297, 4890.18408203, 4890.24658203,
       4890.28417969, 4890.30273438, 4890.31689453, 4890.32324219,
       2985.88671875, 4734.44189453, 4809.625     , 4847.25927734,
       4869.97265625, 4883.16015625, 4887.98046875, 4891.42041016,
       4894.27001953, 4896.10498047, 4897.22363281, 4897.81005859,
       4898.07421875, 4898.19335938, 4898.28466797, 4898.3125    ,
       4898.36669922, 4898.39599609, 4898.42480469, 4898.44238

In [13]:
normalized_vecs = [v / np.linalg.norm(v) for v in dim_reduced_vecs]

In [14]:
cos_similarities = normalized_vecs @ kmeans.centroids.T
classifications = np.argmax(cos_similarities, axis=1)

In [15]:
pd.Series(classifications).value_counts()

9    1563
1    1277
2    1199
3    1150
8    1088
4     928
6     870
5     762
0     694
7     469
dtype: int64

In [16]:
pd.Series(kmeans.index.search(dim_reduced_vecs.astype(np.float32), 1)[1].flatten()).value_counts()

9    1563
1    1277
2    1199
3    1150
8    1088
4     928
6     870
5     762
0     694
7     469
dtype: int64

In [17]:
prompt_ids = df["prompt_id"]
prompt_ids = prompt_ids.to_numpy()
prompt_ids

array([ 1,  1,  1, ..., 10, 10, 10])

In [18]:
# Get most common centroid for each 1000 points (same label)
max_centroid_per_label = [pd.Series(classifications[i * 1000:(i + 1) * 1000]).value_counts().idxmax() for i in range(10)]
max_centroid_per_label

[9, 4, 7, 8, 3, 8, 8, 3, 9, 6]

In [19]:
# Get most common label for each point classified to a centroid (same centroid)
centroid_labels = [np.where(classifications == i)[0] for i in range(10)]
max_label_per_centroid = [pd.Series(prompt_ids[centroid_labels[i]]).value_counts().idxmax() for i in range(10)]
max_label_per_centroid

[9, 6, 7, 5, 2, 5, 10, 3, 6, 1]

In [20]:
max_centroids = [centroid for centroid in max_centroid_per_label for _ in range(1000)]

max_labels = [label for label in max_label_per_centroid for _ in range(1000)]

In [21]:
np.array(max_centroids)

array([9, 9, 9, ..., 6, 6, 6])

In [22]:
np.array(max_labels)

array([9, 9, 9, ..., 1, 1, 1])

In [23]:
label_to_centroid = {idx + 1 : max_centroid_per_label[idx] for idx in range(len(max_centroid_per_label))}

centroid_to_label = {idx : max_label_per_centroid[idx] for idx in range(len(max_label_per_centroid))}

In [24]:
label_to_centroid

{1: 9, 2: 4, 3: 7, 4: 8, 5: 3, 6: 8, 7: 8, 8: 3, 9: 9, 10: 6}

In [25]:
centroid_to_label

{0: 9, 1: 6, 2: 7, 3: 5, 4: 2, 5: 5, 6: 10, 7: 3, 8: 6, 9: 1}

In [26]:
vectorized_map = np.vectorize(centroid_to_label.get)
classifications_to_label = vectorized_map(classifications)

classifications_to_label

array([ 6,  9,  9, ..., 10, 10, 10])

In [27]:
for i in range(10):
    print(f"Prompt {i + 1} Accuracy: ", np.mean(classifications_to_label[i * 1000:(i + 1) * 1000] == (i + 1)))

Prompt 1 Accuracy:  0.285
Prompt 2 Accuracy:  0.257
Prompt 3 Accuracy:  0.398
Prompt 4 Accuracy:  0.0
Prompt 5 Accuracy:  0.723
Prompt 6 Accuracy:  0.502
Prompt 7 Accuracy:  0.247
Prompt 8 Accuracy:  0.0
Prompt 9 Accuracy:  0.246
Prompt 10 Accuracy:  0.863


In [28]:
print(f"Overall Accuracy: ", np.mean(classifications_to_label == prompt_ids))

Overall Accuracy:  0.3521
