In [1]:
import torch
from torch import nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
# 超参数
EPOCH = 1
batch_size = 64
time_size = 28
input_size = 28
lr = 0.01

In [3]:
# MINST数据集加载
train_data = datasets.MNIST(root='.mnist', train=True, 
                            transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST(root='.mnist', train=False, 
                            transform=transforms.ToTensor(), download=True)
test_x = test_data.data.type(torch.FloatTensor)[:2000]/255
test_y = test_data.targets.numpy()[:2000]

# plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
# plt.show()

In [4]:
# DataLoader进行分批
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

In [5]:
# 定义网络
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.GRU(input_size=input_size, hidden_size=64,
                          num_layers=1, batch_first=True)
        self.out = nn.Linear(64, 10)  # 10个分类

    def forward(self, x):
        # 前向传播
        r_out, _ = self.rnn(x)
        # 选择最后一个时间步
        out = self.out(r_out[:, -1, :])
        return out

In [6]:
# 设置使用GPU
cuda = torch.device('cuda')
rnn = RNN()
rnn = rnn.cuda()
optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)
loss_func = nn.CrossEntropyLoss()

In [7]:
# 训练 & 验证
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader):
        b_x = b_x.view(-1, 28, 28)
        output = rnn(b_x.cuda())
        loss = loss_func(output, b_y.cuda())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            test_output = rnn(test_x.cuda())
            pred_y = torch.max(test_output, 1)[1].data.cpu().numpy()
            accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
            print('Epoch: {}, Step: {}, loss: {}, accuracy: {}'.format(epoch, step, loss, accuracy))


Epoch: 0, Step: 0, loss: 2.296750783920288, accuracy: 0.104
Epoch: 0, Step: 100, loss: 0.474337100982666, accuracy: 0.801
Epoch: 0, Step: 200, loss: 0.4759863018989563, accuracy: 0.8905
Epoch: 0, Step: 300, loss: 0.07621700316667557, accuracy: 0.916
Epoch: 0, Step: 400, loss: 0.36875393986701965, accuracy: 0.951
Epoch: 0, Step: 500, loss: 0.06785818934440613, accuracy: 0.9495
Epoch: 0, Step: 600, loss: 0.05693020671606064, accuracy: 0.9565
Epoch: 0, Step: 700, loss: 0.11919639259576797, accuracy: 0.959
Epoch: 0, Step: 800, loss: 0.17804071307182312, accuracy: 0.957
Epoch: 0, Step: 900, loss: 0.10407861322164536, accuracy: 0.962


In [8]:
# 从测试集中选择10个进行验证
test_x = test_x.cuda()
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.cpu().numpy()
print('预测数字: ', pred_y)
print('实际数字: ', test_y[:10])

预测数字:  [7 2 1 0 4 1 4 9 5 9]
实际数字:  [7 2 1 0 4 1 4 9 5 9]
