In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [5]:
def make_batch():
    input_batch, target_batch = [], []
    
    for seq in seq_data:
        input = [ch_to_idx[ch] for ch in seq[:-1]]
        target = ch_to_idx[seq[-1]]
        
        input_vector = np.eye(n_class)
        
        input_batch.append(input_vector[input])
        target_batch.append(target)
        
    return input_batch, target_batch

In [7]:
class TextLSTM(nn.Module):
    def __init__(self):
        super(TextLSTM, self).__init__()
        
        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden)
        self.W = nn.Linear(n_hidden, n_class, bias=False) # n_hidden x n_class
        self.b = nn.Parameter(torch.ones([n_class]))
        
    def forward(self, X):
        # X: (batch_size, n_step, vocab_size)
        
        X = X.transpose(0, 1) # (n_step, batch_size, vocab_size)
        hidden_state = torch.zeros(1, batch_size, n_hidden)
        cell_state = torch.zeros(1, batch_size, n_hidden)
        
        outputs, (_, _) = self.lstm(X, (hidden_state, cell_state))
        outputs = outputs[-1] # (batch_size, n_hidden)
        result = self.W(outputs) + self.b # (batch_size, n_class)
        return result

In [41]:
n_step = 3 # number of cells(steps), = seq - 1
n_hidden = 10 # number of hidden units of one cell

ch_list = [ch for ch in 'abcdefghijklmnopqrstuvwxyz']
ch_to_idx = {ch: idx for idx, ch in enumerate(ch_list)}
idx_to_ch = {idx: ch for idx, ch in enumerate(ch_list)}
n_class = len(ch_list) # number of vocab

In [42]:
seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash']

batch_size = len(seq_data)

In [43]:
model = TextLSTM()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [44]:
input_batch, target_batch = make_batch()
input_batch = torch.FloatTensor(input_batch)
target_batch = torch.LongTensor(target_batch)

In [45]:
# Training
for epoch in range(1000):
    optimizer.zero_grad()
    
    output = model(input_batch)
    loss = criterion(output, target_batch)
    if (epoch + 1) % 100 == 0:
        print('Epoch : {:4d}  loss : {:.6f}'.format(epoch + 1, loss))
        
    loss.backward()
    optimizer.step()

Epoch :  100  loss : 2.512672
Epoch :  200  loss : 1.317537
Epoch :  300  loss : 1.041553
Epoch :  400  loss : 0.732688
Epoch :  500  loss : 0.461911
Epoch :  600  loss : 0.289294
Epoch :  700  loss : 0.174738
Epoch :  800  loss : 0.102245
Epoch :  900  loss : 0.069598
Epoch : 1000  loss : 0.051509


In [46]:
# Predict
inputs = [seq[:-1] for seq in seq_data]
predict = model(input_batch).data.max(1, keepdim=True)[1].squeeze() 

In [47]:
predict # (batch_size)

tensor([ 4,  3, 11,  3,  4,  4,  4,  4,  7])

In [48]:
for i, idx in enumerate(predict):
    print(inputs[i], '->', idx_to_ch[idx.item()])

mak -> e
nee -> d
coa -> l
wor -> d
lov -> e
hat -> e
liv -> e
hom -> e
has -> h
