### Импорт библиотек

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

### Определение функций

In [9]:

def save_model(model, epoch, final_model=False):
  '''
    Сохранение модели в файл
  '''
  fulename = './models/model.pth' if final_model else f'./models/model_epoch_{str(epoch)}.pth'
  torch.save(model.state_dict(), fulename)


def load_model(model, epoch):
  '''
    Загрузка модели из файла
  '''
  filename = f'./models/model_epoch_{str(epoch)}.pth'
  return model.load_state_dict(torch.load(filename))


def get_accuracy(model, test_loader):
  '''
    Вычисление метрики accuracy
  '''
  model.eval()
  correct = 0
  total = 0

  for data in test_loader:
    images, labels = data
    output = model(images)
    _, predicted = torch.max(output.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

  return 100 * correct / total


def get_loss(current_loss):
  '''
    Вычисление фнкции потерь
  '''
  return current_loss / 50000

### Параметры набора данных

In [10]:
batch_size = 50
num_workers = 4

### Параметры трансформации данных

In [11]:
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [12]:
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)

test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)

Files already downloaded and verified
Files already downloaded and verified


### Описание класса модели

In [13]:
class CifarClassifier(nn.Module):
  def __init__(self):
    super(CifarClassifier, self).__init__()

    self.conv1 = nn.Conv2d(in_channels=3, out_channels=48, kernel_size=(3,3), padding=(1,1))
    self.conv2 = nn.Conv2d(in_channels=48, out_channels=96, kernel_size=(3,3), padding=(1,1))
    self.conv3 = nn.Conv2d(in_channels=96, out_channels=192, kernel_size=(3,3), padding=(1,1))
    self.conv4 = nn.Conv2d(in_channels=192, out_channels=256, kernel_size=(3,3), padding=(1,1))
    self.pool = nn.MaxPool2d(2,2)
    self.fc1 = nn.Linear(in_features=8*8*256, out_features=512)
    self.fc2 = nn.Linear(in_features=512, out_features=64)
    self.Dropout = nn.Dropout(0.25)
    self.fc3 = nn.Linear(in_features=64, out_features=10)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    x = self.pool(x)
    x = self.Dropout(x)
    x = F.relu(self.conv3(x))
    x = F.relu(self.conv4(x))
    x = self.pool(x)
    x = self.Dropout(x)
    x = x.view(-1, 8*8*256)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.Dropout(x)
    x = self.fc3(x)
    return x

### Задание параметров, создание экземпляра и обучение модели

In [14]:
model = CifarClassifier()

epochs = 10
started_epoch = 1
use_pretrained_model = False
learning_rate = 0.001

if use_pretrained_model:
  model = load_model(model, started_epoch)

criterion = nn.CrossEntropyLoss()

for epoch in range(started_epoch, epochs + 1):
  current_loss = 0.0
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  for i, data in enumerate(train_loader):
    inputs, labels = data
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)

    loss.backward()
    optimizer.step()

    current_loss += loss.item()

  save_model(model, epoch)

  print(f'Epoch: {epoch}, accuracy: {get_accuracy(model, test_loader)}')
  print(f'Epoch: {epoch}, loss: {get_loss(current_loss)}')

save_model(model, epoch, final_model=True)

Epoch: 1, accuracy: 47.84000015258789
Epoch: 1, loss: 0.035449444587230684
Epoch: 2, accuracy: 62.58000183105469
Epoch: 2, loss: 0.025338293136358263
Epoch: 3, accuracy: 70.26000213623047
Epoch: 3, loss: 0.02018069205760956
Epoch: 4, accuracy: 72.86000061035156
Epoch: 4, loss: 0.017539769438505173
Epoch: 5, accuracy: 74.30999755859375
Epoch: 5, loss: 0.015944108557105065
Epoch: 6, accuracy: 76.66000366210938
Epoch: 6, loss: 0.014666720368862152
Epoch: 7, accuracy: 76.87000274658203
Epoch: 7, loss: 0.014003162535429001
Epoch: 8, accuracy: 78.94999694824219
Epoch: 8, loss: 0.013054108299911021
Epoch: 9, accuracy: 79.4000015258789
Epoch: 9, loss: 0.012554109606742858
Epoch: 10, accuracy: 80.30000305175781
Epoch: 10, loss: 0.011958377012014389
