In [1]:
import torch

batch_size = 1
seq_len = 5
input_size = 4
hidden_size = 4
num_layers = 1

# 生成数据

In [2]:
idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3]
y_data = [3, 1, 2, 3, 2]

one_hot_lookup = [[1, 0, 0, 0],
                 [0, 1, 0, 0],
                 [0, 0, 1, 0],
                 [0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]  # 生成独热向量

inputs = torch.Tensor(x_one_hot).view(seq_len, batch_size, input_size)
labels = torch.LongTensor(y_data)

# 定义模型

In [3]:
class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, batch_size, num_layers=1):
        """
        继承Module类
        super超类
        重写__init__, forward方法
        调用RNN函数
        """
        super(Model, self).__init__()
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.rnn = torch.nn.RNN(input_size=self.input_size,
                               hidden_size=self.hidden_size,
                               num_layers=num_layers)
    
    def forward(self, input):
        hidden = torch.zeros(self.num_layers,
                            self.batch_size,
                            self.hidden_size)
        out, _ = self.rnn(input, hidden)
        
        return out.view(-1, self.hidden_size)

In [4]:
net = Model(input_size, hidden_size, batch_size, num_layers)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)

In [5]:
for epoch in range(30):
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    _, idx = outputs.max(dim=1)
    idx = idx.data.numpy()
    print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
    print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

Predicted:  ooooo, Epoch [1/15] loss = 1.379
Predicted:  ooooo, Epoch [2/15] loss = 1.243
Predicted:  ooooo, Epoch [3/15] loss = 1.120
Predicted:  ohloo, Epoch [4/15] loss = 1.028
Predicted:  hhlol, Epoch [5/15] loss = 0.968
Predicted:  hhlol, Epoch [6/15] loss = 0.924
Predicted:  hhlol, Epoch [7/15] loss = 0.882
Predicted:  hhlol, Epoch [8/15] loss = 0.836
Predicted:  ohlol, Epoch [9/15] loss = 0.789
Predicted:  ohlol, Epoch [10/15] loss = 0.743
Predicted:  ohlol, Epoch [11/15] loss = 0.701
Predicted:  ohlol, Epoch [12/15] loss = 0.664
Predicted:  ohlol, Epoch [13/15] loss = 0.631
Predicted:  ohlol, Epoch [14/15] loss = 0.601
Predicted:  ohlol, Epoch [15/15] loss = 0.575
Predicted:  ohlol, Epoch [16/15] loss = 0.551
Predicted:  ohlol, Epoch [17/15] loss = 0.530
Predicted:  ohlol, Epoch [18/15] loss = 0.510
Predicted:  ohlol, Epoch [19/15] loss = 0.491
Predicted:  ohlol, Epoch [20/15] loss = 0.475
Predicted:  ohlol, Epoch [21/15] loss = 0.460
Predicted:  ohlol, Epoch [22/15] loss = 0.4