In [1]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

In [40]:
EPOCH = 1
BATCH_SIZE = 50
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.01
train_data = torchvision.datasets.MNIST(
    root = './mnist',
    train = True,
    transform = torchvision.transforms.ToTensor(), # (0, 1) <- (0, 255)
    download = False
)
train_loader = Data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)
test_data = torchvision.datasets.MNIST(root='./mnist', train=False)
test_x = test_data.test_data.type(torch.FloatTensor)/255.
test_y = test_data.test_labels

In [9]:
class RNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size = 64,
            num_layers = 2,
            batch_first = True # 要保持一定的编程习惯，尽量都用batch_first
        )
        self.out = nn.Linear(64, 10)
    def forward(self, x):
        r_out, (h_n, h_c) = self.rnn(x, None)
        out = self.out(r_out[:, -1, :]) #(batch, time, input)
        #取出最后一个时间的。batch 和 input还是不变
        return out

In [15]:
rnn = RNN()
optimizer = torch.optim.Adam(rnn.parameters(), lr = LR)
loss_func = nn.CrossEntropyLoss()

In [11]:
rnn

RNN(
  (rnn): LSTM(28, 64, num_layers=2, batch_first=True)
  (out): Linear(in_features=64, out_features=10, bias=True)
)

In [42]:
for epoch in range(EPOCH):
    for step, (batch_x, batch_y) in enumerate(train_loader):
        batch_x = batch_x.squeeze(1)
        out_pre = rnn(batch_x)
        loss = loss_func(out_pre, batch_y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 50 == 0:
            test_pre = torch.max(rnn(test_x),dim=1)[1].detach().numpy().squeeze()
            acc = (test_pre == test_y).sum().item() / test_y.shape[0]
            print(acc)
#         print(batch_x.view(-1,28,28).shape)
#         prediction = out_pre.
            

0.8525
0.9106
0.9313
0.9369
0.9238
0.9014
0.9477
0.9444
0.9398
0.9514
0.9511
0.9422
0.9476
0.9584
0.9536
0.9561
0.9612
0.9545
0.9692
0.9662
0.9697
0.9661
0.9629
0.9719


In [34]:
input_tsets = torch.randn(3, 5, requires_grad=True)
target_tsets = torch.empty(3, dtype=torch.long).random_(5)

In [37]:
input_tsets.shape

torch.Size([3, 5])

In [38]:
target_tsets.shape

torch.Size([3])