In [35]:
import os
import torch
from torchvision import datasets, transforms
from torch import nn
from torch.utils.data import DataLoader
from torch import optim

### Loading Dataset

In [36]:
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])

def mnist_dataset():
    root_dir = "./data"
    os.makedirs(root_dir, exist_ok=True)
    train_dataset = datasets.MNIST(root=root_dir, train=True, transform=data_tf, download=True)
    test_dataset = datasets.MNIST(root=root_dir, train=False, transform=data_tf)
    return train_dataset, test_dataset

train_dataset, test_dataset = mnist_dataset()

### Network Definition

In [37]:
class Network(nn.Module):
    """ Network with two fully connected layers and softmax output """
    def __init__(self, input_dims, mlp_dims: list, output_dims):
        super(Network, self).__init__()
        self.mlp_dims = [input_dims] + mlp_dims + [output_dims]
        layers = []
        for i in range(len(self.mlp_dims) - 1):
            layers.append(nn.Linear(self.mlp_dims[i], self.mlp_dims[i + 1]))
            if i != len(self.mlp_dims) - 2:
                # activation default to relu, normal output for last layer
                layers.append(nn.ReLU())
        self.fc = nn.Sequential(*layers)
    def forward(self, x):
        # x with shape (batch, 28, 28)
        return self.fc(x)

### Training Loop

In [47]:
def mnist_training():
    learning_rate = 1e-2
    batch_size = 64
    epochs = 10
    eval_interval = max(epochs // 3, 1)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    criterion = nn.CrossEntropyLoss()
    model = Network(input_dims=28 * 28, mlp_dims=[300, 100], output_dims=10) # 10 classes classification
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    def count_correct(pred, label):
        _, pred = torch.max(pred, dim=1)
        return torch.sum(pred == label)

    def train():
        train_loss, train_acc = 0., 0.
        model.train()
        for imgs, labels in train_loader:
            imgs = imgs.view(-1, 28 * 28)
            prob = model(imgs)
            loss = criterion(prob, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            num_correct = count_correct(prob, labels)
            acc = num_correct / imgs.size(0)
            train_acc += acc
        return train_loss / len(train_loader), train_acc / len(train_loader)
            
    def validate():
        eval_loss, eval_acc = 0., 0.
        model.eval()
        for imgs, labels in test_loader:
            imgs = imgs.view(-1, 28 * 28)
            prob = model(imgs)
            loss = criterion(prob, labels)
            eval_loss += loss.item()
            num_correct = count_correct(prob, labels)
            acc = num_correct / imgs.size(0)
            eval_acc += acc
        return eval_loss / len(test_loader), eval_acc / len(test_loader)
    
    # training loop
    for i in range(epochs):
        train_loss, train_acc = train()
        print(f"Epoch {i}, loss: {train_loss:.2f}, accuracy: {train_acc * 100:.2f}%")
        if i % eval_interval == 0 and i:
            eval_loss, eval_acc = validate()
            print(f"Evaluating network in epoch {i}, loss: {eval_loss:.2f}, accuracy: {eval_acc * 100:.2f}%")

In [48]:
mnist_training()

Epoch 0, loss: 1.00, accuracy: 76.01%
Epoch 1, loss: 0.37, accuracy: 89.24%
Epoch 2, loss: 0.32, accuracy: 90.83%
Epoch 3, loss: 0.28, accuracy: 91.72%
Evaluating network in epoch 3, loss: 0.28, accuracy: 92.09%
Epoch 4, loss: 0.26, accuracy: 92.41%
Epoch 5, loss: 0.24, accuracy: 93.07%
Epoch 6, loss: 0.22, accuracy: 93.66%
Evaluating network in epoch 6, loss: 0.21, accuracy: 93.86%
Epoch 7, loss: 0.20, accuracy: 94.26%
Epoch 8, loss: 0.18, accuracy: 94.74%
Epoch 9, loss: 0.17, accuracy: 95.13%
Evaluating network in epoch 9, loss: 0.16, accuracy: 95.21%
