In [33]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [34]:
def make_batch():
    input_batch, target_batch = [], []
    
    input_vector = np.eye(vocab_size)
    
    sentence_words = " ".join(sentences).split()
    for i, word in enumerate(sentence_words[:-1]):
        input = [word_to_idx[word] for word in sentence_words[:i + 1]] 
        input = input + [0] * (max_len - len(input))
        target = word_to_idx[sentence_words[i + 1]]
        
        input_batch.append(input_vector[input])
        target_batch.append(target)
        
    return input_batch, target_batch

In [35]:
class BiLSTM(nn.Module):
    ### LSTM과 코드가 달라진 부분 ###
    # self.lstm의 bidirectional 파라미터
    # self.W의 차원
    # hidden_state의 차원
    # cell_state의 차원
    
    def __init__(self):
        super(BiLSTM, self).__init__()
        
        self.lstm = nn.LSTM(input_size=vocab_size, hidden_size=n_hidden, bidirectional=True)
        self.W = nn.Linear(n_hidden * 2, vocab_size, bias=False)
        self.b = nn.Parameter(torch.ones([vocab_size]))
        
    def forward(self, X):
        # X: (batch_size, n_step, vocab_size)
        batch_size = len(X)
        
        X = X.transpose(0, 1) # X: (n_step, batch_size, vocab_size)
        hidden_state = torch.zeros(1*2, batch_size, n_hidden) # 1*2 = num_layers*num_directions
        cell_state = torch.zeros(1*2, batch_size, n_hidden)
        
        outputs, (_, _) = self.lstm(X, (hidden_state, cell_state))
        outputs = outputs[-1] # (batch_size, n_hidden * 2)
        result = self.W(outputs) + self.b # (batch_size, vocab_size)
        return result

In [36]:
n_hidden = 5

In [37]:
sentences = ["Lorem ipsum dolor sit amet consectetur adipisicing elit", "sed do eiusmod tempor incididunt ut labore et dolore magna", "aliqua Ut enim ad minim veniam quis nostrud exercitation"]

words = " ".join(sentences).split()
max_len = len(words)
words = list(set(words)) # 중복 단어 제거

word_to_idx = {word: idx for idx, word in enumerate(words)}
idx_to_word = {idx: word for idx, word in enumerate(words)}
vocab_size = len(words)

In [38]:
model = BiLSTM()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

input_batch, target_batch = make_batch()
input_batch = torch.FloatTensor(input_batch)
target_batch = torch.LongTensor(target_batch)

In [39]:
# Training
for epoch in range(10000):
    optimizer.zero_grad()
    output = model(input_batch)
    loss = criterion(output, target_batch)
    
    if (epoch + 1) % 1000 == 0:
        print("Epoch : {:4d}  loss : {:.6f}".format(epoch + 1, loss))
        
    loss.backward()
    optimizer.step()

Epoch : 1000  loss : 1.823740
Epoch : 2000  loss : 1.491018
Epoch : 3000  loss : 1.247711
Epoch : 4000  loss : 1.144609
Epoch : 5000  loss : 1.002288
Epoch : 6000  loss : 0.513472
Epoch : 7000  loss : 0.311971
Epoch : 8000  loss : 0.210143
Epoch : 9000  loss : 0.150546
Epoch : 10000  loss : 0.118463


In [41]:
# Predict
predicts = model(input_batch).data.max(1, keepdim=True)[1]

print(sentences)
print([idx_to_word[predict.item()] for predict in predicts.squeeze()])

['Lorem ipsum dolor sit amet consectetur adipisicing elit', 'sed do eiusmod tempor incididunt ut labore et dolore magna', 'aliqua Ut enim ad minim veniam quis nostrud exercitation']
['ipsum', 'dolor', 'sit', 'amet', 'amet', 'adipisicing', 'elit', 'sed', 'do', 'eiusmod', 'tempor', 'incididunt', 'ut', 'labore', 'et', 'dolore', 'magna', 'aliqua', 'Ut', 'enim', 'ad', 'minim', 'veniam', 'quis', 'nostrud', 'exercitation']
