In [1]:
!conda install -y -c conda-forge faiss-gpu
!apt-get -y update
!apt-get -y install libatlas-base-dev

done
Solving environment: done

## Package Plan ##

  environment location: /opt/conda

  added / updated specs:
    - faiss-gpu


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    conda-4.14.0               |   py37h89c1867_0        1010 KB  conda-forge
    toolz-0.12.1               |     pyhd8ed1ab_0          51 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         1.0 MB

The following NEW packages will be INSTALLED:

  toolz              conda-forge/noarch::toolz-0.12.1-pyhd8ed1ab_0

The following packages will be UPDATED:

  conda                               4.12.0-py37h89c1867_0 --> 4.14.0-py37h89c1867_0



Downloading and Extracting Packages
conda-4.14.0         | 1010 KB   | ##################################### | 100% 
toolz-0.12.1         | 51 KB     | ##################################### | 1

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.random_projection import GaussianRandomProjection

from tqdm import tqdm

import faiss

In [3]:
df = pd.read_csv("story_dataset.csv")
df

Unnamed: 0,prompt_id,prompt,story,hidden_state_file,len_generated_story,len_new_story
0,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Blaz...,./hidden_states/prompt_1.npz,270,271
1,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Spar...,./hidden_states/prompt_1.npz,349,350
2,1,Once upon a time there was a dragon,Once upon a time there was a dragon named Scor...,./hidden_states/prompt_1.npz,278,278
3,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,117,118
4,1,Once upon a time there was a dragon,Once upon a time there was a dragon. The drago...,./hidden_states/prompt_1.npz,129,130
...,...,...,...,...,...,...
9995,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,289,290
9996,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,119,119
9997,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,127,128
9998,10,Once upon a time there was a poor boy,Once upon a time there was a poor boy named Ti...,./hidden_states/prompt_10.npz,441,441


In [4]:
max_story_len = max(df["len_generated_story"])
max_story_len

522

In [5]:
hidden_states_by_layer = {}
NUM_PROMPTS = 2

# for prompt_id in range(1, 11):
for prompt_id in range(1, NUM_PROMPTS + 1):
    with np.load(f'./hidden_states/prompt_{prompt_id}.npz') as loaded_data:
        for i in tqdm(range(1000)):
            curr_hidden_states = loaded_data[f"arr_{i}"][0]
#             print(curr_hidden_states.shape)
            for layer in range(1):
                padded_arr = np.zeros((max_story_len, 512))
                padded_arr_len = len(curr_hidden_states[layer][0])
                
                padded_arr[:padded_arr_len] = curr_hidden_states[layer][0]
                
                padded_arr = padded_arr.flatten().astype('float32') #FAISS expects data in type float32 instead of float64 - saves memory too!
#                 print(padded_arr.shape)
                
                if(f"layer_{layer}" in hidden_states_by_layer):
                    hidden_states_by_layer[f"layer_{layer}"].append(padded_arr)
                else:
                    hidden_states_by_layer[f"layer_{layer}"] = [padded_arr]

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:26<00:00, 37.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:24<00:00, 40.16it/s]


In [6]:
np.array(hidden_states_by_layer["layer_0"]).shape

(2000, 267264)

In [7]:
# Layer 0 clustering

random_projector = GaussianRandomProjection(random_state = 42)
dim_reduced_vecs = random_projector.fit_transform(np.array(hidden_states_by_layer["layer_0"])).astype('float32')

In [8]:
dim_reduced_vecs.shape

(2000, 6515)

In [9]:
# K-means Clustering

ncentroids = NUM_PROMPTS
niter = 50
verbose = True
dim = dim_reduced_vecs.shape[1]
kmeans = faiss.Kmeans(dim, ncentroids, niter=niter, verbose=verbose, gpu=True, nredo = 20)
kmeans.train(dim_reduced_vecs)

Sampling a subset of 512 / 2000 for training
Clustering 512 points in 6515D to 2 clusters, redo 20 times, 50 iterations
  Preprocessing in 0.01 s
Outer iteration 0 / 20
  Iteration 0 (0.00 s, search 0.00 s): objective=232818 imbalance=1.850 nsplit=0         Iteration 1 (0.00 s, search 0.00 s): objective=167797 imbalance=1.864 nsplit=0         Iteration 2 (0.01 s, search 0.00 s): objective=167774 imbalance=1.879 nsplit=0         Iteration 3 (0.01 s, search 0.00 s): objective=167743 imbalance=1.886 nsplit=0         Iteration 4 (0.01 s, search 0.00 s): objective=167733 imbalance=1.886 nsplit=0         Iteration 5 (0.01 s, search 0.00 s): objective=167733 imbalance=1.886 nsplit=0         Iteration 6 (0.01 s, search 0.01 s): objective=167733 imbalance=1.886 nsplit=0         Iteration 7 (0.01 s, search 0.01 s): objective=167733 imbalance=1.886 nsplit=0         Iteration 8 (0.01 s, search 0.01 s): objective=167733 imbalance=1.886 nsplit=0         Iteration 9 (0.02 s, search 0.01 s): 

  Iteration 0 (0.15 s, search 0.08 s): objective=350738 imbalance=1.992 nsplit=0         Iteration 1 (0.15 s, search 0.08 s): objective=167751 imbalance=1.992 nsplit=0         Iteration 2 (0.15 s, search 0.08 s): objective=167751 imbalance=1.992 nsplit=0         Iteration 3 (0.15 s, search 0.08 s): objective=167751 imbalance=1.992 nsplit=0         Iteration 4 (0.16 s, search 0.08 s): objective=167751 imbalance=1.992 nsplit=0         Iteration 5 (0.16 s, search 0.08 s): objective=167751 imbalance=1.992 nsplit=0         Iteration 6 (0.16 s, search 0.08 s): objective=167751 imbalance=1.992 nsplit=0         Iteration 7 (0.16 s, search 0.08 s): objective=167751 imbalance=1.992 nsplit=0         Iteration 8 (0.16 s, search 0.08 s): objective=167751 imbalance=1.992 nsplit=0         Iteration 9 (0.16 s, search 0.08 s): objective=167751 imbalance=1.992 nsplit=0         Iteration 10 (0.16 s, search 0.08 s): objective=167751 imbalance=1.992 nsplit=0         Iteration 11 (0.17 s, search 

  Iteration 0 (0.37 s, search 0.19 s): objective=351361 imbalance=1.992 nsplit=0         Iteration 1 (0.37 s, search 0.19 s): objective=167700 imbalance=1.992 nsplit=0         Iteration 2 (0.37 s, search 0.19 s): objective=167700 imbalance=1.992 nsplit=0         Iteration 3 (0.37 s, search 0.19 s): objective=167700 imbalance=1.992 nsplit=0         Iteration 4 (0.37 s, search 0.19 s): objective=167700 imbalance=1.992 nsplit=0         Iteration 5 (0.38 s, search 0.19 s): objective=167700 imbalance=1.992 nsplit=0         Iteration 6 (0.38 s, search 0.19 s): objective=167700 imbalance=1.992 nsplit=0         Iteration 7 (0.38 s, search 0.20 s): objective=167700 imbalance=1.992 nsplit=0         Iteration 8 (0.38 s, search 0.20 s): objective=167700 imbalance=1.992 nsplit=0         Iteration 9 (0.38 s, search 0.20 s): objective=167700 imbalance=1.992 nsplit=0         Iteration 10 (0.38 s, search 0.20 s): objective=167700 imbalance=1.992 nsplit=0         Iteration 11 (0.38 s, search 

  Iteration 0 (0.58 s, search 0.30 s): objective=340176 imbalance=1.992 nsplit=0         Iteration 1 (0.59 s, search 0.30 s): objective=167760 imbalance=1.992 nsplit=0         Iteration 2 (0.59 s, search 0.30 s): objective=167760 imbalance=1.992 nsplit=0         Iteration 3 (0.59 s, search 0.30 s): objective=167760 imbalance=1.992 nsplit=0         Iteration 4 (0.59 s, search 0.30 s): objective=167760 imbalance=1.992 nsplit=0         Iteration 5 (0.59 s, search 0.31 s): objective=167760 imbalance=1.992 nsplit=0         Iteration 6 (0.59 s, search 0.31 s): objective=167760 imbalance=1.992 nsplit=0         Iteration 7 (0.59 s, search 0.31 s): objective=167760 imbalance=1.992 nsplit=0         Iteration 8 (0.60 s, search 0.31 s): objective=167760 imbalance=1.992 nsplit=0         Iteration 9 (0.60 s, search 0.31 s): objective=167760 imbalance=1.992 nsplit=0         Iteration 10 (0.60 s, search 0.31 s): objective=167760 imbalance=1.992 nsplit=0         Iteration 11 (0.60 s, search 

  Iteration 0 (0.80 s, search 0.41 s): objective=215414 imbalance=1.992 nsplit=0         Iteration 1 (0.80 s, search 0.41 s): objective=168158 imbalance=1.992 nsplit=0         Iteration 2 (0.80 s, search 0.41 s): objective=168158 imbalance=1.992 nsplit=0         Iteration 3 (0.80 s, search 0.42 s): objective=168158 imbalance=1.992 nsplit=0         Iteration 4 (0.80 s, search 0.42 s): objective=168158 imbalance=1.992 nsplit=0         Iteration 5 (0.81 s, search 0.42 s): objective=168158 imbalance=1.992 nsplit=0         Iteration 6 (0.81 s, search 0.42 s): objective=168158 imbalance=1.992 nsplit=0         Iteration 7 (0.81 s, search 0.42 s): objective=168158 imbalance=1.992 nsplit=0         Iteration 8 (0.81 s, search 0.42 s): objective=168158 imbalance=1.992 nsplit=0         Iteration 9 (0.81 s, search 0.42 s): objective=168158 imbalance=1.992 nsplit=0         Iteration 10 (0.81 s, search 0.42 s): objective=168158 imbalance=1.992 nsplit=0         Iteration 11 (0.82 s, search 

  Iteration 0 (1.02 s, search 0.53 s): objective=219995 imbalance=1.992 nsplit=0         Iteration 1 (1.02 s, search 0.53 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 2 (1.02 s, search 0.53 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 3 (1.02 s, search 0.53 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 4 (1.02 s, search 0.53 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 5 (1.02 s, search 0.53 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 6 (1.03 s, search 0.53 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 7 (1.03 s, search 0.53 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 8 (1.03 s, search 0.53 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 9 (1.03 s, search 0.53 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 10 (1.03 s, search 0.53 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 11 (1.03 s, search 

  Iteration 0 (1.16 s, search 0.60 s): objective=222221 imbalance=1.992 nsplit=0         Iteration 1 (1.16 s, search 0.60 s): objective=167828 imbalance=1.992 nsplit=0         Iteration 2 (1.16 s, search 0.60 s): objective=167828 imbalance=1.992 nsplit=0         Iteration 3 (1.16 s, search 0.60 s): objective=167828 imbalance=1.992 nsplit=0         Iteration 4 (1.16 s, search 0.60 s): objective=167828 imbalance=1.992 nsplit=0         Iteration 5 (1.17 s, search 0.60 s): objective=167828 imbalance=1.992 nsplit=0         Iteration 6 (1.17 s, search 0.61 s): objective=167828 imbalance=1.992 nsplit=0         Iteration 7 (1.17 s, search 0.61 s): objective=167828 imbalance=1.992 nsplit=0         Iteration 8 (1.17 s, search 0.61 s): objective=167828 imbalance=1.992 nsplit=0         Iteration 9 (1.17 s, search 0.61 s): objective=167828 imbalance=1.992 nsplit=0         Iteration 10 (1.17 s, search 0.61 s): objective=167828 imbalance=1.992 nsplit=0         Iteration 11 (1.17 s, search 

  Iteration 0 (1.37 s, search 0.71 s): objective=246898 imbalance=1.992 nsplit=0         Iteration 1 (1.37 s, search 0.71 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 2 (1.38 s, search 0.71 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 3 (1.38 s, search 0.71 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 4 (1.38 s, search 0.71 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 5 (1.38 s, search 0.72 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 6 (1.38 s, search 0.72 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 7 (1.38 s, search 0.72 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 8 (1.38 s, search 0.72 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 9 (1.38 s, search 0.72 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 10 (1.39 s, search 0.72 s): objective=167638 imbalance=1.992 nsplit=0         Iteration 11 (1.39 s, search 

167629.703125

In [10]:
kmeans.centroids #cluster centers

array([[ 0.03878354,  0.05904993,  0.0282885 , ..., -0.03968159,
        -0.02476441, -0.05587641],
       [ 0.01763256, -0.21713637,  0.29737765, ...,  0.00347101,
         0.01532576,  0.01871864]], dtype=float32)

In [11]:
kmeans.obj #inertia at each iteration

array([232817.984375, 167796.890625, 167773.9375  , 167743.390625,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 167732.9375  , 167732.9375  ,
       167732.9375  , 167732.9375  , 233874.84375 , 167812.5     ,
       167812.5     , 167812.5     , 167812.5     , 167812.5     ,
       167812.5     , 167812.5     , 167812.5     , 167812.5  

In [12]:
pd.Series(kmeans.index.search(dim_reduced_vecs.astype(np.float32), 1)[1].flatten()).value_counts()

0    1999
1       1
dtype: int64

In [13]:
classifications = []
for dim_red_vec in dim_reduced_vecs:
    dist_from_first_centroid = np.linalg.norm((dim_red_vec - kmeans.centroids[0]))
    dist_from_second_centroid = np.linalg.norm((dim_red_vec - kmeans.centroids[1]))
    classification = 0 if dist_from_first_centroid < dist_from_second_centroid else 1
    
    classifications.append(classification)
    
pd.Series(classifications).value_counts()

0    1999
1       1
dtype: int64