In [1]:
import numpy as np
import torch
import torchvision as tv

In [53]:
BATCH_SIZE=256

In [54]:
train_dataset = tv.datasets.MNIST('.', train=True, transform=tv.transforms.ToTensor(), download=True)
test_dataset = tv.datasets.MNIST('.', train=False, transform=tv.transforms.ToTensor(), download=True)
train = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)
test = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [55]:
for i in train:
    print(i[1])
    break

tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1,
        1, 2, 4, 3, 2, 7, 3, 8, 6, 9, 0, 5, 6, 0, 7, 6, 1, 8, 7, 9, 3, 9, 8, 5,
        9, 3, 3, 0, 7, 4, 9, 8, 0, 9, 4, 1, 4, 4, 6, 0, 4, 5, 6, 1, 0, 0, 1, 7,
        1, 6, 3, 0, 2, 1, 1, 7, 9, 0, 2, 6, 7, 8, 3, 9, 0, 4, 6, 7, 4, 6, 8, 0,
        7, 8, 3, 1, 5, 7, 1, 7, 1, 1, 6, 3, 0, 2, 9, 3, 1, 1, 0, 4, 9, 2, 0, 0,
        2, 0, 2, 7, 1, 8, 6, 4, 1, 6, 3, 4, 5, 9, 1, 3, 3, 8, 5, 4, 7, 7, 4, 2,
        8, 5, 8, 6, 7, 3, 4, 6, 1, 9, 9, 6, 0, 3, 7, 2, 8, 2, 9, 4, 4, 6, 4, 9,
        7, 0, 9, 2, 9, 5, 1, 5, 9, 1, 2, 3, 2, 3, 5, 9, 1, 7, 6, 2, 8, 2, 2, 5,
        0, 7, 4, 9, 7, 8, 3, 2, 1, 1, 8, 3, 6, 1, 0, 3, 1, 0, 0, 1, 7, 2, 7, 3,
        0, 4, 6, 5, 2, 6, 4, 7, 1, 8, 9, 9, 3, 0, 7, 1, 0, 2, 0, 3, 5, 4, 6, 5,
        8, 6, 3, 7, 5, 8, 0, 9, 1, 0, 3, 1, 2, 2, 3, 3])


In [56]:
model = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(784, 256),
    torch.nn.BatchNorm1d(256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 10) 
)

In [57]:
model

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=256, bias=True)
  (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): ReLU()
  (4): Linear(in_features=256, out_features=10, bias=True)
)

In [58]:
loss = torch.nn.CrossEntropyLoss() 
trainer = torch.optim.Adam(model.parameters(), lr=0.01)
num_epochs = 10

In [59]:
def train_model(): # функция обучения модели
    for ep in range(num_epochs):
        train_iters, train_passed  = 0, 0 # фиксируем характеристики обучения
        train_loss, train_acc = 0., 0.
        
        model.train()
        for X, y in train: # итерируемся по тренировочной части датасета
            trainer.zero_grad() # сбрасываем градиенты нашей модели
            y_pred = model(X) # получаем предсказание
            l = loss(y_pred, y) # сравниваем с помощью функции потерь
            l.backward() # считаем градиенты
            trainer.step() # делаем шаг
            train_loss += l.item() # фиксируем количество элементов, которые мы прогнали
            train_acc += (y_pred.argmax(dim=1) == y).sum().item() # количество правильно предсказанных элементов
            
            train_iters += 1 # добавляем количество итераций
            train_passed += len(X) # смотрим сколько данных прогнали через нашу модель
            
        
        
        test_iters, test_passed  = 0, 0
        test_loss, test_acc = 0., 0.
        model.eval()
        for X, y in test:
            y_pred = model(X)
            l = loss(y_pred, y)
            test_loss += l.item()
            test_acc += (y_pred.argmax(dim=1) == y).sum().item()
            test_iters += 1
            test_passed += len(X)
            
        print("ep: {}, train_loss: {}, train_acc: {}, test_loss: {}, test_acc: {}".format(
            ep, train_loss / train_iters, train_acc / train_passed,
            test_loss / test_iters, test_acc / test_passed)
        )

In [60]:
train_model()

ep: 0, train_loss: 0.22148384976180943, train_acc: 0.9336666666666666, test_loss: 0.12321777335600928, test_acc: 0.9624
ep: 1, train_loss: 0.09412939180520938, train_acc: 0.97165, test_loss: 0.0950196890727966, test_acc: 0.9701
ep: 2, train_loss: 0.05782979813780873, train_acc: 0.9829333333333333, test_loss: 0.09875793446190073, test_acc: 0.9692
ep: 3, train_loss: 0.037989741112006474, train_acc: 0.9889, test_loss: 0.1010832858748472, test_acc: 0.9697
ep: 4, train_loss: 0.026711100204154216, train_acc: 0.9919666666666667, test_loss: 0.10773993961147425, test_acc: 0.9708
ep: 5, train_loss: 0.022753404461570637, train_acc: 0.99255, test_loss: 0.1250790519295606, test_acc: 0.9697
ep: 6, train_loss: 0.019124501230482486, train_acc: 0.9938666666666667, test_loss: 0.11003976158199294, test_acc: 0.9747
ep: 7, train_loss: 0.016650893367291923, train_acc: 0.9947833333333334, test_loss: 0.10147093100811162, test_acc: 0.9736
ep: 8, train_loss: 0.012375930091861557, train_acc: 0.9960833333333333, 