In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision as tv

from torchvision.datasets import MNIST
from tqdm import tqdm

Загружаем датасет и проводим трансформации над ним: сначала трансформируем каждую картинку в тензор, затем нормализуем с указанными параметрамми мат.ожидания и среднеквадратического отклонения.

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

batch_size=128
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

Строим классификатор. Данный классификатор представляет собой известную нейросеть LeNet, с небольшими изменениями, чтобы подстроиться под разрешение картинки в MNIST. Также добавлены дополнительные полносвязные слои в конец нейросети, т.к. это позволило добиться точности более 99%.

In [26]:
class Classificator(nn.Module):
    def __init__(self):
        super(Classificator, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=9) # after that step: 20x20x6
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # after that step: 10x10x6
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=3) # after that step: 8x8x16
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # after that step: 4x4x16
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=32)
        self.fc4 = nn.Linear(in_features=32, out_features=16)
        self.fc5 = nn.Linear(in_features=16, out_features=8)
        self.fc6 = nn.Linear(in_features=8, out_features=10)

        self.leaky_relu = nn.LeakyReLU()

    def forward(self, x):
        x = self.pool1(self.leaky_relu(self.conv1(x)))
        x = self.pool2(self.leaky_relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = self.leaky_relu(self.fc1(x))
        x = self.leaky_relu(self.fc2(x))
        x = self.leaky_relu(self.fc3(x))
        x = self.leaky_relu(self.fc4(x))
        x = self.leaky_relu(self.fc5(x))
        x = self.fc6(x)

        return x


Создаём объект нейросети, в качестве функции потерь выбираем кросс-энтропию, в качестве оптимизирующего алгоритма - Adam. Важно отметить, что реализация nn.CrossEntropyLoss() в Pytorch совмещает в себе softmax activation function и neagative log likelihood в единой функции, поэтому явно слой с softmax в нейросети не используется. Также инициализируем гиперпараметры нейросети, затем начинаем обучение.

In [31]:
clf = Classificator()
criterion = nn.CrossEntropyLoss()
learning_rate = 0.001
num_epochs = 100
optimizer = torch.optim.Adam(clf.parameters(), lr=learning_rate)

for epoch in tqdm(range(num_epochs)):
    for data in train_loader:
        img, labels = data
        optimizer.zero_grad()
        output = clf(img)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

  1%|          | 1/100 [00:18<30:46, 18.65s/it]

Epoch [1/100], Loss: 0.3374


  2%|▏         | 2/100 [00:34<27:32, 16.87s/it]

Epoch [2/100], Loss: 0.1665


  3%|▎         | 3/100 [00:57<31:56, 19.76s/it]

Epoch [3/100], Loss: 0.1464


  4%|▍         | 4/100 [01:13<29:16, 18.30s/it]

Epoch [4/100], Loss: 0.1414


  5%|▌         | 5/100 [01:37<32:07, 20.29s/it]

Epoch [5/100], Loss: 0.1340


  6%|▌         | 6/100 [01:55<30:22, 19.39s/it]

Epoch [6/100], Loss: 0.1306


  7%|▋         | 7/100 [02:11<28:38, 18.48s/it]

Epoch [7/100], Loss: 0.1348


  8%|▊         | 8/100 [02:28<27:41, 18.06s/it]

Epoch [8/100], Loss: 0.1238


  9%|▉         | 9/100 [02:47<27:46, 18.32s/it]

Epoch [9/100], Loss: 0.1321


 10%|█         | 10/100 [03:18<33:10, 22.12s/it]

Epoch [10/100], Loss: 0.1329


 11%|█         | 11/100 [03:35<30:37, 20.65s/it]

Epoch [11/100], Loss: 0.1111


 12%|█▏        | 12/100 [03:52<28:35, 19.50s/it]

Epoch [12/100], Loss: 0.1203


 13%|█▎        | 13/100 [04:08<26:56, 18.58s/it]

Epoch [13/100], Loss: 0.1192


 14%|█▍        | 14/100 [04:26<26:01, 18.15s/it]

Epoch [14/100], Loss: 0.1031


 15%|█▌        | 15/100 [04:42<24:48, 17.52s/it]

Epoch [15/100], Loss: 0.1038


 16%|█▌        | 16/100 [04:59<24:24, 17.43s/it]

Epoch [16/100], Loss: 0.0840


 17%|█▋        | 17/100 [05:15<23:37, 17.08s/it]

Epoch [17/100], Loss: 0.0888


 18%|█▊        | 18/100 [05:32<23:22, 17.10s/it]

Epoch [18/100], Loss: 0.0755


 19%|█▉        | 19/100 [05:50<23:22, 17.31s/it]

Epoch [19/100], Loss: 0.0828


 20%|██        | 20/100 [06:09<23:36, 17.71s/it]

Epoch [20/100], Loss: 0.1119


 21%|██        | 21/100 [06:24<22:29, 17.09s/it]

Epoch [21/100], Loss: 0.0622


 22%|██▏       | 22/100 [06:47<24:30, 18.85s/it]

Epoch [22/100], Loss: 0.0753


 23%|██▎       | 23/100 [07:12<26:17, 20.48s/it]

Epoch [23/100], Loss: 0.0504


 24%|██▍       | 24/100 [07:30<25:16, 19.95s/it]

Epoch [24/100], Loss: 0.0588


 25%|██▌       | 25/100 [08:00<28:32, 22.83s/it]

Epoch [25/100], Loss: 0.0541


 26%|██▌       | 26/100 [08:16<25:50, 20.96s/it]

Epoch [26/100], Loss: 0.0860


 27%|██▋       | 27/100 [08:35<24:45, 20.34s/it]

Epoch [27/100], Loss: 0.0654


 28%|██▊       | 28/100 [08:53<23:35, 19.66s/it]

Epoch [28/100], Loss: 0.0522


 29%|██▉       | 29/100 [09:15<23:46, 20.09s/it]

Epoch [29/100], Loss: 0.0701


 30%|███       | 30/100 [09:38<24:27, 20.96s/it]

Epoch [30/100], Loss: 0.0405


 31%|███       | 31/100 [09:57<23:39, 20.58s/it]

Epoch [31/100], Loss: 0.0726


 32%|███▏      | 32/100 [10:17<23:11, 20.47s/it]

Epoch [32/100], Loss: 0.0518


 33%|███▎      | 33/100 [10:37<22:34, 20.21s/it]

Epoch [33/100], Loss: 0.0540


 34%|███▍      | 34/100 [10:56<21:57, 19.96s/it]

Epoch [34/100], Loss: 0.1072


 35%|███▌      | 35/100 [11:15<21:03, 19.44s/it]

Epoch [35/100], Loss: 0.0591


 36%|███▌      | 36/100 [11:32<20:01, 18.78s/it]

Epoch [36/100], Loss: 0.0382


 37%|███▋      | 37/100 [11:49<19:20, 18.42s/it]

Epoch [37/100], Loss: 0.0433


 38%|███▊      | 38/100 [12:07<18:41, 18.09s/it]

Epoch [38/100], Loss: 0.0336


 39%|███▉      | 39/100 [12:25<18:18, 18.01s/it]

Epoch [39/100], Loss: 0.0498


 40%|████      | 40/100 [12:43<18:13, 18.23s/it]

Epoch [40/100], Loss: 0.0214


 41%|████      | 41/100 [13:00<17:34, 17.87s/it]

Epoch [41/100], Loss: 0.0339


 42%|████▏     | 42/100 [13:19<17:34, 18.18s/it]

Epoch [42/100], Loss: 0.0407


 43%|████▎     | 43/100 [13:36<16:56, 17.83s/it]

Epoch [43/100], Loss: 0.0252


 44%|████▍     | 44/100 [13:49<15:18, 16.41s/it]

Epoch [44/100], Loss: 0.0458


 45%|████▌     | 45/100 [14:06<15:08, 16.51s/it]

Epoch [45/100], Loss: 0.0226


 46%|████▌     | 46/100 [14:29<16:36, 18.45s/it]

Epoch [46/100], Loss: 0.0131


 47%|████▋     | 47/100 [14:51<17:13, 19.50s/it]

Epoch [47/100], Loss: 0.0156


 48%|████▊     | 48/100 [15:12<17:08, 19.79s/it]

Epoch [48/100], Loss: 0.0833


 49%|████▉     | 49/100 [15:33<17:21, 20.41s/it]

Epoch [49/100], Loss: 0.0182


 50%|█████     | 50/100 [15:55<17:15, 20.70s/it]

Epoch [50/100], Loss: 0.0201


 51%|█████     | 51/100 [16:13<16:12, 19.85s/it]

Epoch [51/100], Loss: 0.0269


 52%|█████▏    | 52/100 [16:32<15:45, 19.70s/it]

Epoch [52/100], Loss: 0.0121


 53%|█████▎    | 53/100 [16:53<15:39, 19.99s/it]

Epoch [53/100], Loss: 0.0496


 54%|█████▍    | 54/100 [17:16<16:08, 21.06s/it]

Epoch [54/100], Loss: 0.0477


 55%|█████▌    | 55/100 [17:35<15:15, 20.35s/it]

Epoch [55/100], Loss: 0.0156


 56%|█████▌    | 56/100 [17:54<14:40, 20.02s/it]

Epoch [56/100], Loss: 0.0116


 57%|█████▋    | 57/100 [18:11<13:44, 19.18s/it]

Epoch [57/100], Loss: 0.0194


 58%|█████▊    | 58/100 [18:27<12:38, 18.06s/it]

Epoch [58/100], Loss: 0.0554


 59%|█████▉    | 59/100 [18:43<11:56, 17.48s/it]

Epoch [59/100], Loss: 0.0293


 60%|██████    | 60/100 [19:01<11:45, 17.64s/it]

Epoch [60/100], Loss: 0.0252


 61%|██████    | 61/100 [19:16<10:55, 16.82s/it]

Epoch [61/100], Loss: 0.0108


 62%|██████▏   | 62/100 [19:31<10:25, 16.47s/it]

Epoch [62/100], Loss: 0.0483


 63%|██████▎   | 63/100 [19:48<10:10, 16.51s/it]

Epoch [63/100], Loss: 0.0170


 64%|██████▍   | 64/100 [20:01<09:10, 15.28s/it]

Epoch [64/100], Loss: 0.0128


 65%|██████▌   | 65/100 [20:15<08:42, 14.93s/it]

Epoch [65/100], Loss: 0.0170


 66%|██████▌   | 66/100 [20:28<08:10, 14.42s/it]

Epoch [66/100], Loss: 0.0240


 67%|██████▋   | 67/100 [20:43<08:01, 14.59s/it]

Epoch [67/100], Loss: 0.0123


 68%|██████▊   | 68/100 [20:58<07:48, 14.64s/it]

Epoch [68/100], Loss: 0.0103


 69%|██████▉   | 69/100 [21:12<07:27, 14.45s/it]

Epoch [69/100], Loss: 0.0185


 70%|███████   | 70/100 [21:26<07:09, 14.31s/it]

Epoch [70/100], Loss: 0.0123


 71%|███████   | 71/100 [21:41<07:01, 14.54s/it]

Epoch [71/100], Loss: 0.0145


 72%|███████▏  | 72/100 [21:54<06:34, 14.10s/it]

Epoch [72/100], Loss: 0.0174


 73%|███████▎  | 73/100 [22:06<06:04, 13.50s/it]

Epoch [73/100], Loss: 0.0123


 74%|███████▍  | 74/100 [22:18<05:40, 13.09s/it]

Epoch [74/100], Loss: 0.0097


 75%|███████▌  | 75/100 [22:30<05:21, 12.86s/it]

Epoch [75/100], Loss: 0.0083


 76%|███████▌  | 76/100 [22:44<05:16, 13.19s/it]

Epoch [76/100], Loss: 0.0074


 77%|███████▋  | 77/100 [22:58<05:05, 13.28s/it]

Epoch [77/100], Loss: 0.0083


 78%|███████▊  | 78/100 [23:23<06:08, 16.77s/it]

Epoch [78/100], Loss: 0.0209


 79%|███████▉  | 79/100 [23:39<05:51, 16.76s/it]

Epoch [79/100], Loss: 0.0085


 80%|████████  | 80/100 [23:55<05:27, 16.36s/it]

Epoch [80/100], Loss: 0.0059


 81%|████████  | 81/100 [24:09<04:57, 15.65s/it]

Epoch [81/100], Loss: 0.0133


 82%|████████▏ | 82/100 [24:21<04:25, 14.74s/it]

Epoch [82/100], Loss: 0.0067


 83%|████████▎ | 83/100 [24:36<04:10, 14.72s/it]

Epoch [83/100], Loss: 0.0076


 84%|████████▍ | 84/100 [24:54<04:13, 15.81s/it]

Epoch [84/100], Loss: 0.0042


 85%|████████▌ | 85/100 [25:07<03:42, 14.83s/it]

Epoch [85/100], Loss: 0.0033


 86%|████████▌ | 86/100 [25:20<03:18, 14.15s/it]

Epoch [86/100], Loss: 0.0053


 87%|████████▋ | 87/100 [25:33<03:00, 13.89s/it]

Epoch [87/100], Loss: 0.0177


 88%|████████▊ | 88/100 [25:45<02:40, 13.37s/it]

Epoch [88/100], Loss: 0.0058


 89%|████████▉ | 89/100 [25:58<02:26, 13.31s/it]

Epoch [89/100], Loss: 0.0151


 90%|█████████ | 90/100 [26:19<02:35, 15.55s/it]

Epoch [90/100], Loss: 0.0051


 91%|█████████ | 91/100 [26:39<02:31, 16.81s/it]

Epoch [91/100], Loss: 0.0019


 92%|█████████▏| 92/100 [26:57<02:18, 17.29s/it]

Epoch [92/100], Loss: 0.0097


 93%|█████████▎| 93/100 [27:15<02:02, 17.54s/it]

Epoch [93/100], Loss: 0.0075


 94%|█████████▍| 94/100 [27:34<01:47, 17.85s/it]

Epoch [94/100], Loss: 0.0059


 95%|█████████▌| 95/100 [27:53<01:30, 18.13s/it]

Epoch [95/100], Loss: 0.0047


 96%|█████████▌| 96/100 [28:18<01:21, 20.30s/it]

Epoch [96/100], Loss: 0.0430


 97%|█████████▋| 97/100 [28:37<00:59, 19.83s/it]

Epoch [97/100], Loss: 0.0011


 98%|█████████▊| 98/100 [28:54<00:37, 18.98s/it]

Epoch [98/100], Loss: 0.0070


 99%|█████████▉| 99/100 [29:07<00:17, 17.29s/it]

Epoch [99/100], Loss: 0.0019


100%|██████████| 100/100 [29:20<00:00, 17.60s/it]

Epoch [100/100], Loss: 0.0005





Осталось протестировать полученную модель. Как видно, за 100 итераций обучения с такой архитектурой нейросети поулчилось достичь точности в 99.15%.

In [39]:
correct, total = 0, 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = clf(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy on the test set:  {(100 * correct / total)} %')

Accuracy on the test set:  99.15 %
