In [1]:
import torch
from torch.autograd import Variable
import torchvision
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
mnist_train = torchvision.datasets.MNIST('./data', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST('./data', train=False, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=32, shuffle=False)

In [3]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.fc1 = nn.Linear(784, 512)
    self.fc2 = nn.Linear(512, 256)
    self.fc3 = nn.Linear(256, 128)
    self.fc4 = nn.Linear(128, 64)
    self.fc5 = nn.Linear(64, 32)
    self.fc6 = nn.Linear(32, 10)
    
    
  def forward(self, x):
    x = x.view(-1, 784)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))
    x = F.relu(self.fc4(x))
    x = F.relu(self.fc5(x))
    x = self.fc6(x)
    return F.log_softmax(x, dim=1)

In [4]:
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [5]:
for epoch in range(1, 21):
    model.train()
    acc=0
    for data, label in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        pred = output.argmax(dim=1, keepdim=True)
        acc += pred.eq(label.view_as(pred)).sum().item()
    
    if epoch % 2 == 0:
        print('Epoch {}/{}: Accuracy {}'.format(epoch, 20, 100 * acc / len(train_loader.dataset)))

model.eval()
ans = 0
for data, label in test_loader:
    output = model(data)
    pred = output.argmax(dim=1, keepdim=True)
    ans += pred.eq(label.view_as(pred)).sum().item()

print('Test Accuracy: {}'.format(100 * ans / len(test_loader.dataset)))

Epoch 2/20: Accuracy 26.82
Epoch 4/20: Accuracy 90.345
Epoch 6/20: Accuracy 95.39333333333333
Epoch 8/20: Accuracy 97.04333333333334
Epoch 10/20: Accuracy 97.95
Epoch 12/20: Accuracy 98.535
Epoch 14/20: Accuracy 98.97166666666666
Epoch 16/20: Accuracy 99.415
Epoch 18/20: Accuracy 99.63166666666666
Epoch 20/20: Accuracy 99.74
Test Accuracy: 97.2
