## Фреймворк PyTorch для разработки искусственных нейронных сетей.

### Урок 2. Feed-forward neural network.

**1. Сделаем необходимые импорты**

In [3]:
import numpy as np

import torch
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch import nn

from tqdm import tqdm

**2. Загрузим датасет CIFAR-100, сразу же создадим dataloader для него. Если вам не хватает вычислительных ресурсов, то можно вернуться к CIFAR-10.**

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

BATCH_SIZE = 4

train_set = torchvision.datasets.CIFAR100(root='./temp_data', train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR100(root='./temp_data', train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(test_set, batch_size = BATCH_SIZE, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./temp_data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./temp_data/cifar-100-python.tar.gz to ./temp_data
Files already downloaded and verified


In [5]:
def get_img_vector_size(args)->int:
    dataiter = iter(train_loader)
    images, labels = dataiter.next()
    size = images[0].size()
    return size[0] * size[1] * size[2]

**3.  Создайте собственную архитектуру! Можете использовать все, что угодно, но в ограничении на использование линейные слои (пока без сверток). Давайте добавим ограниченный Leaky_relu, то есть output = max(0.1x, 0.5x). Ваша задача добавить его в архитектуру сети как функцию активации.**

In [6]:
class CustomLeakyRelu(nn.Module):
    def __init__(self, border:float, alpha:float) -> None:
        super().__init__()
        self.border = border
        self.alpha = alpha
        
    def forward(self, input):
        x = F.leaky_relu(input)
        return torch.where(x>self.border, x*self.alpha, x)

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        img_vector_size = 3072
        self.fc1 = nn.Linear(img_vector_size,128)
        self.fc2 = nn.Linear(128,64)
        self.fc3 = nn.Linear(64,100)
        self.cust_relu = CustomLeakyRelu(0.1, 0.5)
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)        
        x = self.cust_relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def predict(self, x):
        x = x.view(x.shape[0], -1)
        x = self.cust_relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x
        
net = Net()

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

**4. Запустить обучение (по аналогии с тем, что делали на паре).**

In [7]:
for epoch in tqdm(range(10)):  
    running_loss = 0.0 
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0], data[1]

        optimizer.zero_grad() # Обнуление градиента

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() # Вывод статистики о процессе обучения
        if i % 300 == 0:    # Вывод каждых 300 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('End of training.')

  0%|          | 0/10 [00:00<?, ?it/s]

[1,     1] loss: 0.002
[1,   301] loss: 0.685
[1,   601] loss: 0.651
[1,   901] loss: 0.636
[1,  1201] loss: 0.629
[1,  1501] loss: 0.619
[1,  1801] loss: 0.605
[1,  2101] loss: 0.602
[1,  2401] loss: 0.606
[1,  2701] loss: 0.604
[1,  3001] loss: 0.594
[1,  3301] loss: 0.598
[1,  3601] loss: 0.602
[1,  3901] loss: 0.582
[1,  4201] loss: 0.582
[1,  4501] loss: 0.579
[1,  4801] loss: 0.583
[1,  5101] loss: 0.588
[1,  5401] loss: 0.591
[1,  5701] loss: 0.579
[1,  6001] loss: 0.574
[1,  6301] loss: 0.578
[1,  6601] loss: 0.577
[1,  6901] loss: 0.590
[1,  7201] loss: 0.570
[1,  7501] loss: 0.570
[1,  7801] loss: 0.572
[1,  8101] loss: 0.567
[1,  8401] loss: 0.575
[1,  8701] loss: 0.564
[1,  9001] loss: 0.567
[1,  9301] loss: 0.563
[1,  9601] loss: 0.561
[1,  9901] loss: 0.572
[1, 10201] loss: 0.571
[1, 10501] loss: 0.567
[1, 10801] loss: 0.558
[1, 11101] loss: 0.569
[1, 11401] loss: 0.554
[1, 11701] loss: 0.563
[1, 12001] loss: 0.550
[1, 12301] loss: 0.572


 10%|█         | 1/10 [01:44<15:37, 104.13s/it]

[2,     1] loss: 0.002
[2,   301] loss: 0.548
[2,   601] loss: 0.555
[2,   901] loss: 0.550
[2,  1201] loss: 0.557
[2,  1501] loss: 0.562
[2,  1801] loss: 0.557
[2,  2101] loss: 0.541
[2,  2401] loss: 0.541
[2,  2701] loss: 0.558
[2,  3001] loss: 0.544
[2,  3301] loss: 0.554
[2,  3601] loss: 0.556
[2,  3901] loss: 0.543
[2,  4201] loss: 0.551
[2,  4501] loss: 0.556
[2,  4801] loss: 0.547
[2,  5101] loss: 0.553
[2,  5401] loss: 0.543
[2,  5701] loss: 0.552
[2,  6001] loss: 0.553
[2,  6301] loss: 0.549
[2,  6601] loss: 0.544
[2,  6901] loss: 0.556
[2,  7201] loss: 0.551
[2,  7501] loss: 0.556
[2,  7801] loss: 0.535
[2,  8101] loss: 0.550
[2,  8401] loss: 0.542
[2,  8701] loss: 0.553
[2,  9001] loss: 0.562
[2,  9301] loss: 0.559
[2,  9601] loss: 0.558
[2,  9901] loss: 0.551
[2, 10201] loss: 0.551
[2, 10501] loss: 0.559
[2, 10801] loss: 0.560
[2, 11101] loss: 0.555
[2, 11401] loss: 0.546
[2, 11701] loss: 0.543
[2, 12001] loss: 0.540
[2, 12301] loss: 0.556


 20%|██        | 2/10 [03:26<13:42, 102.87s/it]

[3,     1] loss: 0.001
[3,   301] loss: 0.539
[3,   601] loss: 0.529
[3,   901] loss: 0.561
[3,  1201] loss: 0.539
[3,  1501] loss: 0.541
[3,  1801] loss: 0.536
[3,  2101] loss: 0.543
[3,  2401] loss: 0.549
[3,  2701] loss: 0.530
[3,  3001] loss: 0.540
[3,  3301] loss: 0.537
[3,  3601] loss: 0.534
[3,  3901] loss: 0.543
[3,  4201] loss: 0.548
[3,  4501] loss: 0.544
[3,  4801] loss: 0.539
[3,  5101] loss: 0.537
[3,  5401] loss: 0.545
[3,  5701] loss: 0.541
[3,  6001] loss: 0.540
[3,  6301] loss: 0.532
[3,  6601] loss: 0.539
[3,  6901] loss: 0.540
[3,  7201] loss: 0.554
[3,  7501] loss: 0.534
[3,  7801] loss: 0.541
[3,  8101] loss: 0.540
[3,  8401] loss: 0.536
[3,  8701] loss: 0.548
[3,  9001] loss: 0.543
[3,  9301] loss: 0.546
[3,  9601] loss: 0.542
[3,  9901] loss: 0.539
[3, 10201] loss: 0.554
[3, 10501] loss: 0.544
[3, 10801] loss: 0.535
[3, 11101] loss: 0.542
[3, 11401] loss: 0.551
[3, 11701] loss: 0.553
[3, 12001] loss: 0.547
[3, 12301] loss: 0.536


 30%|███       | 3/10 [05:08<11:59, 102.83s/it]

[4,     1] loss: 0.002
[4,   301] loss: 0.512
[4,   601] loss: 0.529
[4,   901] loss: 0.538
[4,  1201] loss: 0.531
[4,  1501] loss: 0.529
[4,  1801] loss: 0.540
[4,  2101] loss: 0.531
[4,  2401] loss: 0.539
[4,  2701] loss: 0.541
[4,  3001] loss: 0.529
[4,  3301] loss: 0.532
[4,  3601] loss: 0.540
[4,  3901] loss: 0.545
[4,  4201] loss: 0.527
[4,  4501] loss: 0.534
[4,  4801] loss: 0.532
[4,  5101] loss: 0.541
[4,  5401] loss: 0.548
[4,  5701] loss: 0.519
[4,  6001] loss: 0.543
[4,  6301] loss: 0.548
[4,  6601] loss: 0.526
[4,  6901] loss: 0.534
[4,  7201] loss: 0.552
[4,  7501] loss: 0.537
[4,  7801] loss: 0.544
[4,  8101] loss: 0.536
[4,  8401] loss: 0.537
[4,  8701] loss: 0.544
[4,  9001] loss: 0.527
[4,  9301] loss: 0.543
[4,  9601] loss: 0.544
[4,  9901] loss: 0.546
[4, 10201] loss: 0.548
[4, 10501] loss: 0.534
[4, 10801] loss: 0.541
[4, 11101] loss: 0.546
[4, 11401] loss: 0.544
[4, 11701] loss: 0.549
[4, 12001] loss: 0.534
[4, 12301] loss: 0.539


 40%|████      | 4/10 [06:57<10:30, 105.15s/it]

[5,     1] loss: 0.002
[5,   301] loss: 0.516
[5,   601] loss: 0.539
[5,   901] loss: 0.518
[5,  1201] loss: 0.529
[5,  1501] loss: 0.521
[5,  1801] loss: 0.525
[5,  2101] loss: 0.528
[5,  2401] loss: 0.519
[5,  2701] loss: 0.541
[5,  3001] loss: 0.529
[5,  3301] loss: 0.541
[5,  3601] loss: 0.533
[5,  3901] loss: 0.542
[5,  4201] loss: 0.533
[5,  4501] loss: 0.531
[5,  4801] loss: 0.534
[5,  5101] loss: 0.543
[5,  5401] loss: 0.542
[5,  5701] loss: 0.525
[5,  6001] loss: 0.533
[5,  6301] loss: 0.518
[5,  6601] loss: 0.547
[5,  6901] loss: 0.549
[5,  7201] loss: 0.535
[5,  7501] loss: 0.545
[5,  7801] loss: 0.534
[5,  8101] loss: 0.534
[5,  8401] loss: 0.521
[5,  8701] loss: 0.537
[5,  9001] loss: 0.536
[5,  9301] loss: 0.545
[5,  9601] loss: 0.527
[5,  9901] loss: 0.541
[5, 10201] loss: 0.529
[5, 10501] loss: 0.532
[5, 10801] loss: 0.548
[5, 11101] loss: 0.535
[5, 11401] loss: 0.530
[5, 11701] loss: 0.539
[5, 12001] loss: 0.541
[5, 12301] loss: 0.525


 50%|█████     | 5/10 [08:43<08:46, 105.33s/it]

[6,     1] loss: 0.002
[6,   301] loss: 0.516
[6,   601] loss: 0.524
[6,   901] loss: 0.525
[6,  1201] loss: 0.519
[6,  1501] loss: 0.534
[6,  1801] loss: 0.535
[6,  2101] loss: 0.519
[6,  2401] loss: 0.513
[6,  2701] loss: 0.532
[6,  3001] loss: 0.537
[6,  3301] loss: 0.534
[6,  3601] loss: 0.524
[6,  3901] loss: 0.531
[6,  4201] loss: 0.545
[6,  4501] loss: 0.524
[6,  4801] loss: 0.541
[6,  5101] loss: 0.544
[6,  5401] loss: 0.544
[6,  5701] loss: 0.522
[6,  6001] loss: 0.523
[6,  6301] loss: 0.521
[6,  6601] loss: 0.546
[6,  6901] loss: 0.533
[6,  7201] loss: 0.530
[6,  7501] loss: 0.536
[6,  7801] loss: 0.529
[6,  8101] loss: 0.521
[6,  8401] loss: 0.546
[6,  8701] loss: 0.532
[6,  9001] loss: 0.518
[6,  9301] loss: 0.537
[6,  9601] loss: 0.538
[6,  9901] loss: 0.525
[6, 10201] loss: 0.554
[6, 10501] loss: 0.535
[6, 10801] loss: 0.534
[6, 11101] loss: 0.537
[6, 11401] loss: 0.529
[6, 11701] loss: 0.536
[6, 12001] loss: 0.535
[6, 12301] loss: 0.536


 60%|██████    | 6/10 [10:29<07:02, 105.63s/it]

[7,     1] loss: 0.001
[7,   301] loss: 0.517
[7,   601] loss: 0.516
[7,   901] loss: 0.521
[7,  1201] loss: 0.507
[7,  1501] loss: 0.517
[7,  1801] loss: 0.521
[7,  2101] loss: 0.529
[7,  2401] loss: 0.535
[7,  2701] loss: 0.524
[7,  3001] loss: 0.527
[7,  3301] loss: 0.528
[7,  3601] loss: 0.517
[7,  3901] loss: 0.536
[7,  4201] loss: 0.534
[7,  4501] loss: 0.533
[7,  4801] loss: 0.526
[7,  5101] loss: 0.522
[7,  5401] loss: 0.527
[7,  5701] loss: 0.522
[7,  6001] loss: 0.530
[7,  6301] loss: 0.527
[7,  6601] loss: 0.520
[7,  6901] loss: 0.524
[7,  7201] loss: 0.538
[7,  7501] loss: 0.528
[7,  7801] loss: 0.515
[7,  8101] loss: 0.536
[7,  8401] loss: 0.536
[7,  8701] loss: 0.527
[7,  9001] loss: 0.520
[7,  9301] loss: 0.521
[7,  9601] loss: 0.531
[7,  9901] loss: 0.540
[7, 10201] loss: 0.526
[7, 10501] loss: 0.526
[7, 10801] loss: 0.542
[7, 11101] loss: 0.543
[7, 11401] loss: 0.537
[7, 11701] loss: 0.530
[7, 12001] loss: 0.531
[7, 12301] loss: 0.539


 70%|███████   | 7/10 [12:22<05:23, 107.92s/it]

[8,     1] loss: 0.002
[8,   301] loss: 0.512
[8,   601] loss: 0.533
[8,   901] loss: 0.512
[8,  1201] loss: 0.528
[8,  1501] loss: 0.519
[8,  1801] loss: 0.522
[8,  2101] loss: 0.517
[8,  2401] loss: 0.514
[8,  2701] loss: 0.519
[8,  3001] loss: 0.518
[8,  3301] loss: 0.516
[8,  3601] loss: 0.521
[8,  3901] loss: 0.517
[8,  4201] loss: 0.537
[8,  4501] loss: 0.535
[8,  4801] loss: 0.524
[8,  5101] loss: 0.545
[8,  5401] loss: 0.527
[8,  5701] loss: 0.526
[8,  6001] loss: 0.528
[8,  6301] loss: 0.540
[8,  6601] loss: 0.519
[8,  6901] loss: 0.535
[8,  7201] loss: 0.526
[8,  7501] loss: 0.533
[8,  7801] loss: 0.521
[8,  8101] loss: 0.528
[8,  8401] loss: 0.526
[8,  8701] loss: 0.525
[8,  9001] loss: 0.532
[8,  9301] loss: 0.528
[8,  9601] loss: 0.520
[8,  9901] loss: 0.517
[8, 10201] loss: 0.530
[8, 10501] loss: 0.535
[8, 10801] loss: 0.537
[8, 11101] loss: 0.532
[8, 11401] loss: 0.537
[8, 11701] loss: 0.532
[8, 12001] loss: 0.537
[8, 12301] loss: 0.530


 80%|████████  | 8/10 [14:10<03:36, 108.18s/it]

[9,     1] loss: 0.002
[9,   301] loss: 0.515
[9,   601] loss: 0.520
[9,   901] loss: 0.526
[9,  1201] loss: 0.529
[9,  1501] loss: 0.515
[9,  1801] loss: 0.519
[9,  2101] loss: 0.543
[9,  2401] loss: 0.522
[9,  2701] loss: 0.529
[9,  3001] loss: 0.521
[9,  3301] loss: 0.522
[9,  3601] loss: 0.525
[9,  3901] loss: 0.525
[9,  4201] loss: 0.508
[9,  4501] loss: 0.531
[9,  4801] loss: 0.529
[9,  5101] loss: 0.514
[9,  5401] loss: 0.520
[9,  5701] loss: 0.520
[9,  6001] loss: 0.526
[9,  6301] loss: 0.525
[9,  6601] loss: 0.530
[9,  6901] loss: 0.519
[9,  7201] loss: 0.542
[9,  7501] loss: 0.527
[9,  7801] loss: 0.532
[9,  8101] loss: 0.543
[9,  8401] loss: 0.534
[9,  8701] loss: 0.525
[9,  9001] loss: 0.524
[9,  9301] loss: 0.528
[9,  9601] loss: 0.524
[9,  9901] loss: 0.525
[9, 10201] loss: 0.535
[9, 10501] loss: 0.531
[9, 10801] loss: 0.536
[9, 11101] loss: 0.531
[9, 11401] loss: 0.532
[9, 11701] loss: 0.531
[9, 12001] loss: 0.543
[9, 12301] loss: 0.530


 90%|█████████ | 9/10 [16:00<01:48, 108.61s/it]

[10,     1] loss: 0.002
[10,   301] loss: 0.515
[10,   601] loss: 0.520
[10,   901] loss: 0.515
[10,  1201] loss: 0.523
[10,  1501] loss: 0.513
[10,  1801] loss: 0.508
[10,  2101] loss: 0.518
[10,  2401] loss: 0.517
[10,  2701] loss: 0.511
[10,  3001] loss: 0.524
[10,  3301] loss: 0.530
[10,  3601] loss: 0.522
[10,  3901] loss: 0.523
[10,  4201] loss: 0.524
[10,  4501] loss: 0.528
[10,  4801] loss: 0.532
[10,  5101] loss: 0.526
[10,  5401] loss: 0.526
[10,  5701] loss: 0.534
[10,  6001] loss: 0.523
[10,  6301] loss: 0.515
[10,  6601] loss: 0.521
[10,  6901] loss: 0.531
[10,  7201] loss: 0.519
[10,  7501] loss: 0.527
[10,  7801] loss: 0.530
[10,  8101] loss: 0.533
[10,  8401] loss: 0.538
[10,  8701] loss: 0.535
[10,  9001] loss: 0.530
[10,  9301] loss: 0.530
[10,  9601] loss: 0.522
[10,  9901] loss: 0.528
[10, 10201] loss: 0.523
[10, 10501] loss: 0.545
[10, 10801] loss: 0.525
[10, 11101] loss: 0.526
[10, 11401] loss: 0.517
[10, 11701] loss: 0.535
[10, 12001] loss: 0.512
[10, 12301] loss

100%|██████████| 10/10 [17:46<00:00, 106.68s/it]

End of training.



