In [4]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [5]:
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# hyper-param
sequence_length = 28      # image height
input_size = 28           # image width
hidden_size = 128         
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01

rnn预测一张图片的分类问题，它需要对一张图片从上到下，每一行像素进行分析，也就是图片高度从上到下为一个时间序列（一行一行pixel观察）
sequence_length: 图片的高度，序列的长度，就是一共有多少个时间点
input_size: 图片的宽度，即每个时间点要分析的内容量
num_layers: 对于每一个时间点输入 x<t>，都可以用多层LSTM layer处理得到输出y<t>
hidden_size: 对于当前t时刻输入 x<t>和前一时刻a<t-1>，用多少个神经元处理（一层LSTM）

In [8]:
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True
                                           )

test_dataset = torchvision.datasets.MNIST(root='./data/',
                                          train=False, 
                                          transform=transforms.ToTensor()
                                          )

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True, 
                                           num_workers=2
                                           )

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False,
                                          num_workers=2
                                          )

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


nn.LSTM()会返回out, (h, c)两个值：
第一个是LSTM最后一层的输出，格式：(batch, seq_len, input_size)，所以最后要取seq_len=-1的结果，即最后一个时间点out[:, -1, :]。

第二个是每个时间点输出的隐藏状态，实际上内部计算需要每次叠加，但我们用不上。

In [9]:
# RNN(many->one)
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        # set initial hidden and cell states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        # forward LSTM
        # out format: (batch_size, seq_length, hidden_size)
        out, (h, c) = self.lstm(x, (h0, c0)) 

        # decode the hidden state of last time step
        out = self.fc(out[:, -1, :])
        return out

In [11]:
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)

# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# train model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # reshape images (batch_size, seq_len, input_size)
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)

        # forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()     

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))   

Epoch [1/2], Step [100/600], Loss: 0.4880
Epoch [1/2], Step [200/600], Loss: 0.2549
Epoch [1/2], Step [300/600], Loss: 0.2265
Epoch [1/2], Step [400/600], Loss: 0.0704
Epoch [1/2], Step [500/600], Loss: 0.1325
Epoch [1/2], Step [600/600], Loss: 0.1325
Epoch [2/2], Step [100/600], Loss: 0.1045
Epoch [2/2], Step [200/600], Loss: 0.0370
Epoch [2/2], Step [300/600], Loss: 0.1254
Epoch [2/2], Step [400/600], Loss: 0.0918
Epoch [2/2], Step [500/600], Loss: 0.1791
Epoch [2/2], Step [600/600], Loss: 0.0336


In [18]:
# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predict = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predict == labels).sum().item()
        
    print(correct, total)
    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

# Save the model checkpoint
# torch.save(model.state_dict(), 'model.ckpt')

9782 10000
Test Accuracy of the model on the 10000 test images: 97.82 %
