In [11]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [12]:
transformer = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])])

trainset = torchvision.datasets.CIFAR10(root="./data", train=True,
                                        transform=transformer, download=True)
testset = torchvision.datasets.CIFAR10(root="./data", train=False,
                                       transform=transformer, download=True)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

In [13]:
print(len(trainset))
print(len(testset))

50000
10000


In [14]:
trainset[0][0].shape # 3 цветовых канала, сами картинки размером 32 на 32 пикселя

torch.Size([3, 32, 32])

In [15]:
class MyOwnNet(nn.Module):
    def __init__(self):
      super(MyOwnNet, self).__init__()
      self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3)
      self.bn1 = nn.BatchNorm2d(12)
      self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
      self.conv2 = nn.Conv2d(in_channels=12, out_channels=16, kernel_size=3)
      self.bn2 = nn.BatchNorm2d(16)
      self.conv3 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=3)
      self.bn3 = nn.BatchNorm2d(24)

      self.dropout = nn.Dropout(p=0.3)
      self._to_linear = None
      self._get_conv_size(torch.randn(1, 3, 32, 32))

      self.flatten = nn.Flatten()
      self.fc1 = nn.Linear(self._to_linear, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)

    def _get_conv_size(self, x):
          x = self.pool(F.relu(self.bn1(self.conv1(x))))
          x = self.pool(F.relu(self.bn2(self.conv2(x))))
          x = self.pool(F.relu(self.bn3(self.conv3(x))))
          self._to_linear = x.flatten(1).shape[1]

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.flatten(x)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x) # Последний слой без ФА, т.к. в Лосс уже встроен софтмакс
        return x

In [17]:
model = MyOwnNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [18]:
dataloaders = {"train": trainloader, "test": testloader}

In [19]:
model.train()
epochs = 100
accuracy = {"train": [], "test": []}
for epoch in tqdm(range(epochs)):
  for action, dataloader in dataloaders.items():
    epoch_correct = 0
    epoch_all = 0
    for X_batch, y_batch in dataloader:
      X_batch = X_batch.to(device)
      y_batch = y_batch.to(device)
      if action == "train":
        model.train()
        optimizer.zero_grad()
        logits = model(X_batch)
      else:
        model.eval()
        with torch.no_grad():
          logits = model(X_batch)
      y_pred = logits.argmax(-1)
      correct_preds = (y_batch == y_pred).sum()
      epoch_correct += correct_preds.item()
      epoch_all += y_batch.shape[0]
      if action == "train":
        loss = criterion(logits, y_batch)
        loss.backward()
        optimizer.step()
    if action == "train":
      print(f"Epoch №{epoch + 1}")
    print(f"Loader: {action}, accuracy: {epoch_correct/epoch_all}")
    accuracy[action].append(epoch_correct/epoch_all)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch №1
Loader: train, accuracy: 0.2977


  1%|          | 1/100 [00:28<46:46, 28.35s/it]

Loader: test, accuracy: 0.39634
Epoch №2
Loader: train, accuracy: 0.4111


  2%|▏         | 2/100 [00:53<43:07, 26.40s/it]

Loader: test, accuracy: 0.46958
Epoch №3
Loader: train, accuracy: 0.4613


  3%|▎         | 3/100 [01:17<41:09, 25.45s/it]

Loader: test, accuracy: 0.50928
Epoch №4
Loader: train, accuracy: 0.49312


  4%|▍         | 4/100 [01:42<40:27, 25.29s/it]

Loader: test, accuracy: 0.54858
Epoch №5
Loader: train, accuracy: 0.5081


  5%|▌         | 5/100 [02:07<39:39, 25.05s/it]

Loader: test, accuracy: 0.56822
Epoch №6
Loader: train, accuracy: 0.5258


  6%|▌         | 6/100 [02:31<38:58, 24.87s/it]

Loader: test, accuracy: 0.58228
Epoch №7
Loader: train, accuracy: 0.53754


  7%|▋         | 7/100 [02:55<38:09, 24.62s/it]

Loader: test, accuracy: 0.58532
Epoch №8
Loader: train, accuracy: 0.54728


  8%|▊         | 8/100 [03:20<37:30, 24.47s/it]

Loader: test, accuracy: 0.60432
Epoch №9
Loader: train, accuracy: 0.55986


  9%|▉         | 9/100 [03:43<36:47, 24.25s/it]

Loader: test, accuracy: 0.61568
Epoch №10
Loader: train, accuracy: 0.56126


 10%|█         | 10/100 [04:07<36:14, 24.16s/it]

Loader: test, accuracy: 0.6067
Epoch №11
Loader: train, accuracy: 0.5693


 11%|█         | 11/100 [04:31<35:45, 24.10s/it]

Loader: test, accuracy: 0.62898
Epoch №12
Loader: train, accuracy: 0.5725


 12%|█▏        | 12/100 [04:55<35:20, 24.09s/it]

Loader: test, accuracy: 0.63408
Epoch №13
Loader: train, accuracy: 0.58374


 13%|█▎        | 13/100 [05:19<34:54, 24.07s/it]

Loader: test, accuracy: 0.6419
Epoch №14
Loader: train, accuracy: 0.58604


 14%|█▍        | 14/100 [05:44<34:31, 24.09s/it]

Loader: test, accuracy: 0.64766
Epoch №15
Loader: train, accuracy: 0.59214


 15%|█▌        | 15/100 [06:08<34:13, 24.16s/it]

Loader: test, accuracy: 0.64646
Epoch №16
Loader: train, accuracy: 0.59658


 16%|█▌        | 16/100 [06:32<33:49, 24.16s/it]

Loader: test, accuracy: 0.65356
Epoch №17
Loader: train, accuracy: 0.59662


 17%|█▋        | 17/100 [06:56<33:24, 24.15s/it]

Loader: test, accuracy: 0.65648
Epoch №18
Loader: train, accuracy: 0.6024


 18%|█▊        | 18/100 [07:20<32:52, 24.05s/it]

Loader: test, accuracy: 0.66304
Epoch №19
Loader: train, accuracy: 0.60158


 19%|█▉        | 19/100 [07:44<32:18, 23.93s/it]

Loader: test, accuracy: 0.65404
Epoch №20
Loader: train, accuracy: 0.60734


 20%|██        | 20/100 [08:08<32:02, 24.03s/it]

Loader: test, accuracy: 0.66116
Epoch №21
Loader: train, accuracy: 0.6084


 21%|██        | 21/100 [08:32<31:48, 24.16s/it]

Loader: test, accuracy: 0.67248
Epoch №22
Loader: train, accuracy: 0.61288


 22%|██▏       | 22/100 [08:57<31:39, 24.35s/it]

Loader: test, accuracy: 0.67504
Epoch №23
Loader: train, accuracy: 0.615


 23%|██▎       | 23/100 [09:21<31:11, 24.31s/it]

Loader: test, accuracy: 0.68282
Epoch №24
Loader: train, accuracy: 0.6149


 24%|██▍       | 24/100 [09:46<30:46, 24.29s/it]

Loader: test, accuracy: 0.67392
Epoch №25
Loader: train, accuracy: 0.61992


 25%|██▌       | 25/100 [10:10<30:17, 24.24s/it]

Loader: test, accuracy: 0.67832
Epoch №26
Loader: train, accuracy: 0.62152


 26%|██▌       | 26/100 [10:34<29:50, 24.19s/it]

Loader: test, accuracy: 0.688
Epoch №27
Loader: train, accuracy: 0.62128


 27%|██▋       | 27/100 [10:58<29:17, 24.08s/it]

Loader: test, accuracy: 0.68432
Epoch №28
Loader: train, accuracy: 0.62368


 28%|██▊       | 28/100 [11:21<28:41, 23.91s/it]

Loader: test, accuracy: 0.68964
Epoch №29
Loader: train, accuracy: 0.62624


 29%|██▉       | 29/100 [11:45<28:17, 23.91s/it]

Loader: test, accuracy: 0.6842
Epoch №30
Loader: train, accuracy: 0.6256


 30%|███       | 30/100 [12:09<28:04, 24.07s/it]

Loader: test, accuracy: 0.68772
Epoch №31
Loader: train, accuracy: 0.63084


 31%|███       | 31/100 [12:34<27:41, 24.09s/it]

Loader: test, accuracy: 0.6965
Epoch №32
Loader: train, accuracy: 0.63158


 32%|███▏      | 32/100 [12:58<27:23, 24.17s/it]

Loader: test, accuracy: 0.69058
Epoch №33
Loader: train, accuracy: 0.63144


 33%|███▎      | 33/100 [13:22<27:04, 24.24s/it]

Loader: test, accuracy: 0.69658
Epoch №34
Loader: train, accuracy: 0.6336


 34%|███▍      | 34/100 [13:47<26:48, 24.37s/it]

Loader: test, accuracy: 0.70018
Epoch №35
Loader: train, accuracy: 0.6315


 35%|███▌      | 35/100 [14:12<26:29, 24.45s/it]

Loader: test, accuracy: 0.70104
Epoch №36
Loader: train, accuracy: 0.63486


 36%|███▌      | 36/100 [14:36<26:07, 24.50s/it]

Loader: test, accuracy: 0.70658
Epoch №37
Loader: train, accuracy: 0.63676


 37%|███▋      | 37/100 [15:01<25:40, 24.44s/it]

Loader: test, accuracy: 0.70366
Epoch №38
Loader: train, accuracy: 0.63614


 38%|███▊      | 38/100 [15:25<25:13, 24.42s/it]

Loader: test, accuracy: 0.70124
Epoch №39
Loader: train, accuracy: 0.63708


 39%|███▉      | 39/100 [15:49<24:37, 24.22s/it]

Loader: test, accuracy: 0.69748
Epoch №40
Loader: train, accuracy: 0.63896


 40%|████      | 40/100 [16:12<23:59, 23.99s/it]

Loader: test, accuracy: 0.71012
Epoch №41
Loader: train, accuracy: 0.6409


 41%|████      | 41/100 [16:36<23:35, 23.99s/it]

Loader: test, accuracy: 0.7152
Epoch №42
Loader: train, accuracy: 0.64132


 42%|████▏     | 42/100 [17:00<23:11, 23.99s/it]

Loader: test, accuracy: 0.70618
Epoch №43
Loader: train, accuracy: 0.6433


 43%|████▎     | 43/100 [17:24<22:44, 23.95s/it]

Loader: test, accuracy: 0.708
Epoch №44
Loader: train, accuracy: 0.64212


 44%|████▍     | 44/100 [17:48<22:21, 23.95s/it]

Loader: test, accuracy: 0.70686
Epoch №45
Loader: train, accuracy: 0.64276


 45%|████▌     | 45/100 [18:12<21:57, 23.96s/it]

Loader: test, accuracy: 0.71038
Epoch №46
Loader: train, accuracy: 0.6469


 46%|████▌     | 46/100 [18:36<21:34, 23.97s/it]

Loader: test, accuracy: 0.70792
Epoch №47
Loader: train, accuracy: 0.6426


 47%|████▋     | 47/100 [18:59<21:03, 23.84s/it]

Loader: test, accuracy: 0.71452
Epoch №48
Loader: train, accuracy: 0.64618


 48%|████▊     | 48/100 [19:23<20:35, 23.76s/it]

Loader: test, accuracy: 0.7201
Epoch №49
Loader: train, accuracy: 0.6474


 49%|████▉     | 49/100 [19:47<20:19, 23.91s/it]

Loader: test, accuracy: 0.71758
Epoch №50
Loader: train, accuracy: 0.64616


 50%|█████     | 50/100 [20:12<20:00, 24.01s/it]

Loader: test, accuracy: 0.71086
Epoch №51
Loader: train, accuracy: 0.64638


 51%|█████     | 51/100 [20:36<19:37, 24.03s/it]

Loader: test, accuracy: 0.71944
Epoch №52
Loader: train, accuracy: 0.65132


 52%|█████▏    | 52/100 [21:00<19:13, 24.04s/it]

Loader: test, accuracy: 0.71522
Epoch №53
Loader: train, accuracy: 0.6505


 53%|█████▎    | 53/100 [21:24<18:49, 24.02s/it]

Loader: test, accuracy: 0.7175
Epoch №54
Loader: train, accuracy: 0.65012


 54%|█████▍    | 54/100 [21:48<18:25, 24.03s/it]

Loader: test, accuracy: 0.71346
Epoch №55
Loader: train, accuracy: 0.65126


 55%|█████▌    | 55/100 [22:12<18:01, 24.02s/it]

Loader: test, accuracy: 0.7244
Epoch №56
Loader: train, accuracy: 0.65086


 56%|█████▌    | 56/100 [22:35<17:28, 23.83s/it]

Loader: test, accuracy: 0.72696
Epoch №57
Loader: train, accuracy: 0.65304


 57%|█████▋    | 57/100 [22:59<17:01, 23.74s/it]

Loader: test, accuracy: 0.70564
Epoch №58
Loader: train, accuracy: 0.64966


 58%|█████▊    | 58/100 [23:23<16:39, 23.80s/it]

Loader: test, accuracy: 0.72854
Epoch №59
Loader: train, accuracy: 0.65132


 59%|█████▉    | 59/100 [23:47<16:17, 23.84s/it]

Loader: test, accuracy: 0.73194
Epoch №60
Loader: train, accuracy: 0.6546


 60%|██████    | 60/100 [24:13<16:22, 24.57s/it]

Loader: test, accuracy: 0.723
Epoch №61
Loader: train, accuracy: 0.65112


 61%|██████    | 61/100 [24:38<16:02, 24.67s/it]

Loader: test, accuracy: 0.72782
Epoch №62
Loader: train, accuracy: 0.65664


 62%|██████▏   | 62/100 [25:01<15:24, 24.32s/it]

Loader: test, accuracy: 0.73076
Epoch №63
Loader: train, accuracy: 0.65608


 63%|██████▎   | 63/100 [25:25<14:50, 24.06s/it]

Loader: test, accuracy: 0.72638
Epoch №64
Loader: train, accuracy: 0.65662


 64%|██████▍   | 64/100 [25:48<14:21, 23.92s/it]

Loader: test, accuracy: 0.7275
Epoch №65
Loader: train, accuracy: 0.65644


 65%|██████▌   | 65/100 [26:12<13:50, 23.73s/it]

Loader: test, accuracy: 0.7325
Epoch №66
Loader: train, accuracy: 0.65654


 66%|██████▌   | 66/100 [26:35<13:24, 23.67s/it]

Loader: test, accuracy: 0.73118
Epoch №67
Loader: train, accuracy: 0.66062


 67%|██████▋   | 67/100 [26:59<13:03, 23.74s/it]

Loader: test, accuracy: 0.73158
Epoch №68
Loader: train, accuracy: 0.65728


 68%|██████▊   | 68/100 [27:23<12:42, 23.84s/it]

Loader: test, accuracy: 0.7356
Epoch №69
Loader: train, accuracy: 0.65716


 69%|██████▉   | 69/100 [27:47<12:20, 23.89s/it]

Loader: test, accuracy: 0.72054
Epoch №70
Loader: train, accuracy: 0.6599


 70%|███████   | 70/100 [28:11<11:58, 23.93s/it]

Loader: test, accuracy: 0.73438
Epoch №71
Loader: train, accuracy: 0.65944


 71%|███████   | 71/100 [28:35<11:35, 23.99s/it]

Loader: test, accuracy: 0.73478
Epoch №72
Loader: train, accuracy: 0.6588


 72%|███████▏  | 72/100 [28:59<11:12, 24.02s/it]

Loader: test, accuracy: 0.71904
Epoch №73
Loader: train, accuracy: 0.657


 73%|███████▎  | 73/100 [29:23<10:47, 23.99s/it]

Loader: test, accuracy: 0.73162
Epoch №74
Loader: train, accuracy: 0.65866


 74%|███████▍  | 74/100 [29:47<10:22, 23.94s/it]

Loader: test, accuracy: 0.73292
Epoch №75
Loader: train, accuracy: 0.66236


 75%|███████▌  | 75/100 [30:11<10:01, 24.07s/it]

Loader: test, accuracy: 0.71672
Epoch №76
Loader: train, accuracy: 0.65966


 76%|███████▌  | 76/100 [30:36<09:39, 24.13s/it]

Loader: test, accuracy: 0.73624
Epoch №77
Loader: train, accuracy: 0.66144


 77%|███████▋  | 77/100 [31:00<09:15, 24.14s/it]

Loader: test, accuracy: 0.7259
Epoch №78
Loader: train, accuracy: 0.65854


 78%|███████▊  | 78/100 [31:24<08:52, 24.21s/it]

Loader: test, accuracy: 0.72914
Epoch №79
Loader: train, accuracy: 0.66092


 79%|███████▉  | 79/100 [31:48<08:27, 24.17s/it]

Loader: test, accuracy: 0.73186
Epoch №80
Loader: train, accuracy: 0.66144


 80%|████████  | 80/100 [32:12<08:02, 24.10s/it]

Loader: test, accuracy: 0.73306
Epoch №81
Loader: train, accuracy: 0.66384


 81%|████████  | 81/100 [32:36<07:36, 24.05s/it]

Loader: test, accuracy: 0.73076
Epoch №82
Loader: train, accuracy: 0.6611


 82%|████████▏ | 82/100 [33:00<07:11, 23.97s/it]

Loader: test, accuracy: 0.738
Epoch №83
Loader: train, accuracy: 0.6621


 83%|████████▎ | 83/100 [33:23<06:44, 23.81s/it]

Loader: test, accuracy: 0.73814
Epoch №84
Loader: train, accuracy: 0.66162


 84%|████████▍ | 84/100 [33:47<06:21, 23.87s/it]

Loader: test, accuracy: 0.73326
Epoch №85
Loader: train, accuracy: 0.66672


 85%|████████▌ | 85/100 [34:12<05:59, 23.96s/it]

Loader: test, accuracy: 0.73494
Epoch №86
Loader: train, accuracy: 0.6625


 86%|████████▌ | 86/100 [34:36<05:36, 24.04s/it]

Loader: test, accuracy: 0.72952
Epoch №87
Loader: train, accuracy: 0.6653


 87%|████████▋ | 87/100 [35:00<05:13, 24.11s/it]

Loader: test, accuracy: 0.74314
Epoch №88
Loader: train, accuracy: 0.66796


 88%|████████▊ | 88/100 [35:24<04:49, 24.11s/it]

Loader: test, accuracy: 0.73654
Epoch №89
Loader: train, accuracy: 0.66398


 89%|████████▉ | 89/100 [35:48<04:24, 24.08s/it]

Loader: test, accuracy: 0.74576
Epoch №90
Loader: train, accuracy: 0.66698


 90%|█████████ | 90/100 [36:12<04:00, 24.07s/it]

Loader: test, accuracy: 0.7427
Epoch №91
Loader: train, accuracy: 0.66528


 91%|█████████ | 91/100 [36:36<03:35, 23.95s/it]

Loader: test, accuracy: 0.7434
Epoch №92
Loader: train, accuracy: 0.66678


 92%|█████████▏| 92/100 [36:59<03:10, 23.81s/it]

Loader: test, accuracy: 0.74164
Epoch №93
Loader: train, accuracy: 0.67038


 93%|█████████▎| 93/100 [37:24<02:47, 23.95s/it]

Loader: test, accuracy: 0.74304
Epoch №94
Loader: train, accuracy: 0.6663


 94%|█████████▍| 94/100 [37:48<02:24, 24.05s/it]

Loader: test, accuracy: 0.74228
Epoch №95
Loader: train, accuracy: 0.66638


 95%|█████████▌| 95/100 [38:12<02:00, 24.09s/it]

Loader: test, accuracy: 0.7433
Epoch №96
Loader: train, accuracy: 0.6686


 96%|█████████▌| 96/100 [38:36<01:36, 24.05s/it]

Loader: test, accuracy: 0.74706
Epoch №97
Loader: train, accuracy: 0.6688


 97%|█████████▋| 97/100 [39:00<01:12, 24.07s/it]

Loader: test, accuracy: 0.74292
Epoch №98
Loader: train, accuracy: 0.66762


 98%|█████████▊| 98/100 [39:24<00:48, 24.08s/it]

Loader: test, accuracy: 0.73854
Epoch №99
Loader: train, accuracy: 0.66874


 99%|█████████▉| 99/100 [39:48<00:24, 24.08s/it]

Loader: test, accuracy: 0.74266
Epoch №100
Loader: train, accuracy: 0.6679


100%|██████████| 100/100 [40:12<00:00, 24.12s/it]

Loader: test, accuracy: 0.73566





In [20]:
torch.save(model.state_dict(), 'CIFAR10_v2_100_EPOCHS.pth')