<a href="https://colab.research.google.com/github/Myashka/MIPT_Applied_ML/blob/main/week_13_knowledge_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms

from torchsummary import summary

from tqdm.auto import trange

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = torchvision.datasets.CIFAR10(
    root='datasets/',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.CIFAR10(
    root='datasets/',
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to datasets/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting datasets/cifar-10-python.tar.gz to datasets/
Files already downloaded and verified


In [9]:
class CNN_Block(nn.Module):
    def __init__(self,in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, input):
        return self.block(input)

class CNN(nn.Module):
    def __init__(self, in_channels=3, inter_channels=None, image_size=32, n_classes=10):
        super().__init__()
        
        if inter_channels is None:
            inter_channels = [64, 128, 256]
            
        model = []
        for out_channels in inter_channels:
            model.append(CNN_Block(in_channels, out_channels))
            in_channels = out_channels
        
        model.append(nn.Flatten())
        
        image_size = int(image_size / (2 ** len(inter_channels)))
        in_features = out_channels * image_size * image_size
        out_features = int(in_features / 4)
        model.append(nn.Linear(in_features, out_features))
        model.append(nn.ReLU())
        
        in_features = out_features
        out_features = n_classes
        model.append(nn.Linear(in_features, out_features))
        
        self.model = nn.Sequential(*model)
        
    def forward(self, input):
        return self.model(input)

In [4]:
model = CNN().to(device)
summary(model, input_size=(3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 16, 16]          36,864
       BatchNorm2d-5           [-1, 64, 16, 16]             128
              ReLU-6           [-1, 64, 16, 16]               0
         CNN_Block-7           [-1, 64, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]          73,728
       BatchNorm2d-9          [-1, 128, 16, 16]             256
             ReLU-10          [-1, 128, 16, 16]               0
           Conv2d-11            [-1, 128, 8, 8]         147,456
      BatchNorm2d-12            [-1, 128, 8, 8]             256
             ReLU-13            [-1, 128, 8, 8]               0
        CNN_Block-14            [-1, 12

In [5]:
def train(model, n_epoches=15):
    model.to(device)
    optim = torch.optim.AdamW(model.parameters())
    
    best_epoch, best_accuracy = -1, 0
    for epoch in trange(n_epoches):
        model.train()
        for input, target in train_loader:
            input, target = input.to(device), target.to(device)
            pred = model(input)
            loss = F.cross_entropy(pred, target)
            
            optim.zero_grad()
            loss.backward()
            optim.step()
            
        model.eval()
        correct, total = 0, 0
        for input, target in test_loader:
            input, target = input.to(device), target.to(device)
            pred = model(input)
            pred = pred.argmax(dim=1)
            correct += (pred == target).sum()
            total += target.numel()
        
        accuracy = 100 * correct / total
        print(f"Epoch {epoch}: accuracy {accuracy:.1f}%")
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_epoch = epoch
        
    print(f'Best accuracy {best_accuracy:.1f}% after epoch {best_epoch}')

In [6]:
train(model)

  0%|          | 0/15 [00:00<?, ?it/s]

Epoch 0: accuracy 40.7%
Epoch 1: accuracy 59.4%
Epoch 2: accuracy 67.9%
Epoch 3: accuracy 70.8%
Epoch 4: accuracy 72.8%
Epoch 5: accuracy 75.3%
Epoch 6: accuracy 77.4%
Epoch 7: accuracy 75.6%
Epoch 8: accuracy 74.5%
Epoch 9: accuracy 78.5%
Epoch 10: accuracy 77.2%
Epoch 11: accuracy 78.1%
Epoch 12: accuracy 77.7%
Epoch 13: accuracy 75.8%
Epoch 14: accuracy 76.9%
Best accuracy 78.5% after epoch 9


In [11]:
smaller_model = CNN(inter_channels=[16, 32]).to(device)
summary(smaller_model, input_size=(3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4           [-1, 16, 16, 16]           2,304
       BatchNorm2d-5           [-1, 16, 16, 16]              32
              ReLU-6           [-1, 16, 16, 16]               0
         CNN_Block-7           [-1, 16, 16, 16]               0
            Conv2d-8           [-1, 32, 16, 16]           4,608
       BatchNorm2d-9           [-1, 32, 16, 16]              64
             ReLU-10           [-1, 32, 16, 16]               0
           Conv2d-11             [-1, 32, 8, 8]           9,216
      BatchNorm2d-12             [-1, 32, 8, 8]              64
             ReLU-13             [-1, 32, 8, 8]               0
        CNN_Block-14             [-1, 3

In [12]:
train(smaller_model)

  0%|          | 0/15 [00:00<?, ?it/s]

Epoch 0: accuracy 52.5%
Epoch 1: accuracy 58.0%
Epoch 2: accuracy 61.3%
Epoch 3: accuracy 64.2%
Epoch 4: accuracy 64.9%
Epoch 5: accuracy 65.3%
Epoch 6: accuracy 66.5%
Epoch 7: accuracy 65.4%
Epoch 8: accuracy 67.2%
Epoch 9: accuracy 66.8%
Epoch 10: accuracy 67.6%
Epoch 11: accuracy 67.9%
Epoch 12: accuracy 67.7%
Epoch 13: accuracy 67.6%
Epoch 14: accuracy 67.4%
Best accuracy 67.9% after epoch 11


In [15]:
def train_distill(student_model, teacher_model, n_epoches=15, alpha=0.7, T=5):
    student_model.to(device)
    teacher_model.to(device)
    optim = torch.optim.AdamW(student_model.parameters())

    best_epoch, best_accuracy = -1, 0
    for epoch in trange(n_epoches):
      student_model.train()
      for input, target in train_loader:
        input, target = input.to(device), target.to(device)
        student_pred = student_model(input)
        teacher_pred = teacher_model(input)

        student_logprobs = F.log_softmax(student_pred / T, dim=-1)
        teacher_probs = F.softmax(teacher_pred / T,dim=-1)
        distill_loss = F.kl_div(student_logprobs, teacher_probs)

        ce_loss = F.cross_entropy(student_pred, target)
        loss = alpha * distill_loss * (T ** 2) + (1 - alpha) * ce_loss

        optim.zero_grad()
        loss.backward()
        optim.step()

      student_model.eval()
      correct, total = 0, 0
      for input, target in test_loader:
        input, target = input.to(device), target.to(device)
        pred = student_model(input)
        pred = pred.argmax(dim=1)
        correct += (pred == target).sum()
        total += target.numel()

      accuracy = 100 * correct / total
      print(f"Epoch {epoch}: accuracy {accuracy:.1f}%")
      if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_epoch = epoch
        
    print(f'Best accuracy {best_accuracy:.1f}% after epoch {best_epoch}')
    return student_model

In [18]:
student_model = CNN(inter_channels=[16, 32])
teacher_model = model
student_model = train_distill(student_model, teacher_model, alpha=0.9, T=7)

  0%|          | 0/15 [00:00<?, ?it/s]

Epoch 0: accuracy 43.2%
Epoch 1: accuracy 56.6%
Epoch 2: accuracy 60.5%
Epoch 3: accuracy 64.3%
Epoch 4: accuracy 66.2%
Epoch 5: accuracy 67.1%
Epoch 6: accuracy 66.4%
Epoch 7: accuracy 67.6%
Epoch 8: accuracy 69.6%
Epoch 9: accuracy 70.7%
Epoch 10: accuracy 70.8%
Epoch 11: accuracy 71.0%
Epoch 12: accuracy 70.6%
Epoch 13: accuracy 72.1%
Epoch 14: accuracy 72.5%
Best accuracy 72.5% after epoch 14
