In [15]:

import os
import pickle

with open('../temp/embs_150.pkl', 'rb') as f:
    embs = pickle.load(f)

In [16]:
# cluster emb with kmeans and compute clustering coefficient
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.metrics.pairwise import cosine_distances

def get_clustering_score(embedding):
    clusterer = KMeans(n_clusters=2, random_state=10)
    cluster_labels = clusterer.fit_predict(embedding)
    silhouette_avg = silhouette_score(embedding, cluster_labels)
    return silhouette_avg

scores_with_labels_train = []
scores_with_labels_test = []
for i in embs['train']:
    score = get_clustering_score(i['embs'])
    scores_with_labels_train.append((score,int(i['y'][0])))
for i in embs['test']:
    score = get_clustering_score(i['embs'])
    scores_with_labels_test.append((score,int(i['y'][0])))

KeyboardInterrupt: 

In [9]:
scores_with_labels_train.sort(key=lambda x: x[0], reverse=True)
labels = [l for s,l in scores_with_labels_train]

In [13]:
scores_with_labels_train

[(0.8531033, 1),
 (0.80788547, 1),
 (0.8077914, 1),
 (0.80743945, 1),
 (0.7982832, 1),
 (0.79752755, 1),
 (0.796381, 1),
 (0.79513615, 1),
 (0.79424584, 1),
 (0.7921556, 1),
 (0.7896109, 1),
 (0.7893833, 1),
 (0.788917, 1),
 (0.7878195, 1),
 (0.7874459, 1),
 (0.7840251, 1),
 (0.7798263, 1),
 (0.7756173, 1),
 (0.7752906, 1),
 (0.77514946, 1),
 (0.77359784, 1),
 (0.7727221, 1),
 (0.77229154, 1),
 (0.77190477, 1),
 (0.77175224, 1),
 (0.7710217, 1),
 (0.76960117, 1),
 (0.76713777, 1),
 (0.76407206, 1),
 (0.7606886, 1),
 (0.75907004, 1),
 (0.7584586, 1),
 (0.75778985, 1),
 (0.75777704, 1),
 (0.7570033, 1),
 (0.7551208, 1),
 (0.7538001, 1),
 (0.75310725, 1),
 (0.75257844, 1),
 (0.75137365, 1),
 (0.75119096, 1),
 (0.7497859, 1),
 (0.74752325, 1),
 (0.7473325, 1),
 (0.74522144, 1),
 (0.7439779, 1),
 (0.74384135, 1),
 (0.7426074, 1),
 (0.7425889, 1),
 (0.7408212, 1),
 (0.73882926, 1),
 (0.7384659, 1),
 (0.73687047, 1),
 (0.73568594, 1),
 (0.7351522, 1),
 (0.73453486, 1),
 (0.73374015, 1),
 (0.7

In [11]:
def find_split_point(sequence):
    count_ones = 0
    count_zeros = 0
    max_diff = -1
    split_point = -1
    
    for i, bit in enumerate(sequence):
        if bit == 1:
            count_ones += 1
        else:
            count_zeros += 1
        
        diff = count_ones - count_zeros
        if diff > max_diff:
            max_diff = diff
            split_point = i
    
    return split_point

split_point = find_split_point(labels)
value = scores_with_labels_train[split_point]
value

(0.61600953, 1)

In [12]:

preds = []
thresh = value[0]
for i in scores_with_labels_test:
    if i[0] > thresh:
        preds.append(1==i[1])
    else:
        preds.append(0==i[1])
sum(preds)/len(preds)

0.8932

## Compute cluster correlation across formulas

In [66]:
import numpy as np


cluster_vectors = []

def group_clusters(embedding, cluster_labels):
    cluster_0 = []
    cluster_1 = []
    for i, label in enumerate(cluster_labels):
        if label == 0:
            cluster_0.append(embedding[i])
        else:
            cluster_1.append(embedding[i])
    # get average embedding for each cluster
    cluster_0_avg = np.mean(cluster_0, axis=0)
    cluster_1_avg = np.mean(cluster_1, axis=0)
    return (cluster_0_avg, cluster_1_avg)

def get_cluster_vectors(embedding):
    clusterer = KMeans(n_clusters=2, random_state=10)
    cluster_labels = clusterer.fit_predict(embedding)
    silhouette_avg = silhouette_score(embedding, cluster_labels)
    if silhouette_avg > value[0]:
        return group_clusters(embedding, cluster_labels)
    else:
        return None
    
for embedding in embs['train']:
    vecs = get_cluster_vectors(embedding['embs'])
    if vecs != None:
        cluster_vectors.append(vecs)

In [69]:
mean_vecs = [i for i,j in cluster_vectors] + [j for i,j in cluster_vectors]
mean_vecs = np.array(mean_vecs)
truth_vecs = get_cluster_vectors(mean_vecs)
truth_vecs

(array([ 0.21540566,  0.04394149, -0.8167139 ,  0.93003595, -0.97237647,
        -0.86760235, -0.88844824, -0.7071766 , -0.8585163 , -0.9751871 ,
        -0.58062494,  0.9863596 ,  0.8402313 ,  0.6771322 ,  0.33159262,
        -0.73869634], dtype=float32),
 array([-0.72970885,  0.04942944,  0.9779272 , -0.71400034, -0.9095677 ,
         0.048223  , -0.62863165,  0.10765669, -0.9106606 , -0.85211015,
        -0.8941389 ,  0.97597337, -0.03721797, -0.14009655,  0.879907  ,
        -0.2283434 ], dtype=float32))