In [8]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### 1. Dataloader

In [7]:
train_dataset = torchvision.datasets.MNIST(root = '../../data', 
                                           train = True, 
                                           transform = transforms.ToTensor(),
                                           download = True
                                          )

test_dataset = torchvision.datasets.MNIST(root = '../../data', 
                                           train = False, 
                                           transform = transforms.ToTensor(),
                                           download = True
                                          )

In [9]:
train_loader = DataLoader(dataset = train_dataset, batch_size = 32, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size = 32, shuffle = False)

### 2. Model

In [14]:
class ANN(nn.Module):
    def __init__(self, num_classes):
        super(ANN, self).__init__()
        self.l1 = nn.Linear((28 * 28), 500)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(500, num_classes)
        
    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l2(x)
        return x

In [16]:
model = ANN(10).to(device)

### 3. Training

In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [23]:
for epoch in range(5):
    print('Starting epoch:', epoch)
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, 28 * 28).to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Starting epoch: 0
Starting epoch: 1
Starting epoch: 2
Starting epoch: 3
Starting epoch: 4


### 4. Evaluate

In [24]:
with torch.no_grad():
    correct = 0
    samples = len(test_loader.dataset)
    
    for images, labels in test_loader:
        images = images.reshape(-1, 28 * 28).to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
    print(correct/samples)

0.9807
