In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import torch.nn.functional as F

import time

device = torch.device("mps")

In [3]:
def load_data(dataset = "mnist", 
              path = 'data', 
              train = True, 
              batch_size = 256, 
              transforms = torchvision.transforms.ToTensor(),
              download  = True):
    '''
    Returns the dataset and dataloader for the specified dataset.
    
    Supported datasets: [mnist, cifar, fashion, emnist, kmnist, svhn]
    '''
    if dataset.lower() == 'mnist':
        dataset = torchvision.datasets.MNIST(path, train=train, transform=transforms, download=download)
    elif dataset.lower() == 'fashion':
        dataset = torchvision.datasets.FashionMNIST(path, train=train, transform=transforms, download=download)
    elif dataset.lower() == 'cifar':
        dataset = torchvision.datasets.CIFAR10(path, train=train, transform=transforms, download=download)
    elif dataset.lower() == 'emnist':
        dataset = torchvision.datasets.EMNIST(path, train=train, transform=transforms, download=download, split='letters')
    elif dataset.lower() == 'kmnist':
        dataset = torchvision.datasets.KMNIST(path, train=train, transform=transforms, download=download)
    elif dataset.lower() == 'svhn':
        dataset = torchvision.datasets.SVHN(path + '/SVHN', split='train' if train else 'test', transform=transforms, download=download)
    else:
        raise ValueError('Invalid dataset. Options: [mnist, cifar, fashion, emnist, kmnist, svhn]')
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataset, loader

In [4]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 64)
        self.out = nn.Linear(64, 10)
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x) 
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.relu(x)
        x = self.out(x)
        return x

In [5]:
_, train_loader = load_data(batch_size=64)
_, test_loader = load_data(batch_size=64, train=False)

model = MLP()
model.to(device)
model.train()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

criterion = torch.nn.CrossEntropyLoss()
for epoch in range(10):
    tic = time.perf_counter()
    model.train()

    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        X = X.flatten(start_dim=1)

        output = model(X)
        loss = criterion(output, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    accuracy = 0
    n = 0
    for f, l in test_loader:
        f, l = f.to(device), l.to(device)
        f = f.flatten(start_dim=1)
        accuracy += (model(f).argmax(dim=-1) == l).sum()
        n += len(l)

    toc = time.perf_counter()
    print(
        f"Epoch: {epoch+1}, Test Accuracy: {accuracy / n}",
        f"Time: {toc - tic:.3f}"
    )

Epoch: 1, Test Accuracy: 0.6136000156402588 Time: 4.588
Epoch: 2, Test Accuracy: 0.852400004863739 Time: 4.123
Epoch: 3, Test Accuracy: 0.8822000026702881 Time: 4.274
Epoch: 4, Test Accuracy: 0.9016000032424927 Time: 4.259
Epoch: 5, Test Accuracy: 0.9124000072479248 Time: 4.226
Epoch: 6, Test Accuracy: 0.9241999983787537 Time: 4.138
Epoch: 7, Test Accuracy: 0.9336000084877014 Time: 4.173
Epoch: 8, Test Accuracy: 0.9376000165939331 Time: 4.026
Epoch: 9, Test Accuracy: 0.9444000124931335 Time: 4.032
Epoch: 10, Test Accuracy: 0.9474999904632568 Time: 4.016
