In [75]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import re
import numpy as np
import pickle
from io import open

In [76]:
WIFile = open("WI.pickle","rb")
word_indices = pickle.load(WIFile)

gramFile = open("QG.pickle","rb")
quatragrams = pickle.load(gramFile)

emFile = open("Em.pickle","rb")
embeddings = pickle.load(emFile)

vFile = open("Vocab.pickle","rb")
vocab = pickle.load(vFile)

for i in range(len(embeddings)):
    embeddings[i] = embeddings[i].view(1,len(embeddings[i]))

In [77]:
CONTEXT_SIZE = 3
EMBEDDING_DIM = 50
VOCAB_LEN = len(word_indices)

In [66]:
class RNN(nn.Module):
    def __init__(self, vocab_size, hidden_size, embedding_dim):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        
        self.i2h = nn.Linear(CONTEXT_SIZE*embedding_dim + hidden_size, hidden_size)
        self.i2o = nn.Linear(CONTEXT_SIZE*embedding_dim + hidden_size, EMBEDDING_DIM)
        self.i2o2 = nn.Linear(EMBEDDING_DIM, VOCAB_LEN)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.i2o2(output)
        output = self.softmax(output)
        
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

In [67]:
losses = []

n_hidden = 128
rn = RNN(VOCAB_LEN, n_hidden, EMBEDDING_DIM)
optimizer = optim.SGD(rn.parameters(), lr=0.01)
loss_function = nn.NLLLoss()

In [68]:
for epoch in range(10):
    if (epoch % 1 == 0):
        print("Epoch: ", epoch)
    total_loss = 0
    hidden = rn.initHidden()
    for context, target in quatragrams:

        one = word_indices[context[0]]
        two = word_indices[context[1]]
        three = word_indices[context[2]]
        target = word_indices[target]

    
        cat = torch.cat((embeddings[one], embeddings[two], embeddings[three]), 1)
        context_idxs = cat

        rn.zero_grad()

        out, hidden = rn(context_idxs, hidden)
        
        loss = loss_function(out, torch.tensor([target], dtype=torch.long))

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        hidden.detach_()
    losses.append(round(total_loss, 3))
print(losses)

Epoch:  0
Epoch:  1
Epoch:  2
Epoch:  3
Epoch:  4
Epoch:  5
Epoch:  6
Epoch:  7
Epoch:  8
Epoch:  9
[38507.21, 33700.611, 31413.142, 29538.134, 27909.615, 26491.437, 25269.952, 24231.293, 23365.909, 22658.578]


In [69]:
one = word_indices["preheat"]
two = word_indices["the"]
three = word_indices["oven"]


cat = torch.cat((embeddings[one], embeddings[two], embeddings[three]), 1)
context_idxs = cat

out, hidden = rn(context_idxs, hidden)

out = out[0]
values, indices = out.max(0)
print(values, indices)
print(list(vocab)[indices])

tensor(1.00000e-02 *
       -8.9979) tensor(1141)
cashew_nuts


In [72]:
sentence = ["combine", "the"]
for j in range(0,100):

    two = word_indices[sentence[-1]]
    one = word_indices[sentence[-2]]
    
    cat = torch.cat((embeddings[one], embeddings[two], embeddings[three]), 1)
    context_idxs = cat

    out, hidden = rn(context_idxs, hidden)

    out = out[0]
    values, indices = out.max(0)
    
    sentence.append(list(vocab)[indices])


print(sentence)

['combine', 'the', 'cashew_nuts', 'cashew_nuts', 'spoonfuls', 'pour', 'cashew_nuts', 'parmesan', 'pour', 'baking', 'servingssuggested', 'cashew_nuts', 'spoonfuls', 'pour', 'cashew_nuts', 'parmesan', 'pour', 'baking', 'servingssuggested', 'cashew_nuts', 'spoonfuls', 'pour', 'cashew_nuts', 'parmesan', 'pour', 'baking', 'servingssuggested', 'cashew_nuts', 'spoonfuls', 'pour', 'cashew_nuts', 'parmesan', 'pour', 'baking', 'servingssuggested', 'cashew_nuts', 'spoonfuls', 'pour', 'cashew_nuts', 'parmesan', 'pour', 'baking', 'servingssuggested', 'cashew_nuts', 'spoonfuls', 'pour', 'cashew_nuts', 'parmesan', 'pour', 'baking', 'servingssuggested', 'cashew_nuts', 'spoonfuls', 'pour', 'cashew_nuts', 'parmesan', 'pour', 'baking', 'servingssuggested', 'cashew_nuts', 'spoonfuls', 'pour', 'cashew_nuts', 'parmesan', 'pour', 'baking', 'servingssuggested', 'cashew_nuts', 'spoonfuls', 'pour', 'cashew_nuts', 'parmesan', 'pour', 'baking', 'servingssuggested', 'cashew_nuts', 'spoonfuls', 'pour', 'cashew_nuts