In [1]:
# LSTM은 RNN에서 장기기억 담당 부분 추가한 것.
# 기존 hidden state에서 cell state 부분 추가.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [11]:
string = "hello pytorch. how long can a rnn cell remember? show me your limit!"
chars = "abcdefghijklmnopqrstuvwxyz ?!.,:;01"

char_list = [i for i in chars]
char_len = len(char_list)
print(len(string))
print(char_len)

68
35


In [12]:
#one hot(문자, 0,1값으로)
def string_to_onehot(string):
    start = np.zeros(shape = char_len, dtype = int) 
    end = np.zeros(shape = char_len, dtype = int) 

    start[-2] = 1 #시작값 지정.
    end[-1] = 1 #끝값 지정.

    for i in string:
        idx = char_list.index(i)
        zero = np.zeros(shape = char_len, dtype = int)
        zero[idx] = 1
        start = np.vstack([start, zero])
    output = np.vstack([start, end])
    return output

In [13]:
# 0,1값 문자로.
def onehot_to_word(onehot_1):
    onehot = torch.Tensor.numpy(onehot_1)
    return char_list[onehot.argmax()]

In [15]:
batch_size = 1 #단어 하나씩 잘라서 사용하려고. 1로 따로 고정.
seq_len = 1 #각 입력 독립적으로 처리. 2로 해주면 2개의 입력 한번에.

num_layers = 3
input_size = char_len #35개 문자.
hidden_size = 35
lr = 0.01
num_epochs = 1000

one_hot = torch.from_numpy(string_to_onehot(string)).type_as(torch.FloatTensor())

print(one_hot.size()) #원래 문장 길이는 68. 아까 생성해준 start, end가 합쳐져서 70.

torch.Size([70, 35])


In [16]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers) #파이토치 안에서 친절하게 LSTM함수 만들어놓음.

    def forward(self, input_, hidden, cell):
        output,(hidden,cell) = self.lstm(input_, (hidden, cell))
        return output, hidden, cell

    def init_hidden_cell(self):
        hidden = torch.zeros(num_layers, batch_size, hidden_size) # hidden layer
        cell = torch.zeros(num_layers, batch_size, hidden_size) # cell_layer. 이게 LSTM에 추가된거.
        return hidden, cell

rnn = RNN(input_size, hidden_size, num_layers)

In [17]:
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr = lr)

In [18]:
j = 0
input_data = one_hot[j:j+seq_len].view(seq_len, batch_size, input_size) #seq_len값은 아까 1로 받아준거.
print(input_data.size())

hidden, cell = rnn.init_hidden_cell()
print(hidden.size(), cell.size())

output, hidden, cell = rnn(input_data, hidden, cell)
print(output.size(), hidden.size(), cell.size())

torch.Size([1, 1, 35])
torch.Size([3, 1, 35]) torch.Size([3, 1, 35])
torch.Size([1, 1, 35]) torch.Size([3, 1, 35]) torch.Size([3, 1, 35])


In [19]:
unroll_len = one_hot.size()[0]//seq_len -1 #맨 마지막 글자 뺸 개수 만큼.
for i in range(num_epochs):
    hidden,cell = rnn.init_hidden_cell()
    
    loss = 0
    for j in range(unroll_len): #총 69번.
        input_data = one_hot[j:j+seq_len].view(seq_len, batch_size, input_size) #pytorch란 문자면, p 그래서 그냥 j
        label = one_hot[j+1:j+seq_len+1].view(seq_len, batch_size, input_size) # 그 다음 y인데, 그래서 j+1(한칸 더 간거.)
        
        optimizer.zero_grad()
        
        output, hidden, cell = rnn(input_data,hidden,cell)
        loss += loss_func(output.view(1,-1), label.view(1,-1))
        
    loss.backward()
    optimizer.step()

    if i%100 ==0:
        print(loss)

tensor(2.2671, grad_fn=<AddBackward0>)
tensor(0.0555, grad_fn=<AddBackward0>)
tensor(0.0074, grad_fn=<AddBackward0>)
tensor(0.0051, grad_fn=<AddBackward0>)
tensor(0.0044, grad_fn=<AddBackward0>)
tensor(0.0040, grad_fn=<AddBackward0>)
tensor(0.0038, grad_fn=<AddBackward0>)
tensor(0.0037, grad_fn=<AddBackward0>)
tensor(0.0035, grad_fn=<AddBackward0>)
tensor(0.0034, grad_fn=<AddBackward0>)


In [20]:
hidden,cell = rnn.init_hidden_cell()

for j in range(unroll_len-1):
    input_data = one_hot[j:j+1].view(1,batch_size,hidden_size) 
    label = one_hot[j+1:j+1+1].view(1,batch_size,hidden_size) 
    
    output, hidden, cell = rnn(input_data,hidden,cell)
    print(onehot_to_word(output.data),end="") 

    #밑에 출력결과 보면, 같은 띄어쓰기인데, 예측값이 달라지는 것은, 기억력이 좋다는 것.(LSTM 좋다.)

hello pytorch. how long can a rnn cell remember? show me your limit!