In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# train
VOCAB_SIZE = 3000
EMBEDDING_SIZE = 10
BATCH_SIZE = 128
EPOCH = 30
# display number
DISNUM = 50
# name
name = 'rnn2-' + 'epo' + str(EPOCH) + 'ebd' + str(EMBEDDING_SIZE) + 'vcb' + str(VOCAB_SIZE)

class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, num_layers):
        super(RNN, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.rnn = nn.RNN(embedding_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
    def forward(self, inputs):
        embedding = self.embedding(inputs)
        out, hidden = self.rnn(embedding)
        out = self.fc(out)
        return F.log_softmax(out, dim=-1)

model = RNN(VOCAB_SIZE+1, EMBEDDING_SIZE, 128, 1)
model.load_state_dict(torch.load(name + '.pth'))
word_vectors = np.array(model.embedding.weight.data.cpu())

In [7]:
# read corpus from file
corpus = []
with open('news2-cleaned.txt', 'r', encoding='utf-8') as f:
    for line in f:
        corpus.append(line.strip().split(' '))

# build word_to_ix
word_to_ix = {}
for sentence in corpus:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
word_to_ix = {k: v if v < VOCAB_SIZE else VOCAB_SIZE for k, v in word_to_ix.items()}
word_to_ix['<UNK>'] = VOCAB_SIZE

In [10]:
# get similar words using cosine similarity
def get_similar_words(word, n=10):
    if word not in word_to_ix:
        print('Word not in vocabulary')
        return
    word_vector = word_vectors[word_to_ix[word]]
    similarities = np.dot(word_vectors, word_vector) / (np.linalg.norm(word_vectors, axis=1) * np.linalg.norm(word_vector))
    sorted_indices = np.argsort(similarities)[::-1][1:n+1]
    print('Top 10 similar words to ' + word + ':')
    for i in sorted_indices:
        similar_word = list(word_to_ix.keys())[list(word_to_ix.values()).index(i)]
        similarity = similarities[i]
        print(similar_word)
get_similar_words(input('Enter a word: '))

Top 10 similar words to his:
biden’s
trump’s
their
president’s
mcconnell’s
facebook’s
its
initial
controversial
campaign’s
