In [1]:
import numpy as np
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import tqdm

In [2]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
BATCH_SIZE = 64
set_seed(42)

In [4]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

display(train_dataset[0][0].shape)
display(train_dataset)

torch.Size([1, 28, 28])

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
          torch.nn.Linear(28*28, 512),
          torch.nn.ReLU(),
          torch.nn.Linear(512, 256),
          torch.nn.ReLU(),
          torch.nn.Linear(256, 10),
          torch.nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.model(x)
        return x


In [7]:
model = Model()
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
  total_loss = 0

  for images, labels in tqdm.tqdm(train_loader):
    optimizer.zero_grad()
    output = model(images)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

  # test
  correct = 0
  total = 0
  with torch.no_grad():
    for images, labels in test_loader:
      output = model(images)
      _, predicted = torch.max(output, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  print("epoch:", epoch, "loss:", total_loss, "accuracy:", correct/total)

100%|██████████| 938/938 [00:55<00:00, 16.93it/s]


epoch: 0 loss: 220.9954596562311 accuracy: 0.9599


100%|██████████| 938/938 [00:54<00:00, 17.23it/s]


epoch: 1 loss: 82.91203230479732 accuracy: 0.9758


100%|██████████| 938/938 [00:48<00:00, 19.47it/s]


epoch: 2 loss: 54.472215139190666 accuracy: 0.9802


100%|██████████| 938/938 [00:20<00:00, 46.35it/s]


epoch: 3 loss: 40.87254710053094 accuracy: 0.9794


100%|██████████| 938/938 [00:25<00:00, 36.52it/s]


epoch: 4 loss: 29.861049971485045 accuracy: 0.9827


100%|██████████| 938/938 [00:24<00:00, 39.04it/s]


epoch: 5 loss: 22.457174307273817 accuracy: 0.9767


100%|██████████| 938/938 [00:22<00:00, 41.90it/s]


epoch: 6 loss: 21.707078236151574 accuracy: 0.9816


100%|██████████| 938/938 [00:24<00:00, 38.77it/s]


epoch: 7 loss: 17.611939444963355 accuracy: 0.9798


100%|██████████| 938/938 [00:23<00:00, 39.45it/s]


epoch: 8 loss: 17.869373987246945 accuracy: 0.9805


100%|██████████| 938/938 [00:20<00:00, 45.05it/s]


epoch: 9 loss: 12.696594428125536 accuracy: 0.9789


In [8]:
# test
correct = 0
total = 0
with torch.no_grad():
  for images, labels in test_loader:
    output = model(images)
    _, predicted = torch.max(output, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print("accuracy:", correct/total)

accuracy: 0.9789


In [9]:
# save model
torch.save(model.state_dict(), "model/nnmodel.pth")