## В данном ноутбуке будет описан полный цикл тренировки нейронной сети на pytorch

In [None]:
!wget https://raw.githubusercontent.com/yandexdataschool/Practical_DL/35c067adcc1ab364c8803830cdb34d0d50eea37e/week01_backprop/util.py -O util.py
!wget https://raw.githubusercontent.com/yandexdataschool/Practical_DL/35c067adcc1ab364c8803830cdb34d0d50eea37e/week01_backprop/mnist.py -O mnist.py
from __future__ import print_function
import numpy as np
np.random.seed(42)

In [None]:
import matplotlib.pyplot as plt
#%matplotlib inline

from mnist import load_dataset
X_train, y_train, X_val, y_val, X_test, y_test = load_dataset(flatten=True)

plt.figure(figsize=[6,6])
for i in range(4):
    plt.subplot(2,2,i+1)
    plt.title("Label: %i"%y_train[i])
    plt.imshow(X_train[i].reshape([28,28]),cmap='gray');

In [None]:
from torch.utils.data import DataLoader, TensorDataset

class Dataset(TensorDataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):        
        return {
            'input': torch.Tensor(self.X[idx]),
            'target': torch.LongTensor([self.y[idx]])
        }
    
train_dataset = Dataset(X_train, y_train)
val_dataset = Dataset(X_val, y_val)


train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32
)
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=32
)

In [None]:
import torch
from torch import nn
from torch import optim

class Net(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear1 = nn.Linear(in_features=config['input_dim'], out_features=config['l1_out'])
        self.linear2 = nn.Linear(in_features=config['l1_out'], out_features=config['l2_out'])
        self.linear3 = nn.Linear(in_features=config['l2_out'], out_features=config['output_dim'])
    
    def forward(self, x):
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        return self.linear3(x)
    
config = {
    'input_dim': 784,
    'l1_out': 200,
    'l2_out': 100,
    'output_dim': 10
}

net = Net(config)
optimizer = optim.Adam(params=net.parameters())
loss_func = nn.CrossEntropyLoss()

In [None]:
def count_metric(net, loader):
    net.eval()
    correct = 0
    total = 0
    for batch in loader:
        logits = net(batch['input'])
        correct += torch.sum(logits.argmax(dim=-1) == batch['target'].squeeze(-1)).item()
        total += len(logits)
    return correct / total

In [None]:
from IPython.display import clear_output
train_log = []
val_log = []

In [None]:
for epoch in range(5):
    net.train()
    for batch in train_loader:
        logits = net(batch['input'])
        loss = loss_func(logits, batch['target'].squeeze(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_log.append(count_metric(net, train_loader))
    val_log.append(count_metric(net, val_loader))
    
    clear_output()
    print("Epoch",epoch + 1)
    print("Train accuracy:",train_log[-1])
    print("Val accuracy:",val_log[-1])
    plt.plot(train_log,label='train accuracy')
    plt.plot(val_log,label='val accuracy')
    plt.legend(loc='best')
    plt.grid()
    plt.show()
    