In [1]:
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms

In [2]:
Train_batch_size = 64
Test_batch_size = 10000

In [3]:
train_data = dset.MNIST("./mnist", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
test_data = dset.MNIST("./mnist", train=False, transform=transforms.ToTensor(), target_transform=None, download=False)

In [4]:
train_data.data.size()

torch.Size([60000, 28, 28])

In [5]:
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=Train_batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=Test_batch_size, shuffle=True)

In [6]:
import torch.nn.functional as F
class RNN(torch.nn.Module):
    def __init__(self):
        super(RNN,self).__init__()
        self.rnn = torch.nn.LSTM(     
            input_size=28,      
            hidden_size=100,   
            num_layers=2,      
            batch_first=True,  
        )
        self.out = torch.nn.Linear(100, 10)   
    def forward(self,x):
        r_out, (s_h, s_c) = self.rnn(x, None)   
        out = self.out(r_out[:, -1, :])
        return out

In [7]:
rnn = RNN() 
rnn

RNN(
  (rnn): LSTM(28, 100, num_layers=2, batch_first=True)
  (out): Linear(in_features=100, out_features=10, bias=True)
)

In [8]:
optimizer = torch.optim.SGD(rnn.parameters(), lr=0.01)   
loss_func = torch.nn.CrossEntropyLoss()  

In [10]:
epochs = 10 
step = 0
for e in range(epochs):
    for (x,y) in train_loader:

        out = rnn(x.view(-1,28,28))
        loss = loss_func(out,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pred = torch.argmax(F.softmax(out, dim=1), 1)  # 明确指定dim参数
        correct = pred.eq(y.data.view_as(pred))
        correct = correct.type(torch.float)
        acc = torch.sum(correct)/Train_batch_size
        if step %500==0:
            print('step:',step)
            print('Train Accuracy=%.2f'%acc)          
            for x_,y_ in test_loader:             
                x_ = x_.type(torch.float32)
                test_out = rnn(x_.view(-1,28,28))
                test_pred = torch.argmax(F.softmax(test_out, dim=1), 1) 
                te_correct = test_pred.eq(y_.data.view_as(test_pred))
                te_correct = te_correct.type(torch.float)
                te_acc = torch.sum(te_correct)/10000
                print('Test Accuracy=%.2f'%te_acc)
                print('------------------')
        step += 1

step: 0
Train Accuracy=0.97
Test Accuracy=0.97
------------------
step: 500
Train Accuracy=0.97
Test Accuracy=0.97
------------------
step: 1000
Train Accuracy=0.92
Test Accuracy=0.96
------------------
step: 1500
Train Accuracy=0.97
Test Accuracy=0.93
------------------
step: 2000
Train Accuracy=0.73
Test Accuracy=0.77
------------------
step: 2500
Train Accuracy=0.84
Test Accuracy=0.83
------------------
step: 3000
Train Accuracy=0.88
Test Accuracy=0.89
------------------
step: 3500
Train Accuracy=0.91
Test Accuracy=0.93
------------------
step: 4000
Train Accuracy=0.92
Test Accuracy=0.93
------------------
step: 4500
Train Accuracy=0.95
Test Accuracy=0.93
------------------
step: 5000
Train Accuracy=0.94
Test Accuracy=0.93
------------------
step: 5500
Train Accuracy=0.94
Test Accuracy=0.93
------------------
step: 6000
Train Accuracy=0.84
Test Accuracy=0.91
------------------
step: 6500
Train Accuracy=0.84
Test Accuracy=0.88
------------------
step: 7000
Train Accuracy=0.95
Test Ac