In [5]:
import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn
from matplotlib import pyplot as plt

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

True

In [7]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

BATCH_SIZE = 128

training_data = MNIST(
    'train_data',
    train=True,
    transform=transform,
    download=True
)
test_data = MNIST(
    'test_data',
    train=False,
    transform=transform,
    download=True
)

dl_train = DataLoader(
    dataset=training_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)
dl_test = DataLoader(
    dataset=test_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)

In [8]:
FAN_IN = 28 * 28
HIDDEN_DIM = 128
N_CLASSES = 10

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        self.input_dim = FAN_IN
        self.hidden_dim = HIDDEN_DIM
        self.output_dim = N_CLASSES

        self.net = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.output_dim),
        )

    def forward(self, x : torch.Tensor):
        x = x.flatten(1)
        return self.net(x)

In [None]:
def compute_accuracy(dl : DataLoader, model : MLP):
    correct = 0
    total = 0

    for x, y in dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        pred = model(x)
        pred = torch.argmax(pred, dim=1)
        correct += torch.sum(pred == y)
        total += pred.shape[0]

    return correct / total