In [1]:
from pathlib import Path

import tabulate
import torch.nn.functional as F
import torch

In [2]:
dataset_name = 'wikipedia_20220101'

In [3]:
model_dict = torch.load('../data/models/wikipedia_20220101/20230117_1M_half_epoch.pth')

In [4]:
model_dict

OrderedDict([('w_embeddings.weight',
              tensor([[ 0.8157, -0.0253,  0.0314,  ..., -0.0088,  0.0067, -0.0545],
                      [ 1.1868, -0.0469, -0.0204,  ..., -0.1442,  0.1593, -0.3008],
                      [ 0.5395, -0.0428, -0.1247,  ..., -0.0987, -0.1637, -0.0772],
                      ...,
                      [-0.1531,  0.3876,  0.2110,  ..., -1.5737,  0.0395, -0.4951],
                      [ 1.0544,  0.0212, -0.0971,  ..., -1.3920,  0.8410, -0.2747],
                      [ 0.1835, -0.5065,  0.1714,  ...,  0.3168,  0.2803,  0.1444]])),
             ('c_embeddings.weight',
              tensor([[-0.6713, -0.4170, -0.2485,  ...,  0.2466,  0.1604,  0.1674],
                      [-0.2657, -0.2486,  0.1616,  ..., -0.1634,  0.1333,  0.0580],
                      [-0.2388,  0.1429, -0.0976,  ...,  0.3170, -0.1464, -0.0903],
                      ...,
                      [-0.8478,  0.4977,  0.2796,  ...,  1.7972, -0.2456,  0.7156],
                      [-0.978

In [5]:
data_path = Path('../').resolve().joinpath('data')

In [6]:
wvocab = {}
with open(data_path.joinpath('vocab', dataset_name, 'wvocab.txt')) as infile:
    for i, line in enumerate(infile.readlines()):
        wvocab[line.strip()] = i

i2w = list(wvocab.keys())

In [7]:
def top_w_sims(model_dict, word, k=5):
    topk_sims = F.cosine_similarity(
        model_dict['w_embeddings.weight'][wvocab[word]],
        model_dict['w_embeddings.weight']
    ).topk(k)

    for wi, sim in zip(topk_sims.indices.data.tolist(), topk_sims.values.data.tolist()):
        yield i2w[wi], sim

In [9]:
list(top_w_sims(model_dict, 'move-to', 10))

[('move-to', 1.0),
 ('move-from', 0.8199008703231812),
 ('relocate-to', 0.8038907051086426),
 ('transfer-to', 0.772900402545929),
 ('move-into', 0.765338122844696),
 ('send-to', 0.7322393655776978),
 ('return-to', 0.7277364730834961),
 ('work-in', 0.7264373898506165),
 ('relocate-from', 0.7230533361434937),
 ('migrate-to', 0.7226387858390808)]

In [10]:
list(top_w_sims(model_dict, 'be_president_of', 10))

[('be_president_of', 0.9999999403953552),
 ('be_chairman_of', 0.8024470806121826),
 ('be_chair_of', 0.764552891254425),
 ('be_manager_of', 0.7625577449798584),
 ('be_founder_of', 0.7623101472854614),
 ('be_director_of', 0.7604293823242188),
 ('be_ceo_of', 0.7582588791847229),
 ('be_member_of', 0.7575125694274902),
 ('appos_chairman_of', 0.7554025650024414),
 ('work-at', 0.7434799075126648)]