In [None]:
import torch 
from model import Seq2SeqPOSTagger
from model2 import Encoder 
import re
from gensim.models import KeyedVectors
import pickle

In [None]:
torch.manual_seed(3137)

In [None]:
index2tag = pickle.load(open('../data/index_to_tag.pickle', 'rb'))

In [None]:
word2vec = KeyedVectors.load_word2vec_format('../word_vectors/GoogleNews-vectors-negative300.bin', binary=True)

In [None]:
unknown_words = pickle.load(open('../word_vectors/unknown_words.pickle', 'rb'))

In [None]:
def gen_wv(word, rand=True): # generate word vectors 
    global word2vec, unknown_words
    word = word.lower()
    try:
        word_vec = torch.tensor(word2vec[word]).reshape(1,-1) 
        return word_vec
    except: # not in word2vec
        try:
            if word in unknown_words.keys(): # if in unknown words
                word_vec = unknown_words[word].clone().detach().reshape(1,-1)
                return word_vec
            else:
                if re.search("'", word):
                    word = re.split("'", word)[0] # words with apostrophe are queried by removing apostrophe 
                    word_vec = torch.tensor(word2vec[word]).reshape(1,-1) 
                    return word_vec
                if re.search('-', word): # for compound words, word vector of each word is averaged 
                    word_vec = torch.randn((1,300)) if rand else torch.zeros((1,300))
                    words = re.split('-', word)
                    for w in words:
                        try:
                            word_vec += word2vec[w]
                        except:
                            if w in unknown_words.keys():
                                word_vec += unknown_words[w]                            
                    word_vec = word_vec/len(words)
                    return word_vec
                else:
                    if word not in unknown_words.keys():
                        word_vec = torch.randn((1,300)) if rand else torch.zeros((1,300))
                    else:
                        word_vec = unknown_words[word]
                    return word_vec
        except:
            if word not in unknown_words.keys():
                word_vec = torch.randn((1,300)) if rand else torch.zeros((1,300))
            else:
                word_vec = unknown_words[word]
            return word_vec

In [None]:
model = Encoder(300, 128, 1, 12)
encoder_decoder = Seq2SeqPOSTagger(
    encoder_input_dim=300,
    decoder_input_dim=268,
    output_dim=12,
    hidden_dim=128,
    num_layers=1
)

In [None]:
encoder_decoder.load_state_dict(torch.load('runs/epochs=5,batch_size=128,hidden_dim=128,timestamp=2023-03-09_21-31-24/final_model.pt'))

In [None]:
model.load_state_dict(torch.load('runs/enc_only,epochs=5,batch_size=128,hidden_dim=128,timestamp=2023-03-10_01-27-26/final_model.pt'))

In [None]:
sentence = 'can the can-opener open the can ?'

In [None]:
sentence_vector = []
for word in sentence.split():
    print(word)
    word_vec = gen_wv(word)
    sentence_vector.append(word_vec)

sentence_vector = torch.cat(sentence_vector).unsqueeze(0)

In [None]:
sentence_vector.shape

In [None]:
softmax = torch.nn.Softmax(dim=1)
with torch.no_grad():
    enc_output = model(sentence_vector.to('cuda'))
    enc_output = softmax(enc_output.squeeze(0))

In [None]:
with torch.no_grad():
    enc_dec_output = encoder_decoder.predict(sentence_vector.to('cuda'))
    enc_dec_output = softmax(enc_dec_output.squeeze(0))

In [None]:
enc_dec_output.shape

In [None]:
enc_tags = []
for i in torch.argmax(enc_output, dim=1).tolist():
    enc_tags.append(index2tag[i])
enc_dec_tags = []
for i in torch.argmax(enc_dec_output, dim=1).tolist():
    enc_dec_tags.append(index2tag[i])

In [None]:
words = sentence.split()
for i in range(len(words)):
    print(words[i], "\t", enc_tags[i], "\t", enc_dec_tags[i])