In [1]:
import torch
import torchvision


train_mnist = torchvision.datasets.MNIST(
    "./data",
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])
)

test_mnist = torchvision.datasets.MNIST(
    "./data",
    train=False,
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])
)


In [2]:
model = torch.nn.Sequential(
    torch.nn.Linear(28 * 28, 300),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(300, 300),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(300, 10),
)


for digit, cls in train_mnist:
    digit = digit.view(digit.shape[0], 28 * 28)
    print(model(digit).shape)
    break

torch.Size([1, 10])


In [3]:
from tqdm import tqdm

dl = torch.utils.data.DataLoader(train_mnist, batch_size=32, shuffle=True)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(3):
    bar = tqdm(dl)
    for digit, cls in bar:
        digit = digit.view(digit.shape[0], 28 * 28)
        pred = model(digit)

        loss = loss_fn(pred, cls)
        accuracy = (pred.argmax(dim=1) == cls).float().mean()
        bar.set_description(f"Loss: {loss.item():.4f}, Accuracy: {accuracy.item():.4f}")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Loss: 0.1753, Accuracy: 0.9375: 100%|██████████| 1875/1875 [00:18<00:00, 100.71it/s]
Loss: 0.1332, Accuracy: 0.9375: 100%|██████████| 1875/1875 [00:18<00:00, 103.82it/s]
Loss: 0.0150, Accuracy: 1.0000: 100%|██████████| 1875/1875 [00:15<00:00, 118.27it/s]


In [4]:
dl = torch.utils.data.DataLoader(test_mnist, batch_size=32, shuffle=True)

bar = tqdm(dl)

preds = []
target = []
for digit, cls in bar:
    digit = digit.view(digit.shape[0], 28 * 28)
    pred = model(digit)

    preds.append(pred)
    target.append(cls)


100%|██████████| 313/313 [00:01<00:00, 244.99it/s]


In [5]:
p = torch.concatenate(preds)
t = torch.concatenate(target)


(p.argmax(dim=1) == t).float().mean()

tensor(0.9752)