Task 3: Обучить с нуля глубокую сеть на небольшом датасете. Сравнить результат с дообучением предобученной сети

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

# Трансформации и загрузка датасета CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

# Объявление пользовательской модели
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

# Загрузка предобученной модели ResNet18 и настройка под текущую задачу
resnet18 = models.resnet18(pretrained=True)
for param in resnet18.parameters():
    param.requires_grad = False

resnet18.fc = nn.Linear(resnet18.fc.in_features, 10)

# Функция потерь и оптимизатор
criterion = nn.CrossEntropyLoss()
optimizer_net = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer_resnet18 = optim.SGD(resnet18.fc.parameters(), lr=0.001, momentum=0.9)

# Функция для обучения модели
def train_model(model, optimizer, epochs=2):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 2000 == 1999:    # печать каждые 2000 мини-пакетов
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
                running_loss = 0.0
    print('Finished Training')

# Функция для тестирования модели
def test_model(model):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')

# Обучение сети с нуля
print("Training custom network...")
train_model(net, optimizer_net)
test_model(net)

# Дообучение предобученной сети
print("Fine-tuning ResNet18...")
train_model(resnet18, optimizer_resnet18)
test_model(resnet18)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|████████████████████████████████████████████████████████████████| 170498071/170498071 [04:24<00:00, 643506.11it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\Roadmarshal/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|█████████████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:08<00:00, 5.26MB/s]


Training custom network...
[1,  2000] loss: 2.158
[1,  4000] loss: 1.812
[1,  6000] loss: 1.653
[1,  8000] loss: 1.541
[1, 10000] loss: 1.511
[1, 12000] loss: 1.448
[2,  2000] loss: 1.375
[2,  4000] loss: 1.365
[2,  6000] loss: 1.326
[2,  8000] loss: 1.322
[2, 10000] loss: 1.277
[2, 12000] loss: 1.274
Finished Training
Accuracy of the network on the 10000 test images: 55.29%
Fine-tuning ResNet18...
[1,  2000] loss: 2.542
[1,  4000] loss: 2.608
[1,  6000] loss: 2.595
[1,  8000] loss: 2.596
[1, 10000] loss: 2.583
[1, 12000] loss: 2.658
[2,  2000] loss: 2.598
[2,  4000] loss: 2.621
[2,  6000] loss: 2.641
[2,  8000] loss: 2.601
[2, 10000] loss: 2.611
[2, 12000] loss: 2.616
Finished Training
Accuracy of the network on the 10000 test images: 26.19%


снижение качества?! вот это неожиданно :DD