In [1]:
import torch
import pytorch_lightning as pl
import numpy as np
from gensim.models import Word2Vec
from train import Lightning

In [2]:
w2v = Word2Vec.load('data/w2v/news_200d.bin')
model = Lightning.load_from_checkpoint('lightning_logs/exp7_kmean/version_0/checkpoints/epoch=0.ckpt')
doc_encoder = model.nrms.doc_encoder
w2id = {key: w2v.wv.vocab[key].index for key in w2v.wv.vocab}
id2w = {v: k for k, v in w2id.items()}


In [3]:
N = 10
result = []
projs = doc_encoder.get_ae_aspects()
for j, (proj, idx) in enumerate(zip(*projs)):
    idx = idx[-N:].detach().cpu().numpy()
    words = ', '.join([id2w[i] for i in idx])
    result.append(words)
result

["accessory, anytime, cmass, battalion, smarts, spear, 311, 'everybody, legislators, 11/9/19",
 'shall, streetwise, cnn10, husqvarna, cerabino, superdome, onwuasor, headphone, lookout, rams-steelers',
 'dowling, 311, shall, husqvarna, cerabino, legislators, onwuasor, headphone, lookout, rams-steelers',
 "markus, educate, 'corrupt, cursing, spear, jeffersonville, tyrell, saagar, cnn10, apl",
 'misinformation, stockholm, chorus, caldwell-pope, pretzel, marking, mixer, peabody, churchill, dowling',
 'caldwell-pope, charlestown, cnn10, stockholm, marking, gingerly, onwuasor, cerabino, allegan, apl',
 "leslie, 'ready, cerabino, frenzy, tyrell, onwuasor, filip, scorers, peabody, administrators",
 'xhaka, 311, onwuasor, cerabino, apl, legislators, glasgow, saagar, peabody, tyrell',
 'peabody, chancellor, tyrell, hines, saagar, 11/14, emery, mauldin, streetwise, superdome',
 'dyson, peabody, os, caldwell-pope, cnn10, rowhome, cmass, stockholm, i-team, marking',
 "fleming, dowling, churchill, 1

In [4]:
doc_encoder = model.nrms.doc_encoder

In [5]:
title = 'Best PS5 games : top PlayStation 5 titles to look forward to'.split()
idx = torch.tensor([w2id.get(w.lower(), 0) for w in title])
idx

tensor([   31, 10810,   241,     1,    32,  8881,    67,  4443,    12,   148,
         1089,    12])

In [6]:
o, score, _ = doc_encoder(idx.unsqueeze(0), False)
score = score.squeeze().detach().tolist()

In [7]:
list(zip(title, score))

[('Best', 0.07582545280456543),
 ('PS5', 0.11388859897851944),
 ('games', 0.1108662486076355),
 (':', 0.0679372176527977),
 ('top', 0.08917580544948578),
 ('PlayStation', 0.08669701218605042),
 ('5', 0.10396092385053635),
 ('titles', 0.08928592503070831),
 ('to', 0.030614886432886124),
 ('look', 0.10592625290155411),
 ('forward', 0.0855594351887703),
 ('to', 0.04026225954294205)]

In [8]:
title = 'US president backpedals on meeting parents of children killed by recalled dressers'.split()
idx = torch.tensor([w2id.get(w.lower(), 0) for w in title])
o, score, _ = doc_encoder(idx.unsqueeze(0), False)
score = score.squeeze().detach().tolist()
list(zip(title, score))

[('US', 0.0858328640460968),
 ('president', 0.0758572444319725),
 ('backpedals', 0.06443890929222107),
 ('on', 0.11488490551710129),
 ('meeting', 0.07776381820440292),
 ('parents', 0.08054882287979126),
 ('of', 0.09722773730754852),
 ('children', 0.0674239918589592),
 ('killed', 0.08062224090099335),
 ('by', 0.11293116211891174),
 ('recalled', 0.07033807039260864),
 ('dressers', 0.07213034480810165)]