In [44]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [45]:
def make_batch():
    input_batch, target_batch = [], []
    
    for sentence in sentences:
        sentence_words = sentence.split()
        input = [word_to_idx[word] for word in sentence_words[:-1]]
        target = word_to_idx[sentence_words[-1]]
        
        input_vector = np.eye(vocab_size)
        
        input_batch.append(input_vector[input])
        target_batch.append(target)
        
    return input_batch, target_batch

In [55]:
class TextLSTM(nn.Module):
    def __init__(self):
        super(TextLSTM, self).__init__()
        
        self.lstm = nn.LSTM(input_size=vocab_size, hidden_size=n_hidden)
        self.W = nn.Linear(n_hidden, vocab_size, bias=False) # n_hidden x vocab_size
        self.b = nn.Parameter(torch.ones([vocab_size]))
        
    def forward(self, X):
        # X: (batch_size, n_step, vocab_size)
        
        X = X.transpose(0, 1) # (n_step, batch_size, vocab_size)
        hidden_state = torch.zeros(1, batch_size, n_hidden)
        cell_state = torch.zeros(1, batch_size, n_hidden)
        
        outputs, (_, _) = self.lstm(X, (hidden_state, cell_state))
        outputs = outputs[-1] # (batch_size, n_hidden)
        result = self.W(outputs) + self.b # (batch_size, vocab_size)
        return result

In [69]:
n_step = 2 # number of cells(steps), = seq_len - 1
n_hidden = 10 # number of hidden units of one cell

In [70]:
sentences = ['i drink water', 'i eat food', 'i read book', 'i play guitar', 'dog is cute', 'tiger is scary', 'i love you', 'i hate worm', 'i listen song']
word_list = list(set(" ".join(sentences).split()))

word_to_idx = {word: idx for idx, word in enumerate(word_list)}
idx_to_word = {idx: word for idx, word in enumerate(word_list)}

vocab_size = len(word_list) # number of vocab
batch_size = len(sentences)
seq_len = 3

In [71]:
model = TextLSTM()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [72]:
input_batch, target_batch = make_batch()
input_batch = torch.FloatTensor(input_batch)
target_batch = torch.LongTensor(target_batch)

In [73]:
# Training
for epoch in range(1000):
    optimizer.zero_grad()
    
    output = model(input_batch)
    loss = criterion(output, target_batch)
    if (epoch + 1) % 100 == 0:
        print('Epoch : {:4d}  loss : {:.6f}'.format(epoch + 1, loss))
        
    loss.backward()
    optimizer.step()

Epoch :  100  loss : 2.607230
Epoch :  200  loss : 2.055153
Epoch :  300  loss : 1.376193
Epoch :  400  loss : 0.763571
Epoch :  500  loss : 0.458528
Epoch :  600  loss : 0.282576
Epoch :  700  loss : 0.182857
Epoch :  800  loss : 0.129173
Epoch :  900  loss : 0.097505
Epoch : 1000  loss : 0.076790


In [74]:
# Predict
inputs = [" ".join(sentence.split()[:-1]) for sentence in sentences]
predicts = model(input_batch).data.max(1, keepdim=True)[1].squeeze() 

In [75]:
predicts # (batch_size)

tensor([14,  8,  6, 13, 15, 18,  4,  5,  1])

In [76]:
for i, predict in enumerate(predicts):
    print(inputs[i], '->', idx_to_word[predict.item()])

i drink -> water
i eat -> food
i read -> book
i play -> guitar
dog is -> cute
tiger is -> scary
i love -> you
i hate -> worm
i listen -> song
