In [1]:
import numpy as np
import matplotlib as plt
import time
from tqdm import *


In [2]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim

In [3]:
import torch.nn as nn
# parameters
num_class = 4 
input_size = 4 
hidden_size = 8 
embedding_size = 10 
num_layers = 2 
batch_size = 1 
seq_len = 5

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = torch.nn.Embedding(input_size, embedding_size)  #嵌入层
        self.rnn = nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_class)  #全连接层

    def forward(self, x):
        hidden = torch.zeros(num_layers, x.size(0), hidden_size)
        x = self.emb(x)                 # (batch, seqLen, （input_size）embeddingSize) 
        x, _ = self.rnn(x, hidden)      # 输出(𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆, 𝒔𝒆𝒒𝑳𝒆𝒏, hidden_size)
        x = self.fc(x)                  # 输出(𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆, 𝒔𝒆𝒒𝑳𝒆𝒏, 𝒏𝒖𝒎𝑪𝒍𝒂𝒔𝒔)
        return x.view(-1, num_class)    # reshape to use Cross Entropy: (𝒃𝒂𝒕𝒄𝒉𝑺𝒊𝒛𝒆×𝒔𝒆𝒒𝑳𝒆𝒏, 𝒏𝒖𝒎𝑪𝒍𝒂𝒔𝒔)
        #经过嵌入层的降维，然后进过RNN网络，最后经过全连接层，输出的最后一个参数维度变为分类类别个数，替代词典作用
net = Model()


criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)

In [4]:
idx2char = ['e', 'h', 'l', 'o'] 
x_data = [[1, 0, 2, 2, 3]]  # (batch, seq_len) 
y_data = [3, 1, 2, 3, 2]    # (batch * seq_len)

inputs = torch.LongTensor(x_data)   # Input should be LongTensor: (batchSize, seqLen) 二维
labels = torch.LongTensor(y_data)   # Target should be LongTensor: (batchSize * seqLen) 一维

epochs = 15

for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = net(inputs) 
    loss = criterion(outputs, labels) 
    loss.backward() 
    optimizer.step()  #包含之前所有的梯度

    _, idx = outputs.max(dim=1) 
    idx = idx.data.numpy() 
    print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='') 
    print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))


Predicted:  lllll, Epoch [1/15] loss = 1.288
Predicted:  ollll, Epoch [2/15] loss = 1.060
Predicted:  ohlll, Epoch [3/15] loss = 0.860
Predicted:  ohlll, Epoch [4/15] loss = 0.632
Predicted:  ohlol, Epoch [5/15] loss = 0.418
Predicted:  ohlol, Epoch [6/15] loss = 0.270
Predicted:  ohlol, Epoch [7/15] loss = 0.175
Predicted:  ohlol, Epoch [8/15] loss = 0.113
Predicted:  ohlol, Epoch [9/15] loss = 0.073
Predicted:  ohlol, Epoch [10/15] loss = 0.047
Predicted:  ohlol, Epoch [11/15] loss = 0.031
Predicted:  ohlol, Epoch [12/15] loss = 0.022
Predicted:  ohlol, Epoch [13/15] loss = 0.016
Predicted:  ohlol, Epoch [14/15] loss = 0.012
Predicted:  ohlol, Epoch [15/15] loss = 0.009
