In [9]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [2]:
transformer = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])])

trainset = torchvision.datasets.CIFAR10(root="./data", train=True,
                                        transform=transformer, download=True)
testset = torchvision.datasets.CIFAR10(root="./data", train=False,
                                       transform=transformer, download=True)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

100%|██████████| 170M/170M [00:05<00:00, 29.1MB/s]


In [3]:
print(len(trainset))
print(len(testset))

50000
10000


In [4]:
trainset[0][0].shape # 3 цветовых канала, сами картинки размером 32 на 32 пикселя

torch.Size([3, 32, 32])

In [5]:
class MyOwnNet(nn.Module):
    def __init__(self):
      super(MyOwnNet, self).__init__()
      self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3)
      self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
      self.conv2 = nn.Conv2d(in_channels=12, out_channels=16, kernel_size=3)
      self.conv3 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=3)

      self._to_linear = None
      self._get_conv_size(torch.randn(1, 3, 32, 32))

      self.flatten = nn.Flatten()
      self.fc1 = nn.Linear(self._to_linear, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)

    def _get_conv_size(self, x):
          x = self.pool(F.relu(self.conv1(x)))
          x = self.pool(F.relu(self.conv2(x)))
          x = self.pool(F.relu(self.conv3(x)))
          self._to_linear = x.flatten(1).shape[1]

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x) # Последний слой без ФА, т.к. в Лосс уже встроен софтмакс
        return x

In [22]:
model = MyOwnNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [23]:
dataloaders = {"train": trainloader, "test": testloader}

In [24]:
model.train()
epochs = 100
accuracy = {"train": [], "test": []}
for epoch in tqdm(range(epochs)):
  for action, dataloader in dataloaders.items():
    epoch_correct = 0
    epoch_all = 0
    for X_batch, y_batch in dataloader:
      X_batch = X_batch.to(device)
      y_batch = y_batch.to(device)
      if action == "train":
        model.train()
        optimizer.zero_grad()
        logits = model(X_batch)
      else:
        model.eval()
        with torch.no_grad():
          logits = model(X_batch)
      y_pred = logits.argmax(-1)
      correct_preds = (y_batch == y_pred).sum()
      epoch_correct += correct_preds.item()
      epoch_all += y_batch.shape[0]
      if action == "train":
        loss = criterion(logits, y_batch)
        loss.backward()
        optimizer.step()
    if action == "train":
      print(f"Epoch №{epoch + 1}")
    print(f"Loader: {action}, accuracy: {epoch_correct/epoch_all}")
    accuracy[action].append(epoch_correct/epoch_all)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch №1
Loader: train, accuracy: 0.32292


  1%|          | 1/100 [00:23<38:47, 23.51s/it]

Loader: test, accuracy: 0.41654
Epoch №2
Loader: train, accuracy: 0.44932


  2%|▏         | 2/100 [00:46<38:17, 23.44s/it]

Loader: test, accuracy: 0.49136
Epoch №3
Loader: train, accuracy: 0.50728


  3%|▎         | 3/100 [01:10<37:49, 23.40s/it]

Loader: test, accuracy: 0.53154
Epoch №4
Loader: train, accuracy: 0.54074


  4%|▍         | 4/100 [01:33<37:30, 23.44s/it]

Loader: test, accuracy: 0.556
Epoch №5
Loader: train, accuracy: 0.5652


  5%|▌         | 5/100 [01:57<37:00, 23.37s/it]

Loader: test, accuracy: 0.58246
Epoch №6
Loader: train, accuracy: 0.58904


  6%|▌         | 6/100 [02:19<36:04, 23.03s/it]

Loader: test, accuracy: 0.59066
Epoch №7
Loader: train, accuracy: 0.60388


  7%|▋         | 7/100 [02:42<35:47, 23.09s/it]

Loader: test, accuracy: 0.61772
Epoch №8
Loader: train, accuracy: 0.62052


  8%|▊         | 8/100 [03:05<35:22, 23.07s/it]

Loader: test, accuracy: 0.62858
Epoch №9
Loader: train, accuracy: 0.63438


  9%|▉         | 9/100 [03:28<34:55, 23.03s/it]

Loader: test, accuracy: 0.64802
Epoch №10
Loader: train, accuracy: 0.64102


 10%|█         | 10/100 [03:51<34:36, 23.07s/it]

Loader: test, accuracy: 0.65634
Epoch №11
Loader: train, accuracy: 0.6518


 11%|█         | 11/100 [04:13<33:50, 22.81s/it]

Loader: test, accuracy: 0.66752
Epoch №12
Loader: train, accuracy: 0.65974


 12%|█▏        | 12/100 [04:37<33:57, 23.15s/it]

Loader: test, accuracy: 0.67852
Epoch №13
Loader: train, accuracy: 0.66658


 13%|█▎        | 13/100 [05:00<33:33, 23.14s/it]

Loader: test, accuracy: 0.66992
Epoch №14
Loader: train, accuracy: 0.6763


 14%|█▍        | 14/100 [05:23<33:05, 23.08s/it]

Loader: test, accuracy: 0.6906
Epoch №15
Loader: train, accuracy: 0.68276


 15%|█▌        | 15/100 [05:46<32:39, 23.06s/it]

Loader: test, accuracy: 0.70004
Epoch №16
Loader: train, accuracy: 0.68832


 16%|█▌        | 16/100 [06:09<31:54, 22.79s/it]

Loader: test, accuracy: 0.701
Epoch №17
Loader: train, accuracy: 0.69348


 17%|█▋        | 17/100 [06:31<31:31, 22.79s/it]

Loader: test, accuracy: 0.7077
Epoch №18
Loader: train, accuracy: 0.70064


 18%|█▊        | 18/100 [06:54<31:09, 22.80s/it]

Loader: test, accuracy: 0.70722
Epoch №19
Loader: train, accuracy: 0.70564


 19%|█▉        | 19/100 [07:17<30:50, 22.84s/it]

Loader: test, accuracy: 0.7212
Epoch №20
Loader: train, accuracy: 0.71448


 20%|██        | 20/100 [07:40<30:31, 22.89s/it]

Loader: test, accuracy: 0.71466
Epoch №21
Loader: train, accuracy: 0.7142


 21%|██        | 21/100 [08:03<29:56, 22.74s/it]

Loader: test, accuracy: 0.71804
Epoch №22
Loader: train, accuracy: 0.71802


 22%|██▏       | 22/100 [08:25<29:32, 22.72s/it]

Loader: test, accuracy: 0.72672
Epoch №23
Loader: train, accuracy: 0.72066


 23%|██▎       | 23/100 [08:49<29:40, 23.12s/it]

Loader: test, accuracy: 0.73186
Epoch №24
Loader: train, accuracy: 0.72546


 24%|██▍       | 24/100 [09:13<29:20, 23.16s/it]

Loader: test, accuracy: 0.7284
Epoch №25
Loader: train, accuracy: 0.72842


 25%|██▌       | 25/100 [09:36<28:57, 23.17s/it]

Loader: test, accuracy: 0.72842
Epoch №26
Loader: train, accuracy: 0.73096


 26%|██▌       | 26/100 [09:59<28:27, 23.07s/it]

Loader: test, accuracy: 0.7222
Epoch №27
Loader: train, accuracy: 0.73484


 27%|██▋       | 27/100 [10:21<27:47, 22.84s/it]

Loader: test, accuracy: 0.74228
Epoch №28
Loader: train, accuracy: 0.734


 28%|██▊       | 28/100 [10:44<27:32, 22.95s/it]

Loader: test, accuracy: 0.7452
Epoch №29
Loader: train, accuracy: 0.74128


 29%|██▉       | 29/100 [11:07<27:12, 22.99s/it]

Loader: test, accuracy: 0.74712
Epoch №30
Loader: train, accuracy: 0.74348


 30%|███       | 30/100 [11:30<26:52, 23.03s/it]

Loader: test, accuracy: 0.75916
Epoch №31
Loader: train, accuracy: 0.74676


 31%|███       | 31/100 [11:54<26:43, 23.23s/it]

Loader: test, accuracy: 0.7616
Epoch №32
Loader: train, accuracy: 0.74926


 32%|███▏      | 32/100 [12:17<26:17, 23.20s/it]

Loader: test, accuracy: 0.7652
Epoch №33
Loader: train, accuracy: 0.75124


 33%|███▎      | 33/100 [12:40<25:44, 23.05s/it]

Loader: test, accuracy: 0.76564
Epoch №34
Loader: train, accuracy: 0.7539


 34%|███▍      | 34/100 [13:04<25:48, 23.46s/it]

Loader: test, accuracy: 0.76232
Epoch №35
Loader: train, accuracy: 0.75946


 35%|███▌      | 35/100 [13:28<25:23, 23.43s/it]

Loader: test, accuracy: 0.76894
Epoch №36
Loader: train, accuracy: 0.7588


 36%|███▌      | 36/100 [13:51<24:57, 23.39s/it]

Loader: test, accuracy: 0.75734
Epoch №37
Loader: train, accuracy: 0.76282


 37%|███▋      | 37/100 [14:14<24:35, 23.42s/it]

Loader: test, accuracy: 0.78156
Epoch №38
Loader: train, accuracy: 0.76276


 38%|███▊      | 38/100 [14:38<24:09, 23.38s/it]

Loader: test, accuracy: 0.77136
Epoch №39
Loader: train, accuracy: 0.76594


 39%|███▉      | 39/100 [15:00<23:21, 22.98s/it]

Loader: test, accuracy: 0.77396
Epoch №40
Loader: train, accuracy: 0.77104


 40%|████      | 40/100 [15:23<23:02, 23.04s/it]

Loader: test, accuracy: 0.77774
Epoch №41
Loader: train, accuracy: 0.76976


 41%|████      | 41/100 [15:46<22:45, 23.15s/it]

Loader: test, accuracy: 0.78606
Epoch №42
Loader: train, accuracy: 0.77042


 42%|████▏     | 42/100 [16:10<22:27, 23.23s/it]

Loader: test, accuracy: 0.77876
Epoch №43
Loader: train, accuracy: 0.77408


 43%|████▎     | 43/100 [16:33<22:05, 23.25s/it]

Loader: test, accuracy: 0.78334
Epoch №44
Loader: train, accuracy: 0.77762


 44%|████▍     | 44/100 [16:56<21:43, 23.28s/it]

Loader: test, accuracy: 0.78682
Epoch №45
Loader: train, accuracy: 0.78032


 45%|████▌     | 45/100 [17:20<21:25, 23.38s/it]

Loader: test, accuracy: 0.79198
Epoch №46
Loader: train, accuracy: 0.77916


 46%|████▌     | 46/100 [17:43<20:51, 23.18s/it]

Loader: test, accuracy: 0.79592
Epoch №47
Loader: train, accuracy: 0.78328


 47%|████▋     | 47/100 [18:05<20:21, 23.05s/it]

Loader: test, accuracy: 0.79434
Epoch №48
Loader: train, accuracy: 0.78376


 48%|████▊     | 48/100 [18:28<19:57, 23.02s/it]

Loader: test, accuracy: 0.79712
Epoch №49
Loader: train, accuracy: 0.78736


 49%|████▉     | 49/100 [18:51<19:31, 22.98s/it]

Loader: test, accuracy: 0.80528
Epoch №50
Loader: train, accuracy: 0.78876


 50%|█████     | 50/100 [19:13<18:55, 22.71s/it]

Loader: test, accuracy: 0.81204
Epoch №51
Loader: train, accuracy: 0.79046


 51%|█████     | 51/100 [19:36<18:35, 22.76s/it]

Loader: test, accuracy: 0.78778
Epoch №52
Loader: train, accuracy: 0.79112


 52%|█████▏    | 52/100 [19:59<18:17, 22.86s/it]

Loader: test, accuracy: 0.8109
Epoch №53
Loader: train, accuracy: 0.79154


 53%|█████▎    | 53/100 [20:22<17:54, 22.85s/it]

Loader: test, accuracy: 0.80148
Epoch №54
Loader: train, accuracy: 0.79412


 54%|█████▍    | 54/100 [20:45<17:32, 22.88s/it]

Loader: test, accuracy: 0.80912
Epoch №55
Loader: train, accuracy: 0.79702


 55%|█████▌    | 55/100 [21:07<16:56, 22.60s/it]

Loader: test, accuracy: 0.79816
Epoch №56
Loader: train, accuracy: 0.8002


 56%|█████▌    | 56/100 [21:30<16:38, 22.69s/it]

Loader: test, accuracy: 0.81342
Epoch №57
Loader: train, accuracy: 0.79816


 57%|█████▋    | 57/100 [21:54<16:33, 23.11s/it]

Loader: test, accuracy: 0.81364
Epoch №58
Loader: train, accuracy: 0.79984


 58%|█████▊    | 58/100 [22:17<16:08, 23.07s/it]

Loader: test, accuracy: 0.81356
Epoch №59
Loader: train, accuracy: 0.8017


 59%|█████▉    | 59/100 [22:40<15:44, 23.03s/it]

Loader: test, accuracy: 0.80128
Epoch №60
Loader: train, accuracy: 0.80458


 60%|██████    | 60/100 [23:02<15:07, 22.69s/it]

Loader: test, accuracy: 0.82038
Epoch №61
Loader: train, accuracy: 0.80598


 61%|██████    | 61/100 [23:25<14:47, 22.77s/it]

Loader: test, accuracy: 0.8229
Epoch №62
Loader: train, accuracy: 0.80572


 62%|██████▏   | 62/100 [23:48<14:27, 22.83s/it]

Loader: test, accuracy: 0.80356
Epoch №63
Loader: train, accuracy: 0.81026


 63%|██████▎   | 63/100 [24:11<14:10, 22.97s/it]

Loader: test, accuracy: 0.82992
Epoch №64
Loader: train, accuracy: 0.80894


 64%|██████▍   | 64/100 [24:34<13:48, 23.01s/it]

Loader: test, accuracy: 0.81708
Epoch №65
Loader: train, accuracy: 0.81224


 65%|██████▌   | 65/100 [24:56<13:17, 22.78s/it]

Loader: test, accuracy: 0.81636
Epoch №66
Loader: train, accuracy: 0.8124


 66%|██████▌   | 66/100 [25:19<12:55, 22.82s/it]

Loader: test, accuracy: 0.82868
Epoch №67
Loader: train, accuracy: 0.813


 67%|██████▋   | 67/100 [25:42<12:34, 22.86s/it]

Loader: test, accuracy: 0.8259
Epoch №68
Loader: train, accuracy: 0.81682


 68%|██████▊   | 68/100 [26:06<12:24, 23.26s/it]

Loader: test, accuracy: 0.82502
Epoch №69
Loader: train, accuracy: 0.81466


 69%|██████▉   | 69/100 [26:31<12:09, 23.54s/it]

Loader: test, accuracy: 0.82336
Epoch №70
Loader: train, accuracy: 0.81782


 70%|███████   | 70/100 [26:54<11:41, 23.39s/it]

Loader: test, accuracy: 0.81468
Epoch №71
Loader: train, accuracy: 0.81948


 71%|███████   | 71/100 [27:16<11:12, 23.20s/it]

Loader: test, accuracy: 0.84338
Epoch №72
Loader: train, accuracy: 0.81956


 72%|███████▏  | 72/100 [27:40<10:49, 23.19s/it]

Loader: test, accuracy: 0.82842
Epoch №73
Loader: train, accuracy: 0.82132


 73%|███████▎  | 73/100 [28:03<10:27, 23.24s/it]

Loader: test, accuracy: 0.842
Epoch №74
Loader: train, accuracy: 0.8228


 74%|███████▍  | 74/100 [28:26<10:06, 23.32s/it]

Loader: test, accuracy: 0.83298
Epoch №75
Loader: train, accuracy: 0.82488


 75%|███████▌  | 75/100 [28:50<09:45, 23.42s/it]

Loader: test, accuracy: 0.83742
Epoch №76
Loader: train, accuracy: 0.82296


 76%|███████▌  | 76/100 [29:13<09:19, 23.33s/it]

Loader: test, accuracy: 0.83732
Epoch №77
Loader: train, accuracy: 0.82328


 77%|███████▋  | 77/100 [29:36<08:55, 23.27s/it]

Loader: test, accuracy: 0.84622
Epoch №78
Loader: train, accuracy: 0.82646


 78%|███████▊  | 78/100 [30:00<08:34, 23.39s/it]

Loader: test, accuracy: 0.83626
Epoch №79
Loader: train, accuracy: 0.82618


 79%|███████▉  | 79/100 [30:25<08:18, 23.75s/it]

Loader: test, accuracy: 0.8414
Epoch №80
Loader: train, accuracy: 0.83096


 80%|████████  | 80/100 [30:48<07:52, 23.64s/it]

Loader: test, accuracy: 0.84516
Epoch №81
Loader: train, accuracy: 0.82944


 81%|████████  | 81/100 [31:11<07:27, 23.56s/it]

Loader: test, accuracy: 0.84214
Epoch №82
Loader: train, accuracy: 0.83054


 82%|████████▏ | 82/100 [31:35<07:03, 23.55s/it]

Loader: test, accuracy: 0.84886
Epoch №83
Loader: train, accuracy: 0.83184


 83%|████████▎ | 83/100 [31:58<06:36, 23.30s/it]

Loader: test, accuracy: 0.84538
Epoch №84
Loader: train, accuracy: 0.83336


 84%|████████▍ | 84/100 [32:21<06:13, 23.34s/it]

Loader: test, accuracy: 0.84066
Epoch №85
Loader: train, accuracy: 0.83374


 85%|████████▌ | 85/100 [32:45<05:52, 23.48s/it]

Loader: test, accuracy: 0.83874
Epoch №86
Loader: train, accuracy: 0.83382


 86%|████████▌ | 86/100 [33:09<05:29, 23.53s/it]

Loader: test, accuracy: 0.85364
Epoch №87
Loader: train, accuracy: 0.83666


 87%|████████▋ | 87/100 [33:33<05:07, 23.67s/it]

Loader: test, accuracy: 0.84982
Epoch №88
Loader: train, accuracy: 0.83306


 88%|████████▊ | 88/100 [33:56<04:44, 23.68s/it]

Loader: test, accuracy: 0.84456
Epoch №89
Loader: train, accuracy: 0.83644


 89%|████████▉ | 89/100 [34:20<04:22, 23.82s/it]

Loader: test, accuracy: 0.85378
Epoch №90
Loader: train, accuracy: 0.8376


 90%|█████████ | 90/100 [34:44<03:57, 23.72s/it]

Loader: test, accuracy: 0.83842
Epoch №91
Loader: train, accuracy: 0.8402


 91%|█████████ | 91/100 [35:08<03:33, 23.71s/it]

Loader: test, accuracy: 0.8493
Epoch №92
Loader: train, accuracy: 0.84192


 92%|█████████▏| 92/100 [35:31<03:09, 23.75s/it]

Loader: test, accuracy: 0.85906
Epoch №93
Loader: train, accuracy: 0.84062


 93%|█████████▎| 93/100 [35:55<02:46, 23.76s/it]

Loader: test, accuracy: 0.85628
Epoch №94
Loader: train, accuracy: 0.84484


 94%|█████████▍| 94/100 [36:19<02:22, 23.81s/it]

Loader: test, accuracy: 0.86342
Epoch №95
Loader: train, accuracy: 0.84112


 95%|█████████▌| 95/100 [36:43<01:58, 23.77s/it]

Loader: test, accuracy: 0.85104
Epoch №96
Loader: train, accuracy: 0.84214


 96%|█████████▌| 96/100 [37:06<01:34, 23.66s/it]

Loader: test, accuracy: 0.8542
Epoch №97
Loader: train, accuracy: 0.84132


 97%|█████████▋| 97/100 [37:29<01:10, 23.51s/it]

Loader: test, accuracy: 0.86286
Epoch №98
Loader: train, accuracy: 0.8465


 98%|█████████▊| 98/100 [37:53<00:47, 23.63s/it]

Loader: test, accuracy: 0.86558
Epoch №99
Loader: train, accuracy: 0.84598


 99%|█████████▉| 99/100 [38:17<00:23, 23.64s/it]

Loader: test, accuracy: 0.85734
Epoch №100
Loader: train, accuracy: 0.84482


100%|██████████| 100/100 [38:42<00:00, 23.22s/it]

Loader: test, accuracy: 0.85814





In [25]:
torch.save(model.state_dict(), 'CIFAR10_v1_100_EPOCHS.pth')