In [1]:
%matplotlib inline

In [2]:
import torch
import torch.optim as optim
import numpy as np

![rnn_architectures](../_static/rnn_architectures.png)

(ref: http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture10.pdf)

본 실습에서는 문자열 입력을 받아 문자열을 예측하므로, 5번째 그림(many to many)에 해당합니다.

## 입력 데이터

RNN에 문장을 입력하기 위해선 모든 문자를 텐서 형태로 바꿔주어야 합니다. 따라서 앞장의 경우와 마찬가지로 각 문자가 서로 다른 벡터로 매핑되도록 변환해줍니다. 아래는 그 예시입니다. 

In [3]:
sample = " I think therefore I am."

문장 내 각 문자가 서로 다른 인덱스를 갖도록 dictionary 형태로 매핑

In [4]:
char_set = list(set(sample))
char_dic = {c: i for i, c in enumerate(char_set)}
print(char_dic)

{'h': 0, 'e': 1, '.': 2, 't': 3, 'o': 4, 'k': 5, 'r': 6, 'm': 7, 'I': 8, 'a': 9, ' ': 10, 'n': 11, 'f': 12, 'i': 13}


In [5]:
# hyper parameters
dic_size = len(char_dic)
hidden_size = len(char_dic)
learning_rate = 0.1

In [6]:
dic_size, hidden_size, learning_rate

(14, 14, 0.1)

dictionary 로 서로 다른 인덱스로 매핑된 각 문자 기반으로 입력 데이터를 벡터화 (one-hot encoding)

In [7]:
sample_idx = [char_dic[c] for c in sample]
x_data = [sample_idx[:-1]]
x_one_hot = [np.eye(dic_size)[x] for x in x_data]
y_data = [sample_idx[1:]]

In [8]:
# transform as torch tensor variable
X = torch.tensor(x_one_hot).float()
Y = torch.tensor(y_data)

입력 데이터 확인

In [9]:
X.view(-1, dic_size), Y.view(-1)

(tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 

In [10]:
X.shape

torch.Size([1, 23, 14])

## RNN 선언

In [11]:
dic_size, hidden_size

(14, 14)

In [12]:
# declare RNN
rnn = torch.nn.LSTM(dic_size, hidden_size, batch_first=True)

RNN의 아웃풋이 마지막 레이어의 가중치로부터 계산된 logit 값들이므로 손실 함수로 nn.CrossEntropyLoss 가 적합합니다.  
  
optimizer 로 optim.Adam 을 사용하도록 합니다.

In [13]:
# loss & optimizer setting
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn.parameters(), learning_rate)

입력 데이터에 대한 RNN 모델의 출력 확인

In [14]:
outputs, _status = rnn(X)

result = outputs.data.numpy().argmax(axis=2)
result_str = ''.join([char_set[c] for c in np.squeeze(result)])

print("prediction: ", result)
print("true Y: ", y_data)
print("prediction str: ", result_str)

prediction:  [[13 13 13 13  7  5  5  5 13 13  7  7 13 13  7 13 13 13 13 13 13 13  7]]
true Y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  iiiimkkkiimmiimiiiiiiim


## 학습

각 학습 루프는 다음과 같습니다:

-  각 문자를 읽기

   -  다음 문자를 위한 은닉 상태 유지

-  목표와 최종 출력 비교
-  역전파
-  출력과 손실 반환

In [15]:
# start training
for i in range(300):
    optimizer.zero_grad()
    outputs, _status = rnn(X)
    loss = criterion(outputs.view(-1, dic_size), Y.view(-1))
    loss.backward()
    optimizer.step()

    result = outputs.data.numpy().argmax(axis=2)
    result_str = ''.join([char_set[c] for c in np.squeeze(result)])
    
    print("Epoch :", i, "- loss: ", loss.item())
    print("pred_y: ", result)
    print("true_y: ", y_data)
    print("prediction str: ", result_str)
    print("========================")

Epoch : 0 - loss:  2.6722252368927
pred_y:  [[13 13 13 13  7  5  5  5 13 13  7  7 13 13  7 13 13 13 13 13 13 13  7]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  iiiimkkkiimmiimiiiiiiim
Epoch : 1 - loss:  2.564969539642334
pred_y:  [[8 8 8 6 6 6 5 6 6 6 6 6 8 6 6 6 6 6 6 6 6 6 6]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  IIIrrrkrrrrrIrrrrrrrrrr
Epoch : 2 - loss:  2.483957529067993
pred_y:  [[8 8 8 6 1 6 6 6 6 6 6 6 6 6 6 6 1 6 6 6 6 6 6]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  IIIrerrrrrrrrrrrerrrrrr
Epoch : 3 - loss:  2.3673200607299805
pred_y:  [[ 8 10  8  1  1  1  5  6  6  1  1  1  1  1  1  1  1  1  1  1  1  1  1]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I Ieeekrreeeeeeeeeeeeee
Epoch : 4 - loss:  2.288546562194824
pred_y:  [[ 8 10  

Epoch : 38 - loss:  1.3052399158477783
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 39 - loss:  1.3035937547683716
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 40 - loss:  1.3011397123336792
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 41 - loss:  1.297919750213623
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
E

true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 75 - loss:  1.2705596685409546
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 76 - loss:  1.2686604261398315
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 77 - loss:  1.268069863319397
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 78 - loss:  1.266727328300476
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
tr

Epoch : 114 - loss:  1.2428336143493652
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 115 - loss:  1.242582082748413
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 116 - loss:  1.2423185110092163
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 117 - loss:  1.242039442062378
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am

Epoch : 151 - loss:  1.236158013343811
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 152 - loss:  1.236038088798523
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 153 - loss:  1.2359193563461304
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 154 - loss:  1.2358016967773438
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am

Epoch : 183 - loss:  1.2327927350997925
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 184 - loss:  1.2326956987380981
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 185 - loss:  1.2325985431671143
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 186 - loss:  1.23250150680542
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am

Epoch : 230 - loss:  1.2285492420196533
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 231 - loss:  1.228488564491272
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 232 - loss:  1.228428840637207
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am.
Epoch : 233 - loss:  1.2283705472946167
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  4  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefore I am

pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  6  6  1 10  8 10  9  2  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefrre I a..
Epoch : 265 - loss:  1.3349125385284424
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  6  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefrre I am.
Epoch : 266 - loss:  1.313093900680542
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  6  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefrre I am.
Epoch : 267 - loss:  1.313328504562378
pred_y:  [[ 8 10  3  0 13 11  5 10  3  0  1  6  1 12  6  6  1 10  8 10  9  7  2]]
true_y:  [[8, 10, 3, 0, 13, 11, 5, 10, 3, 0, 1, 6, 1, 12, 4, 6, 1, 10, 8, 10, 9, 7, 2]]
prediction str:  I think therefrre I am.
Epoch : 268 - loss:  1.316688418388366

------------------------------------------------------------------------------------------




# Practice

#### Q1. 입력 문장을 " I think therefore I am." 로 변경한 뒤, 본 실습을 수행해볼 것.

#### Q2. 학습이 잘 되는 지 확인하시오.

#### Q2. 학습이 잘 되지 않는다면 RNN 모듈을 LSTM 모듈로 바꿔 학습을 수행해보시오.

------------------------------------------------------------------------------------------




__(ref: https://github.com/deeplearningzerotoall/PyTorch)__