In [1]:
## Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim  as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torchvision.transforms import transforms

In [2]:
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Hyperparameters
seq_length = 28
input_size= 28
num_layers = 2
hidden_size = 256
num_classes = 10
batch = 100
num_epoch = 2
learning_rate = 0.001

In [17]:
torch.rand(100,28,10)[:,-1,:].shape

torch.Size([100, 10])

In [4]:
## Create Fully connected network

class RNN(nn.Module):
    def __init__(self,input_size,hidden_size,num_layers,seq_length,num_classes = 10):
        super(RNN,self).__init__()
        
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        self.lstm = nn.LSTM(input_size,hidden_size,num_layers,batch_first= True)
        
        self.fc1 = nn.Linear(hidden_size,num_classes)
        
    def forward(self,x):
        h0 = torch.zeros(self.num_layers,x.size(0),self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers,x.size(0),self.hidden_size).to(device) 
        out , _ = self.lstm(x,(h0,c0))
        # out = out.reshape(out.shape[0],-1)
        out = self.fc1(out[:,-1,:])
        return out
        
    

In [6]:
## load Dataset
train_dataset = datasets.MNIST(root = 'datasets/',train=True,download=True,transform = transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset,batch_size=batch,shuffle=True)

test_dataset = datasets.MNIST(root = 'datasets/',train=False,download=True,transform = transforms.ToTensor())
test_loader = DataLoader(dataset=test_dataset,batch_size=batch,shuffle=True)

In [9]:
for data, label in train_loader:
    print(data.shape,label.shape)
    break

torch.Size([100, 1, 28, 28]) torch.Size([100])


In [7]:
## Initialize model 
model = RNN(input_size,hidden_size,num_layers,seq_length,num_classes).to(device)

In [8]:
## Losss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=learning_rate)

In [12]:
data = next(iter(train_loader))

In [29]:
out, _ = nn.LSTM(28,256,2,batch_first= True)(data[0].squeeze(1))

In [38]:
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

torch.Size([2, 100, 256])

In [8]:
import torch 
torch.rand(100,3,28,28).squeeze(0).shape

torch.Size([100, 3, 28, 28])

In [10]:
%%time
## Train Network 
for epoch in range(num_epoch):
    for batch_idx, (data,target) in enumerate(train_loader):
        # Get data to cuda 
        data = data.to(device=device).squeeze(1)
        target =target.to(device=device)
        
        # # reshape
        # data = data.view(data.shape[0],-1)
        
        # forward
        scores = model(data)
        loss = criterion(scores,target)
         
        # backward
        optimizer.zero_grad()
        loss.backward()
        
        # gradient descent
        optimizer.step()

CPU times: user 26min 44s, sys: 13min 6s, total: 39min 50s
Wall time: 5min 47s


In [11]:
# check the accuracy of out trained model 
def check_accuracy(loader,model):
    for data,target in loader:
        num_correct = 0
        num_sample = 0
        model.eval()
        with torch.no_grad():
            data = data.to(device=device).squeeze(1)
            target = target.to(device=device)

            # # reshape 
            # data = data.view(data.shape[0],-1)


            scores = model(data)
            _, pred = scores.max(1)
            # print(list(zip(pred,target)))
            num_correct += sum(pred == target)
            num_sample  += pred.shape[0]
    print(f'Total {num_correct} correct  / out of {num_sample} - accuracy {num_correct/num_sample :.3f} ')
    model.train()
            
            
            

In [12]:
# on test dataset
check_accuracy(test_loader,model)

Total 99 correct  / out of 100 - accuracy 0.990 


In [13]:
# on train datasets
check_accuracy(train_loader,model)

Total 99 correct  / out of 100 - accuracy 0.990 
