In [1]:
import torch

input_size = 4
hidden_size = 4
batch_size = 1

# 准备数据

In [2]:
idx2char = ['e', 'h', 'l', 'o']  # h:0, e:1, l:2, o:3
x_data = [1, 0, 2, 2, 3]  # hello
y_data = [3, 1, 2, 3, 2]  # oelol

one_hot_lookup = [[1, 0, 0, 0],
                 [0, 1, 0, 0],
                 [0, 0, 1, 0],
                 [0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]  # 生成hello的独热编码

In [3]:
x_one_hot

[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]]

In [4]:
inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size) 
labels = torch.LongTensor(y_data).view(-1, 1)

In [5]:
inputs   # (5, 1, 4)

tensor([[[0., 1., 0., 0.]],

        [[1., 0., 0., 0.]],

        [[0., 0., 1., 0.]],

        [[0., 0., 1., 0.]],

        [[0., 0., 0., 1.]]])

In [6]:
labels

tensor([[3],
        [1],
        [2],
        [3],
        [2]])

# 定义模型并生成

In [7]:
class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, batch_size):
        super(Model, self).__init__()
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.rnncell = torch.nn.RNNCell(input_size=self.input_size,
                                       hidden_size=self.hidden_size)
        
    def forward(self, input, hidden):  # 生成隐层
        hidden = self.rnncell(input, hidden)
        return hidden
    
    def init_hidden(self):  # 生成初始化的全零的h0
        return torch.zeros(self.batch_size, self.hidden_size)
    
net = Model(input_size, hidden_size, batch_size)

# 训练过程

In [8]:
criterion = torch.nn.CrossEntropyLoss()  # 交叉熵计算损失
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)  # 优化器

# 使用RNN Cell

In [9]:
for epoch in range(20):
    loss = 0
    optimizer.zero_grad()
    hidden = net.init_hidden()
    print('Predicted string: ', end='')
    for input, label in zip(inputs, labels):
        hidden = net(input, hidden)
        loss += criterion(hidden, label)
        _, idx = hidden.max(dim=1)
        print(idx2char[idx.item()], end='')
    loss.backward()
    optimizer.step()
    print(', Epoch [%d/15] loss=%.4f' % (epoch+1, loss.item()))

Predicted string: oholo, Epoch [1/15] loss=5.4520
Predicted string: ooolo, Epoch [2/15] loss=4.8232
Predicted string: oolll, Epoch [3/15] loss=4.5272
Predicted string: ohlol, Epoch [4/15] loss=4.2553
Predicted string: ohlol, Epoch [5/15] loss=3.9649
Predicted string: ohlol, Epoch [6/15] loss=3.6029
Predicted string: ohlol, Epoch [7/15] loss=3.1870
Predicted string: ohlol, Epoch [8/15] loss=2.7634
Predicted string: ohlol, Epoch [9/15] loss=2.4562
Predicted string: ohlol, Epoch [10/15] loss=2.2705
Predicted string: ohlol, Epoch [11/15] loss=2.1491
Predicted string: ohlol, Epoch [12/15] loss=2.0704
Predicted string: ohlol, Epoch [13/15] loss=2.0208
Predicted string: ohlol, Epoch [14/15] loss=1.9786
Predicted string: ohlol, Epoch [15/15] loss=1.9357
Predicted string: ohlol, Epoch [16/15] loss=1.8953
Predicted string: ohlol, Epoch [17/15] loss=1.8600
Predicted string: ohlol, Epoch [18/15] loss=1.8304
Predicted string: ohlol, Epoch [19/15] loss=1.8068
Predicted string: ohlol, Epoch [20/15] l