In [1]:
import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

**Hyper Parameters**

In [3]:
EPOCH = 2
BATCH_SIZE = 64
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.01

**Prepare Dataset**

In [4]:
train_data = dsets.MNIST(root='/.mnist', train=True, transform=transforms.ToTensor(), download=True)

In [5]:
data_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
test_data = dsets.MNIST(root='/.mnist', train=False, transform=transforms.ToTensor(), download=True)
test_x = test_data.test_data.type(torch.FloatTensor)[:2000].to(device)/255.   # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000].to(device)



**RNN Model**

In [7]:
class RNN(nn.Module):
  def __init__(self) -> None:
      super(RNN, self).__init__()

      self.rnn = nn.LSTM(
          input_size=INPUT_SIZE,
          hidden_size=64,
          num_layers=2, # strength of RNN
          batch_first=True, # (batch, time_step, input)
      )

      self.out = nn.Linear(in_features=64,out_features=10)
  

  def forward(self,x):
    r_out, (h_n, h_c) = self.rnn(x,None)  # x (batch, time_step, input_size)
    out = self.out(r_out[:,-1,:]) # (batch, time step, input) this means last time step
    return out

In [8]:
rnn = RNN().to(device)
print(rnn)

RNN(
  (rnn): LSTM(28, 64, num_layers=2, batch_first=True)
  (out): Linear(in_features=64, out_features=10, bias=True)
)


**Loss and Optimizer**

In [9]:
optimizer = torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()

**Training Loop**

In [10]:
for epoch in range(EPOCH):
  for step, (x,y) in enumerate(data_loader):
    b_x = x.view(-1,28,28).to(device)
    b_y = y.to(device)
    output = rnn(b_x)

    loss = loss_func(output,b_y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 50 == 0 and step % 50 == 0:
      test_output = rnn(test_x)
      pred_y = torch.max(test_output, 1)[1].data.squeeze()
      accuracy = sum(pred_y == test_y) / float(test_y.size(0))
      print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.cpu().numpy(), '| test accuracy: %.2f' % accuracy)

Epoch:  0 | train loss: 2.3073 | test accuracy: 0.12
Epoch:  0 | train loss: 1.3189 | test accuracy: 0.55
Epoch:  0 | train loss: 0.7608 | test accuracy: 0.71
Epoch:  0 | train loss: 0.6731 | test accuracy: 0.75
Epoch:  0 | train loss: 0.6736 | test accuracy: 0.79
Epoch:  0 | train loss: 0.5627 | test accuracy: 0.89
Epoch:  0 | train loss: 0.3103 | test accuracy: 0.90
Epoch:  0 | train loss: 0.2329 | test accuracy: 0.88
Epoch:  0 | train loss: 0.3947 | test accuracy: 0.93
Epoch:  0 | train loss: 0.2507 | test accuracy: 0.93
Epoch:  0 | train loss: 0.3160 | test accuracy: 0.92
Epoch:  0 | train loss: 0.2171 | test accuracy: 0.94
Epoch:  0 | train loss: 0.1461 | test accuracy: 0.94
Epoch:  0 | train loss: 0.1845 | test accuracy: 0.95
Epoch:  0 | train loss: 0.0915 | test accuracy: 0.95
Epoch:  0 | train loss: 0.1404 | test accuracy: 0.95
Epoch:  0 | train loss: 0.2208 | test accuracy: 0.94
Epoch:  0 | train loss: 0.0743 | test accuracy: 0.96
Epoch:  0 | train loss: 0.2177 | test accuracy

**Plot**

In [13]:
test_output = rnn(test_x[:10].view(-1,28,28))
pred_y = torch.max(test_output,1)[1].data.cpu().numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

[7 2 1 0 4 1 4 9 5 9] prediction number
tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9], device='cuda:0') real number
