In [1]:
# HW №5
# Distilation ResNet-18 -> CNN

In [2]:
import torch, torch.nn as nn, torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
import time
import io
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
sns.set(style='whitegrid')

In [4]:
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [5]:
BATCH_SIZE = 128
mean, std = (0.4914,0.4822,0.4465), (0.2470,0.2435,0.2616)

In [6]:
train_tfms = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean,std),
])
test_tfms = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean,std),
])

In [7]:
train_ds = datasets.CIFAR10('.', train=True , download=True, transform=train_tfms)
test_ds  = datasets.CIFAR10('.', train=False, download=True, transform=test_tfms)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 35.1MB/s] 


Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


In [8]:
def accuracy(model, loader):
    model.eval(); correct=total=0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            pred = model(x).argmax(1)
            correct += (pred==y).sum().item()
            total   += y.size(0)
    return 100*correct/total

In [9]:
def count_params(model, trainable_only=False):
    return sum(p.numel() for p in model.parameters() if (p.requires_grad or not trainable_only))

In [10]:
def model_size_mb(model):
    buffer = io.BytesIO()
    torch.save(model.state_dict(), buffer)
    return buffer.getbuffer().nbytes/1024/1024

In [11]:
@torch.no_grad()
def inference_speed(model, shape=(1,3,224,224), reps=100):
    model.eval(); x=torch.randn(shape).to(device)
    for _ in range(10): _=model(x)
    torch.cuda.synchronize(); t0=time.time()
    for _ in range(reps): _=model(x)
    torch.cuda.synchronize()
    return (time.time()-t0)/reps*1000

In [12]:
teacher = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
teacher.fc = nn.Linear(512,10)
teacher.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 184MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [13]:
for p in teacher.parameters(): p.requires_grad=False
for p in teacher.fc.parameters(): p.requires_grad=True

criterion = nn.CrossEntropyLoss()
opt_t = torch.optim.Adam(teacher.fc.parameters(), lr=3e-4)

EPOCHS_T = 2
for ep in range(EPOCHS_T):
    teacher.train()
    for x,y in tqdm(train_loader, desc=f'Teacher {ep+1}/{EPOCHS_T}'):
        x,y=x.to(device),y.to(device)
        opt_t.zero_grad()
        loss=criterion(teacher(x),y)
        loss.backward(); opt_t.step()
    print(f'Val acc: {accuracy(teacher,test_loader):.2f}%')

Teacher 1/2: 100%|██████████| 391/391 [00:56<00:00,  6.94it/s]


Val acc: 74.85%


Teacher 2/2: 100%|██████████| 391/391 [00:54<00:00,  7.24it/s]


Val acc: 77.02%


In [14]:
#CNN
class StudentNet(nn.Module):
    def __init__(self,nc=10):
        super().__init__()
        self.feat = nn.Sequential(
            nn.Conv2d(3,32,3,1,1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,1,1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64,128,3,1,1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128,256,3,1,1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.cls = nn.Linear(256,nc)
    def forward(self,x):
        x=self.feat(x); x=torch.flatten(x,1)
        return self.cls(x)

In [15]:
student = StudentNet().to(device)

In [16]:
T=4.0; ALPHA=0.5
opt_s = torch.optim.Adam(student.parameters(), lr=1e-3)
kl = nn.KLDivLoss(reduction='batchmean'); ce = nn.CrossEntropyLoss()
teacher.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [17]:
def kd_loss(s,t,y):
    l_kd = kl(F.log_softmax(s/T,1), F.softmax(t/T,1))*(T*T)
    l_ce = ce(s,y)
    return ALPHA*l_kd + (1-ALPHA)*l_ce

In [18]:
EPOCHS_S=8
for ep in range(EPOCHS_S):
    student.train()
    for x,y in tqdm(train_loader, desc=f'Student {ep+1}/{EPOCHS_S}'):
        x,y=x.to(device),y.to(device)
        with torch.no_grad(): t_logits=teacher(x)
        s_logits=student(x)
        loss=kd_loss(s_logits,t_logits,y)
        opt_s.zero_grad(); loss.backward(); opt_s.step()
    print(f'Val acc: {accuracy(student,test_loader):.2f}%')

Student 1/8: 100%|██████████| 391/391 [01:28<00:00,  4.40it/s]


Val acc: 49.34%


Student 2/8: 100%|██████████| 391/391 [01:28<00:00,  4.42it/s]


Val acc: 53.64%


Student 3/8: 100%|██████████| 391/391 [01:28<00:00,  4.41it/s]


Val acc: 54.15%


Student 4/8: 100%|██████████| 391/391 [01:28<00:00,  4.41it/s]


Val acc: 56.52%


Student 5/8: 100%|██████████| 391/391 [01:28<00:00,  4.42it/s]


Val acc: 61.52%


Student 7/8: 100%|██████████| 391/391 [01:28<00:00,  4.41it/s]


Val acc: 52.61%


Student 8/8: 100%|██████████| 391/391 [01:28<00:00,  4.41it/s]


Val acc: 58.28%


In [19]:
teacher_acc  = accuracy(teacher,test_loader)
student_acc  = accuracy(student,test_loader)

teacher_ms = inference_speed(teacher)
student_ms = inference_speed(student)

teacher_params = count_params(teacher) / 1e6
student_params = count_params(student) / 1e6

teacher_mb = model_size_mb(teacher)
student_mb = model_size_mb(student)

In [20]:
results = pd.DataFrame({
    'Модель':['Teacher (ResNet18)', 'Student (CNN)'],
    'Параметры, млн':[teacher_params, student_params],
    'Вес, MB':[teacher_mb, student_mb],
    'Точность, %':[teacher_acc, student_acc],
    'Ср. время инференса, ms':[teacher_ms, student_ms]
})

In [21]:
improvements = pd.DataFrame({
    'Модель':['% выигрыш Student'],
    'Параметры, млн':[f"-{(1-student_params/teacher_params)*100:.1f}%"],
    'Вес, MB':[f"-{(1-student_mb/teacher_mb)*100:.1f}%"],
    'Точность, %':[f"-{teacher_acc-student_acc:.2f} pp"],
    'Ср. время инференса, ms':[f"+{(teacher_ms/student_ms-1)*100:.1f}% faster"]
})

In [22]:
display(results.style.format({'Параметры, млн':'{:.2f}',
                              'Вес, MB':'{:.1f}',
                              'Точность, %':'{:.2f}',
                              'Ср. время инференса, ms':'{:.2f}'}).background_gradient(cmap='YlGnBu'))

Unnamed: 0,Модель,"Параметры, млн","Вес, MB","Точность, %","Ср. время инференса, ms"
0,Teacher (ResNet18),11.18,42.7,77.02,2.2
1,Student (CNN),0.39,1.5,58.28,0.63


In [23]:
display(improvements)

Unnamed: 0,Модель,"Параметры, млн","Вес, MB","Точность, %","Ср. время инференса, ms"
0,% выигрыш Student,-96.5%,-96.5%,-18.74 pp,+248.6% faster


In [25]:
print(f"""Итоги:
• Параметры ↓  : {teacher_params:.2f} M → {student_params:.2f} M  ({(1-student_params/teacher_params)*100:.1f}% меньше)
• Вес модели ↓: {teacher_mb:.1f} MB → {student_mb:.1f} MB ({(1-student_mb/teacher_mb)*100:.1f}% меньше)
• Скорость ↑  : {teacher_ms:.2f} ms → {student_ms:.2f} ms ({(teacher_ms/student_ms):.1f}× быстрее)
• Точность    : {teacher_acc:.2f}% → {student_acc:.2f}%  (-{teacher_acc-student_acc:.2f} pp)
""")

Итоги:
• Параметры ↓  : 11.18 M → 0.39 M  (96.5% меньше)
• Вес модели ↓: 42.7 MB → 1.5 MB (96.5% меньше)
• Скорость ↑  : 2.20 ms → 0.63 ms (3.5× быстрее)
• Точность    : 77.02% → 58.28%  (-18.74 pp)

