In [14]:
import numpy as np
import matplotlib as plt
import time
from tqdm import *


In [15]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim

In [28]:
input_size=4
hidden_size=4
num_layers=1
batch_size=1
seq_len=5
# 准备数据
idx2char=['e','h','l','o']
x_data=[1,0,2,2,3] # hello
y_data=[3,1,2,3,2] # ohlol

one_hot_lookup=[[1,0,0,0],
                [0,1,0,0],
                [0,0,1,0],
                [0,0,0,1]] #分别对应0,1,2,3项
x_one_hot=[one_hot_lookup[x] for x in x_data] # 组成序列张量
print('x_one_hot:',x_one_hot)

# 构造输入序列和标签
inputs=torch.Tensor(x_one_hot).view(-1,batch_size,input_size)
labels=torch.LongTensor(y_data)
#直接输入一维label，交叉熵只接受一维label
#在上一篇的RNNcell中因为输入的数据不是一次性完整输入，所以需要变形以便每次循环提取一个一维label
print(labels.shape)

x_one_hot: [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
torch.Size([5])


In [29]:
# design model
class Model(torch.nn.Module):
    def __init__(self,input_size,hidden_size,batch_size,num_layers=1):
        super(Model, self).__init__()
        self.num_layers=num_layers
        self.batch_size=batch_size
        self.input_size=input_size
        self.hidden_size=hidden_size
        self.rnn=torch.nn.RNN(input_size=self.input_size,
                              hidden_size=self.hidden_size,
                              num_layers=self.num_layers)

    def forward(self,input):
        hidden=torch.zeros(self.num_layers,self.batch_size,self.hidden_size)
        out, _=self.rnn(input,hidden)
        return out.view(-1,self.hidden_size)
    #变成（seqlen x batch_size,hiddensize）,因为交叉熵损失接受二维Tensor

net=Model(input_size,hidden_size,batch_size,num_layers)

In [30]:
# loss and optimizer
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(), lr=0.05)

In [31]:
# train cycle
for epoch in range(20):
    optimizer.zero_grad()
    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    optimizer.step()

    _, idx=outputs.max(dim=1)
    idx=idx.data.numpy()
    print('Predicted: ',''.join([idx2char[x] for x in idx]),end='')
    print(',Epoch [%d/20] loss=%.3f' % (epoch+1, loss.item()))

Predicted:  eello,Epoch [1/20] loss=1.318
Predicted:  eello,Epoch [2/20] loss=1.214
Predicted:  ehloo,Epoch [3/20] loss=1.128
Predicted:  ohloo,Epoch [4/20] loss=1.050
Predicted:  ohloo,Epoch [5/20] loss=0.976
Predicted:  ohloo,Epoch [6/20] loss=0.906
Predicted:  ohlol,Epoch [7/20] loss=0.839
Predicted:  ohlol,Epoch [8/20] loss=0.777
Predicted:  ohlol,Epoch [9/20] loss=0.722
Predicted:  ohlol,Epoch [10/20] loss=0.674
Predicted:  ohlol,Epoch [11/20] loss=0.635
Predicted:  ohlol,Epoch [12/20] loss=0.602
Predicted:  ohlol,Epoch [13/20] loss=0.575
Predicted:  ohlol,Epoch [14/20] loss=0.551
Predicted:  ohlol,Epoch [15/20] loss=0.530
Predicted:  ohlol,Epoch [16/20] loss=0.511
Predicted:  ohlol,Epoch [17/20] loss=0.493
Predicted:  ohlol,Epoch [18/20] loss=0.475
Predicted:  ohlol,Epoch [19/20] loss=0.458
Predicted:  ohlol,Epoch [20/20] loss=0.443
