In [12]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [13]:
# Hyperparams
input_size = 28
sequence_length = 28
num_classes = 10
num_layers = 2
hidden_size = 256
learning_rate = 0.001
batch_size = 64
num_epochs = 2

In [66]:
class RNN(nn.Module):
	def __init__(self, input_size, num_classes, hidden_size, num_layers):
		super(RNN, self).__init__()
		self.hidden_size = hidden_size
		self.num_layers = num_layers

		# self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first = True)
		# self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first = True)
		self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first = True)
		# N x time_seq x features
		self.fc = nn.Linear(hidden_size*sequence_length, num_classes)

	def forward(self, x):
		h0 = torch.zeros(self.num_layers, x.size(0), hidden_size)
		c0 = torch.zeros(self.num_layers, x.size(0), hidden_size)

		# Forward prop
		out, _ = self.lstm(x, (h0, c0))
		out = out.reshape(out.shape[0], -1)
		out = self.fc(out)
		return out

In [67]:
train_dataset = datasets.MNIST(
	root='data',
	train=True,
	download=True,
	transform=transforms.ToTensor(),
	target_transform=transforms.Lambda(lambda y: torch.zeros(num_classes, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)

test_dataset = datasets.MNIST(
	root='data',
	train=False,
	download=True,
	transform=transforms.ToTensor(),
	target_transform=transforms.Lambda(lambda y: torch.zeros(num_classes, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)

In [68]:
model = RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, num_classes=num_classes)
print(model)

RNN(
  (lstm): LSTM(28, 256, num_layers=2, batch_first=True)
  (fc): Linear(in_features=7168, out_features=10, bias=True)
)


In [69]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [70]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [71]:
def train(dataloader, model, loss, optimizer):
	size = len(dataloader.dataset)
	model.train()

	for batch, (X, y) in enumerate(dataloader):
		X = X.squeeze(1)
		pred = model(X)
		loss = loss_fn(pred, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		if batch % 100 == 0:
			loss, current = loss.item(), batch * len(X)
			print(f"Loss: {loss}, [{current} / {size}]")

In [72]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.squeeze(1)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y.argmax(1)).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [73]:
train(train_dataloader, model, loss_fn, optimizer)

Loss: 2.300607442855835, [0 / 60000]
Loss: 0.3475913405418396, [6400 / 60000]
Loss: 0.37505102157592773, [12800 / 60000]
Loss: 0.2531822919845581, [19200 / 60000]
Loss: 0.07182702422142029, [25600 / 60000]
Loss: 0.23771074414253235, [32000 / 60000]
Loss: 0.09833714365959167, [38400 / 60000]
Loss: 0.2677958607673645, [44800 / 60000]
Loss: 0.29919883608818054, [51200 / 60000]
Loss: 0.0704730972647667, [57600 / 60000]


In [74]:

test(test_dataloader, model, loss_fn)

Test Error: 
 Accuracy: 97.0%, Avg loss: 0.096175 

