In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

## Character-level RNN

### 데이터 전처리

In [24]:
# 훈련 데이터
sentence = ("if you want to build a ship, don't drum up people together to "
            "collect wood and don't assign them tasks and work, but rather "
            "teach them to long for the endless immensity of the sea.")

In [25]:
# 중복을 제거한 문자 집합 생성 후 정수 인코딩
char_set = list(set(sentence)) 
char_dic = {c: i for i, c in enumerate(char_set)}
dic_size = len(char_dic)
print(len(char_dic))
print(char_dic)

25
{'p': 0, 'r': 1, 'm': 2, 'u': 3, 'a': 4, 'b': 5, 'y': 6, ' ': 7, 'w': 8, 'l': 9, 'f': 10, 'e': 11, 'k': 12, 'n': 13, 't': 14, 'i': 15, ',': 16, 'o': 17, 'h': 18, '.': 19, 's': 20, 'c': 21, "'": 22, 'g': 23, 'd': 24}


In [26]:
# 하이퍼파라미터 설정
hidden_size = dic_size
sequence_length = 10  # 데이터 분할 길이
learning_rate = 0.1

In [27]:
# 데이터 구성
x_data = []
y_data = []

for i in range(0, len(sentence) - sequence_length):
    x_str = sentence[i:i + sequence_length]
    y_str = sentence[i + 1: i + sequence_length + 1]
    print(i, x_str, '->', y_str)

    x_data.append([char_dic[c] for c in x_str])  # x str to index
    y_data.append([char_dic[c] for c in y_str])  # y str to index

0 if you wan -> f you want
1 f you want ->  you want 
2  you want  -> you want t
3 you want t -> ou want to
4 ou want to -> u want to 
5 u want to  ->  want to b
6  want to b -> want to bu
7 want to bu -> ant to bui
8 ant to bui -> nt to buil
9 nt to buil -> t to build
10 t to build ->  to build 
11  to build  -> to build a
12 to build a -> o build a 
13 o build a  ->  build a s
14  build a s -> build a sh
15 build a sh -> uild a shi
16 uild a shi -> ild a ship
17 ild a ship -> ld a ship,
18 ld a ship, -> d a ship, 
19 d a ship,  ->  a ship, d
20  a ship, d -> a ship, do
21 a ship, do ->  ship, don
22  ship, don -> ship, don'
23 ship, don' -> hip, don't
24 hip, don't -> ip, don't 
25 ip, don't  -> p, don't d
26 p, don't d -> , don't dr
27 , don't dr ->  don't dru
28  don't dru -> don't drum
29 don't drum -> on't drum 
30 on't drum  -> n't drum u
31 n't drum u -> 't drum up
32 't drum up -> t drum up 
33 t drum up  ->  drum up p
34  drum up p -> drum up pe
35 drum up pe -> rum up peo
36

In [28]:
print(x_data[0])
print(y_data[0])

[15, 10, 7, 6, 17, 3, 7, 8, 4, 13]
[10, 7, 6, 17, 3, 7, 8, 4, 13, 14]


In [29]:
# x 데이터는 원-핫 인코딩
x_one_hot = [np.eye(dic_size)[x] for x in x_data]
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)
print(X.shape)
print(Y.shape)

torch.Size([170, 10, 25])
torch.Size([170, 10])


### 모델 구현

In [30]:
class Net(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, layers):
        super(Net, self).__init__()
        self.rnn = torch.nn.RNN(input_dim, hidden_dim, num_layers=layers, batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)

    def forward(self, x):
        x, _status = self.rnn(x)
        x = self.fc(x)
        return x
    
net = Net(dic_size, hidden_size, 2) # 레이어 2개

In [31]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), learning_rate)

In [32]:
outputs = net(X)
print(outputs.shape)
print(outputs.view(-1, dic_size).shape)

torch.Size([170, 10, 25])
torch.Size([1700, 25])


In [33]:
print(Y.shape)
print(Y.view(-1).shape)

torch.Size([170, 10])
torch.Size([1700])


In [34]:
for i in range(100):
    optimizer.zero_grad()
    outputs = net(X) # (170, 10, 25)
    loss = loss_fn(outputs.view(-1, dic_size), Y.view(-1))
    loss.backward()
    optimizer.step()

    results = outputs.argmax(dim=2) # (170, 10)
    predict_str = ""
    for j, result in enumerate(results):
        # 처음에는 예측 결과를 전부 가져오지만
        # 그 다음에는 마지막 글자만 반복 추가
        if j == 0: 
            predict_str += ''.join([char_set[t] for t in result])
        else: 
            predict_str += char_set[result[-1]]

    print(predict_str)

oeobnoy,oooneswseowootoososyweowoeneooeoyowoeoowooeoooesoooeowsowooooosoooooookeoooenooooeooeoeeeeoostnoooktoosnwsoenyoesooweooooeeseeeeoweooonno,oeseeeoooooonoeeoooseseonoeseeoyo
 n nn t n n nn nnnnnnnt nnn n n n t n nnnnnnnnnnn nn n nnnnnn nnnnn nnnnnnnn nn n n tnn nn nnnn nnn n tnt n n t tnn t n nnnnnn nnn nn nnn nnnnnnn n nntnn nn nnnn n nnn nnnn nntnn 
e t t t s s othe s etos s s s s s s s s s s  s e sht s e s et s e s es s e e et s s s s e s s osee st s s s s s s s s s s s e  s s s e s et s s eot e s s s sos s e s e  s s s soe 
 o'o  fo    o oo on  o o  o  ono  no oono  o  n   ono   n  onwh o     o ooto  no o   o  oo oo  oto  n d   no    do ono o  oto    oh o th ro    onodo ooo  no  no  o  n  oo  o oo   
    w  wo    w  t t w ww   ttw t   w t ttt w      t t wtw w  tt t, t ww t ww    wt w   t g t t     tw t, tt  t t w  tw w w t t   t, w t  tw t  w t wt w  w w  t wt gt  tt  tw t  tt
  tp  to t t  t t   t   tt  t    tt  t t  to t ttt    pt  t  tot     t   tt   t   ttt   t     t t   