In [5]:
import torch
import pickle
from utils import Skipgram, SkipgramNeg, Glove
import torch.nn.functional as F

In [6]:
# Load the data
Data = pickle.load(open('/home/koala/NLP/NPL-A1/data/data.pkl', 'rb'))
vocab = Data['vocab']
word2index = Data['word2index']
voc_size = Data['voc_size']
embed_size = Data['emb_size']

# Load the models
skipgram = Skipgram(voc_size, embed_size)
skipgram.load_state_dict(torch.load('/home/koala/NLP/NPL-A1/models/skipgram.pth', map_location=torch.device('cpu')))
skipgram.eval()

skipgramNeg = SkipgramNeg(voc_size, embed_size)
skipgramNeg.load_state_dict(torch.load('/home/koala/NLP/NPL-A1/models/skipgramNEG.pth', map_location=torch.device('cpu')))
skipgramNeg.eval()

glove = Glove(voc_size, embed_size)
glove.load_state_dict(torch.load('/home/koala/NLP/NPL-A1/models/GloVe.pth', map_location=torch.device('cpu')))
glove.eval()

Glove(
  (embedding_center): Embedding(8558, 30)
  (embedding_outside): Embedding(8558, 30)
  (center_bias): Embedding(8558, 1)
  (outside_bias): Embedding(8558, 1)
)

In [8]:
def get_similar_words(model, user_inputs):

    all_word_vectors = torch.stack([model.get_embed(word) for word in vocab])
    user_inputs = user_inputs.split()

    input_vectors = []

    for word in user_inputs:
        if word.lower() in vocab:
            input_vectors.append(model.get_embed(word.lower()))
        else:
            input_vectors.append(model.get_embed('<UNK>'))

    # Check if input vectors are not empty
    if input_vectors:
        
        # Initialize result_vector with the first vector
        result_vector = input_vectors[0]

        # Add the rest of the vectors
        for vector in input_vectors[1:]:
            result_vector += vector
    else:
        # Handle the case where input_vectors is empty
        result_vector = torch.zeros_like(all_word_vectors[0])  # Assuming all vectors have the same size

    # Calculate cosine similarities
    similarities = F.cosine_similarity(result_vector.unsqueeze(0), all_word_vectors)

    # Get top 10 similar words
    top_indices = torch.argsort(similarities, descending=True)[0][:10]
    return [vocab[index.item()] for index in top_indices.view(-1)]
