In [26]:
import numpy as np
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import tqdm

In [27]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [28]:
BATCH_SIZE = 64
set_seed(42)

In [29]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

display(train_dataset[0][0].shape)
display(train_dataset)

torch.Size([1, 28, 28])

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

In [30]:
class CNNModel(torch.nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.relu = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=1)
        self.pool = torch.nn.MaxPool2d(2, stride=2)

        self.conv1 = torch.nn.Conv2d(1, 16, 3)
        self.conv2 = torch.nn.Conv2d(16, 32, 3)

        self.fc1 = torch.nn.Linear(32*5*5, 128)
        self.fc2 = torch.nn.Linear(128, 10)


    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

In [31]:
model = CNNModel()
display(model)

CNNModel(
  (relu): ReLU()
  (softmax): Softmax(dim=1)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [32]:
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
  total_loss = 0

  for images, labels in tqdm.tqdm(train_loader):
    optimizer.zero_grad()
    output = model(images)
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

    # test
    correct = 0
    total = 0
  for images, labels in test_loader:
      output = model(images)
      _, predicted = torch.max(output, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  accuracy = correct / total
  print("epoch:", epoch, "loss:", total_loss, "accuracy:", accuracy)

100%|██████████| 938/938 [00:17<00:00, 54.42it/s]


epoch: 0 loss: -678.119718067348 accuracy: 0.8802


100%|██████████| 938/938 [00:17<00:00, 53.11it/s]


epoch: 1 loss: -873.8343593478203 accuracy: 0.9692


100%|██████████| 938/938 [00:17<00:00, 53.53it/s]


epoch: 2 loss: -913.6214557886124 accuracy: 0.9759


100%|██████████| 938/938 [00:17<00:00, 54.78it/s]


epoch: 3 loss: -918.6647735238075 accuracy: 0.9812


100%|██████████| 938/938 [00:23<00:00, 40.33it/s]


epoch: 4 loss: -921.8934329748154 accuracy: 0.9841


100%|██████████| 938/938 [00:34<00:00, 27.52it/s]


epoch: 5 loss: -923.3962613344193 accuracy: 0.9831


100%|██████████| 938/938 [00:29<00:00, 32.01it/s]


epoch: 6 loss: -925.3201131820679 accuracy: 0.9846


100%|██████████| 938/938 [00:43<00:00, 21.77it/s]


epoch: 7 loss: -926.6385991573334 accuracy: 0.9849


100%|██████████| 938/938 [01:04<00:00, 14.63it/s]


epoch: 8 loss: -927.6002976298332 accuracy: 0.987


100%|██████████| 938/938 [01:00<00:00, 15.54it/s]


epoch: 9 loss: -927.7109317779541 accuracy: 0.9839


In [36]:
# test
correct = 0
total = 0
with torch.no_grad():
  for images, labels in tqdm.tqdm(test_loader):
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print("Accuracy: ", correct / total)

100%|██████████| 157/157 [00:05<00:00, 29.36it/s]

Accuracy:  0.9839





In [37]:
# save model
torch.save(model.state_dict(), 'model/cnnmodel.pth')