In [None]:
import torch 
from model import FFNN
import re
from gensim.models import KeyedVectors
import pickle

In [None]:
torch.manual_seed(3137)

In [None]:
index2tag = pickle.load(open('../data/index_to_tag.pickle', 'rb'))

In [None]:
word2vec = KeyedVectors.load_word2vec_format('../word_vectors/GoogleNews-vectors-negative300.bin', binary=True)

In [None]:
unknown_words = pickle.load(open('../word_vectors/unknown_words.pickle', 'rb'))

In [None]:
def gen_wv(word, rand=True): # generate word vectors 
    global word2vec, unknown_words
    word = word.lower()
    try:
        word_vec = torch.tensor(word2vec[word]).reshape(1,-1) 
        return word_vec
    except: # not in word2vec
        try:
            if word in unknown_words.keys(): # if in unknown words
                word_vec = unknown_words[word].clone().detach().reshape(1,-1)
                return word_vec
            else:
                if re.search("'", word):
                    word = re.split("'", word)[0] # words with apostrophe are queried by removing apostrophe 
                    word_vec = torch.tensor(word2vec[word]).reshape(1,-1) 
                    return word_vec
                if re.search('-', word): # for compound words, word vector of each word is averaged 
                    word_vec = torch.randn((1,300)) if rand else torch.zeros((1,300))
                    words = re.split('-', word)
                    for w in words:
                        try:
                            word_vec += word2vec[w]
                        except:
                            if w in unknown_words.keys():
                                word_vec += unknown_words[w]                            
                    word_vec = word_vec/len(words)
                    return word_vec
                else:
                    if word not in unknown_words.keys():
                        word_vec = torch.randn((1,300)) if rand else torch.zeros((1,300))
                    else:
                        word_vec = unknown_words[word]
                    return word_vec
        except:
            if word not in unknown_words.keys():
                word_vec = torch.randn((1,300)) if rand else torch.zeros((1,300))
            else:
                word_vec = unknown_words[word]
            return word_vec

In [None]:
model = FFNN(900, 12, [1024, 512, 512])

In [None]:
model.load_state_dict(torch.load('runs/FINAL_epochs=8,batch_size=1024,hidden_dim=[1024, 512, 512],timestamp=2023-03-10_01-25-08/final_model.pt'))

In [None]:
sentence = 'The wheat cultivation in my town suffered due to scarcity of water .'

In [None]:
sentence_vector = []
for word in sentence.split():
    print(word)
    word_vec = gen_wv(word)
    sentence_vector.append(word_vec)
    


In [None]:
sentence_vector = torch.cat(sentence_vector)

In [None]:
word_vec_dim=300
context_size=1
len_sentence = sentence_vector.shape[0]

ffnn_input = []

for word_index in range(len_sentence):
    # index of context window w.r.t to index of word in sentence
    index = torch.cat((word_index - (context_size - torch.arange(context_size)), word_index + (torch.arange(context_size+1))))
    # pad with zero if no words before word to be tagged
    # pad with zero if no words after the word to be tagged
    context_vector = torch.cat((
        torch.zeros((sum(index < 0), word_vec_dim)),
        sentence_vector.index_select(
            dim=0, index=index[(index >= 0) & (index < len_sentence)]),
        torch.zeros((sum(index >= len_sentence), word_vec_dim)),
    )).flatten()
    ffnn_input.append(context_vector.reshape(1,-1))

In [None]:
ffnn_input = torch.cat(ffnn_input)

In [None]:
with torch.no_grad():
    output = model.predict(ffnn_input).dim

In [None]:
tags = []
for i in torch.argmax(output, dim=1).tolist():
    tags.append(index2tag[i])

In [None]:
words = sentence.split()
for i in range(len_sentence):
    print(words[i], tags[i])