In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Загрузим датасет


In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=[0.48501961, 0.45795686, 0.40760392], std=[0.229, 0.224, 0.225])])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


ResNet

In [0]:
resnet18 = models.resnet18(pretrained=True)
resnet18.fc = nn.Linear(512, 10)
resnet18 = resnet18.to(device)
optimizer = optim.Adam(resnet18.parameters(), lr = 1e-4)
criterion = nn.CrossEntropyLoss()

In [0]:
def train(model, optimizer, criterion, epoch, alpha = 1, teacher = None):
    model.train()

    for e in range(epoch):
        total_loss = 0
        total_correct = 0
        for img, label in trainloader:
            img = img.to(device)
            label = label.to(device)

            optimizer.zero_grad()
            model_label = model(img)
            loss_teacher = 0
            if teacher is not None:
                teacher_label = teacher(img)
                loss_teacher = F.mse_loss(F.softmax(model_label,-1), F.softmax(teacher_label,-1))
            else:
                alpha = 1
            loss = alpha*criterion(model_label, label) + (1 - alpha)*loss_teacher

            loss.backward()
            optimizer.step()
            total_loss+=loss
            _, predict_labels = torch.max(F.softmax(model_label,-1), 1)
            total_correct+= torch.sum(predict_labels == label.data)
    print("Epoch: {}, epoch_loss: {:.5}, epoch_correct: {}".format((e+1), total_loss, total_correct))
    return model

In [0]:
def test(model):
    acc = 0.0
    for img, label in testloader:
        img = img.to(device)
        label = label.to(device)

        model_label = model(img)
        _, predict_labels = torch.max(F.softmax(model_label,-1), 1)
        acc += torch.sum(predict_labels == label.data)
    return acc/len(testset)

In [7]:
teacher = train(resnet18, optimizer, criterion, 10)

Epoch: 1, epoch_loss: 1545.5, epoch_correct: 33011
Epoch: 2, epoch_loss: 985.5, epoch_correct: 39160
Epoch: 3, epoch_loss: 746.78, epoch_correct: 41803
Epoch: 4, epoch_loss: 563.1, epoch_correct: 43762
Epoch: 5, epoch_loss: 441.34, epoch_correct: 45065
Epoch: 6, epoch_loss: 337.31, epoch_correct: 46249
Epoch: 7, epoch_loss: 270.61, epoch_correct: 47029
Epoch: 8, epoch_loss: 223.82, epoch_correct: 47471
Epoch: 9, epoch_loss: 187.55, epoch_correct: 47929
Epoch: 10, epoch_loss: 159.47, epoch_correct: 48281


In [8]:
teacher.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
acc = test(teacher)
print("Accuracy = ", acc.item())

Accuracy =  0.8287999629974365


In [0]:
class Learner(nn.Module):
    def __init__(self):
        super(Learner, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, kernel_size = 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size = 5)
        self.linear1 = nn.Linear(16* 5* 5, 64)
        self.linear2 = nn.Linear(64, 10)

    def forward(self, img):
        x = F.relu(self.conv1(img))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 16* 5* 5)
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
    return x




Посмотрим насколько может обучиться маленькая модель

In [11]:
learner_model = Learner()
learner_model = learner_model.to(device)
learner_opt = optim.Adam(learner_model.parameters(), lr=1e-3)
learner_model = train(learner_model, learner_opt, criterion, 10)

Epoch: 1, epoch_loss: 2457.7, epoch_correct: 21453
Epoch: 2, epoch_loss: 2005.2, epoch_correct: 27185
Epoch: 3, epoch_loss: 1834.1, epoch_correct: 29266
Epoch: 4, epoch_loss: 1717.9, epoch_correct: 30662
Epoch: 5, epoch_loss: 1631.4, epoch_correct: 31679
Epoch: 6, epoch_loss: 1560.1, epoch_correct: 32363
Epoch: 7, epoch_loss: 1512.7, epoch_correct: 33084
Epoch: 8, epoch_loss: 1465.0, epoch_correct: 33547
Epoch: 9, epoch_loss: 1428.2, epoch_correct: 33932
Epoch: 10, epoch_loss: 1400.2, epoch_correct: 34164


In [12]:
acc = test(learner_model)
print("Learner Accuracy = ", acc.item())

Learner Accuracy =  0.6327999830245972


In [21]:
learner = Learner()
learner = learner.to(device)
opt = optim.Adam(learner.parameters(), lr=1e-3)
learner_distillation = train(learner, opt, criterion, 10, 0.5, teacher)

Epoch: 1, epoch_loss: 1268.2, epoch_correct: 21786
Epoch: 2, epoch_loss: 1053.1, epoch_correct: 26807
Epoch: 3, epoch_loss: 956.42, epoch_correct: 29173
Epoch: 4, epoch_loss: 892.95, epoch_correct: 30559
Epoch: 5, epoch_loss: 850.71, epoch_correct: 31565
Epoch: 6, epoch_loss: 814.8, epoch_correct: 32325
Epoch: 7, epoch_loss: 788.68, epoch_correct: 32894
Epoch: 8, epoch_loss: 764.15, epoch_correct: 33565
Epoch: 9, epoch_loss: 746.09, epoch_correct: 33979
Epoch: 10, epoch_loss: 728.79, epoch_correct: 34218


In [22]:
acc = test(learner_distillation)
print("Distillation Accuracy = ", acc.item())

Distillation Accuracy =  0.6523000001907349
