# RNN 做图像分类
前面我们讲了 RNN 特别适合做序列类型的数据，那么 RNN 能不能想 CNN 一样用来做图像分类呢？下面我们用 mnist 手写字体的例子来展示一下如何用 RNN 做图像分类，但是这种方法并不是主流，这里我们只是作为举例。

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torchvision import transforms as tfs
from torchvision.datasets import MNIST

In [2]:
# 定义数据
data_tf = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5], [0.5])   # 标准化
])

train_set = MNIST('data/', train=True, transform=data_tf)
test_set = MNIST('data', train=False, transform=data_tf)

train_data = DataLoader(train_set, 64, shuffle=True, num_workers=4)
test_data = DataLoader(test_set, 128, shuffle=False, num_workers=4)

In [3]:
# 定义模型
class rnn_classify(nn.Module):
    def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
        super(rnn_classify, self).__init__()
        self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)      # 使用两层 lstm
        self.classifier = nn.Linear(hidden_feature, num_class)   # 将最后一个rnn 的输出使用全连接得到最后的分类结果
        
    def forward(self, x):
        '''
        x 大小为 (batch, 1, 28, 28)，所以我们需要将其转换成 RNN 的输入形式，即 (28, batch, 28)
        '''
        x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1，变成 (batch, 28, 28)
        x = x.permute(2, 0, 1) # 将最后一维放到第一维，变成 (28, batch, 28)  相当于x.transpose()
        out, _ = self.rnn(x) # 使用默认的隐藏状态，得到的 out 是 (28, batch, hidden_feature)
        out = out[-1, :, :] # 取序列中的最后一个，大小是 (batch, hidden_feature)
        out = self.classifier(out) # 得到分类结果
        return out

In [5]:
# 调用模型
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 0.1)

In [7]:
# 开始训练
from utils import train
train(net, train_data, test_data, 10, optimizer, criterion)

  im = Variable(im.cuda(), volatile=True)
  label = Variable(label.cuda(), volatile=True)


Epoch 1. Train Loss: 0.056684, Train Acc: 0.982526, Valid Loss: 0.065243, Valid Acc: 0.980123, Time 00:00:12
Epoch 2. Train Loss: 0.050928, Train Acc: 0.984542, Valid Loss: 0.069091, Valid Acc: 0.979035, Time 00:00:13
Epoch 3. Train Loss: 0.046779, Train Acc: 0.985241, Valid Loss: 0.068747, Valid Acc: 0.980024, Time 00:00:13
Epoch 4. Train Loss: 0.043971, Train Acc: 0.986257, Valid Loss: 0.055970, Valid Acc: 0.983188, Time 00:00:13
Epoch 5. Train Loss: 0.039695, Train Acc: 0.987673, Valid Loss: 0.057191, Valid Acc: 0.983485, Time 00:00:13
Epoch 6. Train Loss: 0.036519, Train Acc: 0.988806, Valid Loss: 0.058995, Valid Acc: 0.982002, Time 00:00:14
Epoch 7. Train Loss: 0.033205, Train Acc: 0.989972, Valid Loss: 0.056425, Valid Acc: 0.982496, Time 00:00:13
Epoch 8. Train Loss: 0.032026, Train Acc: 0.990089, Valid Loss: 0.057818, Valid Acc: 0.983188, Time 00:00:13
Epoch 9. Train Loss: 0.028377, Train Acc: 0.990988, Valid Loss: 0.055412, Valid Acc: 0.983782, Time 00:00:13
Epoch 10. Train Los

可以看到，训练 10 次在简单的 mnist 数据集上也取得的了 98% 的准确率，所以说 RNN 也可以做做简单的图像分类，但是这并不是他的主战场，下次课我们会讲到 RNN 的一个使用场景，时间序列预测。