In [19]:
import torch
import pickle
from sklearn.metrics.pairwise import cosine_similarity

gcn_embed_weights = torch.load('./embed/exhird_h_normal_gcn_word_embed_cpu.pt')
normal_embed_weights = torch.load('./embed/exhird_h_normal_word_embed_cpu.pt')

with open('./vocab2idx.pkl', 'rb') as f:
    vocab2idx = pickle.load(f)
    f.close()



In [4]:
gcn_embed_weights.shape

torch.Size([50004, 100])

In [52]:
from collections import defaultdict
from tqdm import tqdm

def get_cosine_similarity(feature_vec_1, feature_vec_2):    
    return cosine_similarity(feature_vec_1.reshape(1, -1), feature_vec_2.reshape(1, -1))[0][0]

def get_top_k_from_embed(embed, words, vocab, k):
    words_idx = [vocab[word] for word in words]
    scores = dict({vocab[word]: [] for word in words})
    
    for word_idx in words_idx:
        for vocab_idx in tqdm(range(len(vocab))):
            if word_idx == vocab_idx: continue
            scores[word_idx].append((vocab_idx, get_cosine_similarity(embed[word_idx], embed[vocab_idx])))
            
    for score_idx in scores:
        scores[score_idx] = sorted(scores[score_idx], key=lambda x: x[1], reverse=True)[:k]
            
    return scores

In [53]:
words = ['performance', 'optimal', 'homogeneous', 'subsequent', 'biomedical', 'involvement', 'nlp', 'dimensionality', 'hypothesis', 'configurations']
len(words)

10

In [54]:
gcn_similarity_scores = get_top_k_from_embed(gcn_embed_weights, words, vocab2idx, 50)

100%|██████████| 50004/50004 [00:12<00:00, 3858.15it/s]
100%|██████████| 50004/50004 [00:12<00:00, 4069.17it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4170.97it/s]
100%|██████████| 50004/50004 [00:12<00:00, 4127.23it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4211.28it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4219.98it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4206.98it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4199.33it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4193.66it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4207.78it/s]


In [55]:
baseline_similarity_scores = get_top_k_from_embed(normal_embed_weights, words, vocab2idx, 50)

100%|██████████| 50004/50004 [00:11<00:00, 4197.72it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4193.05it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4218.51it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4223.57it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4223.76it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4213.32it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4202.76it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4192.42it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4222.70it/s]
100%|██████████| 50004/50004 [00:11<00:00, 4220.11it/s]


In [56]:
baseline_similarity_scores

{51: [(1426, 0.5770947),
  (347, 0.5592501),
  (2289, 0.5566488),
  (7807, 0.522533),
  (6511, 0.5204188),
  (7557, 0.5138279),
  (547, 0.50611025),
  (597, 0.5039511),
  (22985, 0.5025718),
  (212, 0.49712214),
  (392, 0.49381307),
  (5485, 0.4869167),
  (2049, 0.48580068),
  (26873, 0.4855883),
  (858, 0.4835335),
  (207, 0.48348865),
  (145, 0.48127922),
  (37245, 0.47923642),
  (2776, 0.47871134),
  (39307, 0.47761554),
  (2053, 0.47719175),
  (2084, 0.476963),
  (7952, 0.47624382),
  (57, 0.47371143),
  (10508, 0.47092035),
  (231, 0.47056216),
  (47138, 0.47045994),
  (328, 0.47024295),
  (2219, 0.4699378),
  (29698, 0.46949127),
  (18196, 0.4684319),
  (1206, 0.46494296),
  (38066, 0.46113724),
  (473, 0.4600024),
  (868, 0.45984957),
  (5588, 0.45830068),
  (2453, 0.45495492),
  (6953, 0.45374778),
  (9462, 0.45312268),
  (5864, 0.45187882),
  (9540, 0.45086926),
  (1142, 0.44993484),
  (4119, 0.4481955),
  (1599, 0.44802433),
  (18197, 0.44794974),
  (296, 0.44688705),
  (598,