In [36]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [37]:
# S: decoder input의 시작을 나타냄
# E: decoder output의 끝을 나타냄
# P: 현재 batch의 단어 길이가 n_step보다 작은 경우 빈 곳을 'P'로 채움

def make_batch():
    input_batch, output_batch, target_batch = [], [], []
    
    onehot_vector = np.eye(n_class)
    
    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + 'P' * (n_step - len(seq[i])) # 빈 곳 채우기
        
        input = [ch_to_idx[ch] for ch in seq[0]]
        output = [ch_to_idx[ch] for ch in ('S' + seq[1])] # decoder의 input
        target = [ch_to_idx[ch] for ch in (seq[1] + 'E')]
        
        input_batch.append(onehot_vector[input])
        output_batch.append(onehot_vector[output])
        target_batch.append(target)
        
    # Tensor로 변환
    input_batch = torch.FloatTensor(input_batch)
    output_batch = torch.FloatTensor(output_batch)
    target_batch = torch.LongTensor(target_batch)
        
    return input_batch, output_batch, target_batch

In [52]:
def make_testbatch(input_word):
    input_batch, output_batch = [], []
    
    onehot_vector = np.eye(n_class)
    
    input_word = input_word + 'P' * (n_step - len(input_word))
    input = [ch_to_idx[ch] for ch in input_word]
    output_word = 'S' + 'P' * n_step
    output = [ch_to_idx[ch] for ch in output_word]
    
    input_batch.append(onehot_vector[input])
    output_batch.append(onehot_vector[output])
    
    input_batch = torch.FloatTensor(input_batch)
    output_batch = torch.FloatTensor(output_batch)
    
    return input_batch, output_batch

In [39]:
class Seq2Seq(nn.Module):
    def __init__(self):
        super(Seq2Seq, self).__init__()
        
        self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.fc = nn.Linear(n_hidden, n_class)
        
    def forward(self, enc_input, enc_hidden, dec_input):
        enc_input = enc_input.transpose(0, 1) # (n_step, batch_size, n_class)
        dec_input = dec_input.transpose(0, 1) # (n_step, batch_size, n_class)
        
        _, enc_states = self.enc_cell(enc_input, enc_hidden) # enc_states: (num_layers(=1)*num_directions(=1), batch_size, n_hidden), hidden state를 의미
        outputs, _ = self.dec_cell(dec_input, enc_states)   # outputs: (n_step+1, batch_size, n_class)

        result = self.fc(outputs) # (n_step+1, batch_size, n_class)
        
        return result

In [66]:
ch_list = [ch for ch in 'SEPabcdefghijklmnopqrstuvwxyz']
ch_to_idx = {ch: idx for idx, ch in enumerate(ch_list)}
idx_to_ch = {idx: ch for idx, ch in enumerate(ch_list)}

seq_data = [['man', 'woman'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]

In [41]:
n_step = 5
n_hidden = 128

n_class = len(ch_list)
batch_size = len(seq_data)

In [42]:
model = Seq2Seq()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

  "num_layers={}".format(dropout, num_layers))


In [43]:
input_batch, output_batch, target_batch = make_batch()

In [44]:
# Training
for epoch in range(5000):
    optimizer.zero_grad()
    
    hidden = torch.zeros(1, batch_size, n_hidden)
    
    output = model(input_batch, hidden, output_batch)
    output = output.transpose(0, 1) # (batch_size, n_step+1, n_class)
    
    loss = 0
    for i in range(len(output)):
        loss += criterion(output[i], target_batch[i])
    
    if (epoch + 1) % 1000 == 0:
        print("Epoch: {:4d}  loss: {:.6f}".format(epoch + 1, loss))
    
    loss.backward()
    optimizer.step()

Epoch: 1000  loss: 0.003422
Epoch: 2000  loss: 0.000942
Epoch: 3000  loss: 0.000402
Epoch: 4000  loss: 0.000202
Epoch: 5000  loss: 0.000109


In [83]:
# Test
def test(word):
    input_batch, output_batch = make_testbatch(word)
    
    hidden = torch.zeros(1*1, 1, n_hidden) # (num_layers*num_directions, batch_size, n_hidden)
    output = model(input_batch, hidden, output_batch)
    
    predicts = output.data.max(2, keepdim=True)[1]
    predicts_decoded = [idx_to_ch[predict.item()] for predict in predicts]
    end_idx = predicts_decoded.index('E')
    result = "".join(predicts_decoded[:end_idx]).replace('P', '')
    
    return result

In [85]:
print('======== test =========')

for seq in seq_data:
    word = seq[0]
    print(word, '->', test(word))

man -> woman
black -> white
king -> queen
girl -> boy
up -> down
high -> low
