In [83]:
import  time

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
from    torch.autograd import Variable

import  torchvision.datasets as dsets
import  torchvision.transforms as trans


In [84]:
batch_size = 1000
train_set = dsets.MNIST(root = '../data/',transform=trans.ToTensor(),train=True,download=True)
test_set = dsets.MNIST(root='../data/',transform=trans.ToTensor(),train=False)
train_dl = torch.utils.data.DataLoader(dataset=train_set,batch_size=batch_size,shuffle=True)
test_dl = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,shuffle=False)


In [85]:
class GRU(nn.Module):
    def __init__(self,input_size,hidden_size,num_layers,num_classes):
        super(GRU,self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.feature = nn.GRU(input_size,hidden_size,num_layers,batch_first=True)
        self.classifier = nn.Linear(hidden_size,num_classes)
        
    def forward(self,x):
        device = x.device
        h0 = Variable(torch.zeros(self.num_layers,x.size(0),self.hidden_size)).to(device)
        out,_ = self.feature(x,h0)
        out = self.classifier(out[:,-1,:])
        return out

In [86]:
input_size = 28
hidden_size = 200
seq_len = 28
num_layers = 2
num_classes = 10


lr = 0.001
nepochs = 30

net = GRU(input_size,hidden_size,num_layers,num_classes)
if torch.backends.mps.is_available():
    net = net.to('mps')

In [87]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=lr)


In [88]:
def eval(model,criterion,dataloader):
    loss = 0
    accuracy = 0
    for batch_x, batch_y in dataloader:
        batch_x = batch_x.view(-1, seq_len, input_size)
        batch_x, batch_y = Variable(batch_x), Variable(batch_y)
        if torch.backends.mps.is_available():
            batch_x = batch_x.to('mps')
            batch_y = batch_y.to('mps')
        logits = model(batch_x)
        error = criterion(logits, batch_y)
        loss += error.item()
        
        probs, pred_y = logits.data.max(dim=1)
        accuracy += (pred_y == batch_y.data).sum()/batch_y.size(0)
        
    loss /= len(dataloader)
    accuracy = accuracy*100/len(dataloader)
    return loss, accuracy

In [89]:
for epoch in range(nepochs):
    since = time.time()
    for batch_x,batch_y in train_dl:
        batch_x = batch_x.view(-1, seq_len, input_size)
        batch_y = Variable(batch_y)
        if torch.backends.mps.is_available():
            batch_x = batch_x.to('mps')
            batch_y = batch_y.to('mps')
            
        optimizer.zero_grad()
        logits = net(batch_x)
        error = criterion(logits,batch_y)
        error.backward()
        optimizer.step()
        
    now = time.time()
    train_loss,train_acc = eval(net,criterion,train_dl)
    test_loss,test_acc = eval(net,criterion,test_dl)
    print('%2d/%d, %.0f sec|\t train loss: %.4f, train acc: %.2f |\t  test loss: %.4f, test acc: %.2f' % (epoch+1,nepochs,now-since,train_loss,train_acc,test_loss,test_acc))

 1/30, 14 sec|	 train loss: 0.6914, train acc: 77.02 |	  test loss: 0.6830, test acc: 77.61
 2/30, 13 sec|	 train loss: 0.3220, train acc: 90.13 |	  test loss: 0.3108, test acc: 90.45
 3/30, 13 sec|	 train loss: 0.1811, train acc: 94.52 |	  test loss: 0.1787, test acc: 94.98
 4/30, 13 sec|	 train loss: 0.1511, train acc: 95.50 |	  test loss: 0.1559, test acc: 95.67
 5/30, 13 sec|	 train loss: 0.1025, train acc: 96.92 |	  test loss: 0.1066, test acc: 96.77
 6/30, 13 sec|	 train loss: 0.0842, train acc: 97.43 |	  test loss: 0.0866, test acc: 97.44
 7/30, 13 sec|	 train loss: 0.0702, train acc: 97.87 |	  test loss: 0.0786, test acc: 97.53
 8/30, 13 sec|	 train loss: 0.0687, train acc: 97.91 |	  test loss: 0.0754, test acc: 97.68
 9/30, 13 sec|	 train loss: 0.0480, train acc: 98.59 |	  test loss: 0.0599, test acc: 98.10
10/30, 13 sec|	 train loss: 0.0429, train acc: 98.69 |	  test loss: 0.0545, test acc: 98.33
11/30, 13 sec|	 train loss: 0.0416, train acc: 98.70 |	  test loss: 0.0547, test