In [10]:
import pickle
import os
from typing import Set, Tuple, List, Dict

import torch
import numpy as np
import scipy
from sklearn.metrics import pairwise
from tqdm import tqdm

np.random.seed(1)


class SkipGramNegativeSampling():
    
    def __init__(self, source: str, path: str):
        if source == 'pytorch':
            self.init_from_pytorch(path)
        elif source == 'plain_text':
            self.init_from_plain_text(path)

#     def from_pytorch(path: str):
#         from embeddings.pretrain_embeddings import SkipGramNegativeSampling
#         model = SkipGramNegativeSampling().load_state_dict(torch.load(model_path)) 
#         with open(path, 'rb') as model_file:
#             state_dict = torch.load(model_file, map_location='cpu')
#         print(state_dict.keys())
#         embeddings = state_dict['center_embedding.weight'].numpy()
#         assert False

    def init_from_pytorch(self, paths: Tuple[str, str]) -> None:
        model_path, vocab_path = paths
        with open(model_path, 'rb') as model_file:
            state_dict = torch.load(model_file, map_location='cpu')
    #     print(state_dict.keys())
        self.embedding = state_dict['center_embedding.weight'].numpy()
        with open(vocab_path, 'rb') as vocab_file:
            self.word_to_id, self.id_to_word, _ = pickle.load(vocab_file)

    def init_from_plain_text(self, path: str) -> Tuple[np.array, Dict[str, int]]:
        id_generator = 0
        word_to_id: Dict[str, int] = {}
        embeddings: List[float] = []
        embedding_file = open(path)
        vocab_size, num_dimensions = map(int, embedding_file.readline().split())
        print(f'vocab_size = {vocab_size:,}, num_dimensions = {num_dimensions}')
        print(f'Loading embeddings from {path}', flush=True)
        for line in embedding_file:
            line: List[str] = line.split()  # type: ignore
            word = line[0]
            vector = np.array(line[-num_dimensions:], dtype=np.float64)
            embeddings.append(vector)
            word_to_id[word] = id_generator
            id_generator += 1
        embedding_file.close()
        print('Done')
        self.id_to_word = {val: key for key, val in word_to_id.items()}
        self.word_to_id = word_to_id
        self.embedding = np.array(embeddings)

    def cosine_similarity(self, query1: str, query2: str) -> float:
        try:
            query1_id = self.word_to_id[query1]
        except KeyError:
            print(f'{query1} is out of vocabulary. Sorry!')
            return
        try:
            query2_id = self.word_to_id[query1]
        except KeyError:
            print(f'{query2} is out of vocabulary. Sorry!')
            return
        vectors = self.embedding[(query1_id, query2_id), :]
        similarity = 1 - scipy.spatial.distance.cosine(vectors[0], vectors[1])
        return similarity

    def nearest_neighbor(self, query: str, top_k: int = 10):
        try:
            query_id = self.word_to_id[query]
        except KeyError:
            raise f'{query} is out of vocabulary. Sorry!'    
        query_vec = self.embedding[query_id]
        
        distances = [scipy.spatial.distance.cosine(query_vec, vec) 
                     for vec in self.embedding]
        neighbors = np.argsort(distances)
        print(f"{query}'s neareset neighbors:")
        for ranking in range(1, top_k + 1):
            word_id = neighbors[ranking]
            word = self.id_to_word[word_id]
            cosine_similarity = 1 - distances[word_id]
            print(f'{ranking}: {word}\t\t{cosine_similarity:.4f}')
            if ranking > top_k:
                break
        print()
        

In [2]:
decade_model = SkipGramNegativeSampling(
    'pytorch',
    ('../../results/skip_gram/decade_bs32anneal_1e-5/epoch10.pt',
     '../../data/processed/skip_gram/decade_paper1e-5/vocab.pickle'))

In [3]:
postwar_model = SkipGramNegativeSampling(
    'pytorch',
    ('../../results/skip_gram/postwar_1e-5/epoch10.pt',
     '../../data/processed/skip_gram/postwar_paper1e-5/vocab.pickle'))

In [4]:
ref_decade_model = SkipGramNegativeSampling(
    'plain_text', '../../results/baseline/word2vec_decade.txt')

vocab_size = 109,123, num_dimensions = 300
Loading embeddings from ../../results/baseline/word2vec_decade.txt
Done


In [5]:
ref_postwar_model = SkipGramNegativeSampling(
    'plain_text', '../../results/baseline/word2vec_decade.txt')

vocab_size = 109,123, num_dimensions = 300
Loading embeddings from ../../results/baseline/word2vec_decade.txt
Done


In [18]:
model = ref_postwar_model
decade_model.nearest_neighbor('unemployment_benefits', top_k=30)
ref_decade_model.nearest_neighbor('unemployment_benefits', top_k=30)
# model.nearest_neighbor('estate_tax')
# model.nearest_neighbor('death_tax')
# model.nearest_neighbor('undocumented_immigrants')
# model.nearest_neighbor('illegal_aliens')
# model.nearest_neighbor('illegal_aliens')

unemployment_benefits's neareset neighbors:
1: unemployment_insurance		0.8851
2: unemployed_workers		0.8455
3: unemployment_compensation		0.8432
4: extension_of_unemployment		0.8398
5: unemployment_insurance_benefits		0.8396
6: unemployed		0.8355
7: extend_unemployment_benefits		0.8248
8: longterm_unemployed		0.7791
9: jobless		0.7768
10: federal_unemployment_benefits		0.7647
11: looking_for_work		0.7640
12: extended_benefits		0.7517
13: extended_unemployment_benefits		0.7513
14: unemployment		0.7480
15: extending_unemployment_benefits		0.7465
16: extend_unemployment		0.7429
17: teuc		0.7388
18: outofwork		0.7349
19: exhausted_their_benefits		0.7321
20: euc		0.7303
21: extra_weeks		0.7285
22: lost_their_jobs		0.7173
23: extending_unemployment		0.7171
24: exhausted_their_unemployment		0.7137
25: find_work		0.7095
26: unemployed_americans		0.7048
27: exhaustees		0.6915
28: emergency_unemployment_compensation		0.6915
29: unemployment_benefitsand		0.6873
30: ui		0.6805

unemployment_benefi