In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

### STEP1: Loading MNIST Dataset

In [None]:
train_dataset=datasets.MNIST('./data/', True,
                             transforms.ToTensor(),
                             None, True)

test_dataset=datasets.MNIST('./data/', False,
                             transforms.ToTensor(),
                             None, True)

print(train_dataset.data.shape, train_dataset.targets.shape)
print(test_dataset.data.shape, test_dataset.targets.shape)

### STEP2: Make Dataset Iterable

In [None]:
batch_size=100
num_epochs=5
train_loader=DataLoader(train_dataset, batch_size, True)
test_loader=DataLoader(test_dataset, batch_size, not True)

f,s = iter(train_loader).next()
print(len(train_loader), f.shape, s.shape)

### STEP3:Create Model Class

In [None]:
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModel, self).__init__()
        self.hidden_dim=hidden_dim
        self.layer_dim=layer_dim
        
        self.lstm=nn.LSTM(input_dim, hidden_dim, layer_dim,
                          batch_first=True)
        self.fc=nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        h0=torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        c0=torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        out=self.fc(out[:, -1, :])
        return out

### STEP4: Instantiate Model Class

In [None]:
input_dim=28
hidden_dim=100
layer_dim=1
output_dim=10

model=LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
print(model)
for p in model.parameters():
    print(p.shape)

### STEP5: Instantiate Loss Class

In [None]:
criterion=nn.CrossEntropyLoss()

### STEP6: Instantiate Optimizer Class

In [None]:
learning_rate=0.1
optimizer=torch.optim.SGD(model.parameters(), lr=learning_rate)

### STEP7: Train Model

In [None]:
seq_dim=28

iter=0
for epoch in range(num_epochs):
    for images,labels in train_loader:
        images=images.view(-1, seq_dim, input_dim).requires_grad_()
        optimizer.zero_grad()
        outputs=model(images)
        loss=criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        iter+=1
        if iter%500==0:
            total=0.0
            correct=0.0
            for images, labels in test_loader:
                total+=labels.size(0)
                images=images.view(-1, seq_dim, input_dim)
                outputs=model(images)
                _,predicted=torch.max(outputs.data, 1)
                correct+=(predicted==labels).sum()
            accuracy=100.0*correct/total
            print("Iter:{}. Loss:{}. Accu:{}".format(
                iter, loss.item(), accuracy))