In [1]:
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchsummary import summary

In [2]:
EPOCHS = 20
BATCH_SIZE = 128
LEARNING_RATE = 0.001

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load data

In [3]:
train_transforms = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.4914, 0.4822, 0.4465],
                std=[0.2023, 0.1994, 0.2010])
])

val_transforms = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.4914, 0.4822, 0.4465],
                std=[0.2023, 0.1994, 0.2010])
])

In [4]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transforms)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


# Init model

In [5]:
class TeacherCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(TeacherCNN, self).__init__()
        self.block1 = self._make_conv_block(in_ch=3,   out_ch=64)
        self.block2 = self._make_conv_block(in_ch=64,  out_ch=128)
        self.block3 = self._make_conv_block(in_ch=128, out_ch=256)
        self.block4 = self._make_conv_block(in_ch=256, out_ch=512)

        self.classifier = nn.Sequential(
            nn.Linear(512 * 2 * 2, 2048),  # из 512x2x2 -> 2048
            nn.ReLU(inplace=True),
            nn.Linear(2048, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num_classes)
        )

    def _make_conv_block(self, in_ch, out_ch):
        block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),

            nn.MaxPool2d(kernel_size=2)
        )
        return block

    def forward(self, x):
        x = self.block1(x) 
        x = self.block2(x)  
        x = self.block3(x)  
        x = self.block4(x) 

        x = x.view(x.size(0), -1)  
        x = self.classifier(x)
        return x

teacher_model = TeacherCNN(num_classes=10).to(DEVICE)
teacher_model.load_state_dict(torch.load('./model.pt', weights_only=True))

_ = summary(teacher_model, input_size=(BATCH_SIZE, 3, 32, 32), device=DEVICE, depth=4)

Layer (type:depth-idx)                   Param #
├─Sequential: 1-1                        --
|    └─Conv2d: 2-1                       1,728
|    └─BatchNorm2d: 2-2                  128
|    └─ReLU: 2-3                         --
|    └─Conv2d: 2-4                       36,864
|    └─BatchNorm2d: 2-5                  128
|    └─ReLU: 2-6                         --
|    └─MaxPool2d: 2-7                    --
├─Sequential: 1-2                        --
|    └─Conv2d: 2-8                       73,728
|    └─BatchNorm2d: 2-9                  256
|    └─ReLU: 2-10                        --
|    └─Conv2d: 2-11                      147,456
|    └─BatchNorm2d: 2-12                 256
|    └─ReLU: 2-13                        --
|    └─MaxPool2d: 2-14                   --
├─Sequential: 1-3                        --
|    └─Conv2d: 2-15                      294,912
|    └─BatchNorm2d: 2-16                 512
|    └─ReLU: 2-17                        --
|    └─Conv2d: 2-18                      589,

In [6]:
class StudentCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 16x16

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 8x8
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 8 * 8, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

student_model = StudentCNN(num_classes=10).to(DEVICE)

_ = summary(student_model, input_size=(BATCH_SIZE, 3, 32, 32), device=DEVICE, depth=4)

Layer (type:depth-idx)                   Param #
├─Sequential: 1-1                        --
|    └─Conv2d: 2-1                       896
|    └─BatchNorm2d: 2-2                  64
|    └─ReLU: 2-3                         --
|    └─MaxPool2d: 2-4                    --
|    └─Conv2d: 2-5                       18,496
|    └─BatchNorm2d: 2-6                  128
|    └─ReLU: 2-7                         --
|    └─MaxPool2d: 2-8                    --
├─Sequential: 1-2                        --
|    └─Linear: 2-9                       2,097,664
|    └─ReLU: 2-10                        --
|    └─Linear: 2-11                      5,130
Total params: 2,122,378
Trainable params: 2,122,378
Non-trainable params: 0


# Train

In [7]:
def distillation_loss_fn(student_outputs, teacher_outputs, labels, alpha=0.5, T=2.0):
    ce_loss = nn.functional.cross_entropy(student_outputs, labels)
    kd_loss = nn.functional.kl_div(nn.functional.log_softmax(student_outputs / T, dim=1),        
                       nn.functional.softmax(teacher_outputs / T, dim=1),
                       reduction='batchmean') * (T * T)

    return alpha * ce_loss + (1 - alpha) * kd_loss

In [8]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X_, y_ in loader:
            X_, y_ = X_.to(DEVICE), y_.to(DEVICE)
            outputs = model(X_)
            _, preds = torch.max(outputs, 1)
            correct += (preds == y_).sum().item()
            total += y_.size(0)
    return 100.0 * correct / total

In [9]:
teacher_model.eval()
optimizer = optim.Adam(student_model.parameters(), lr=LEARNING_RATE)

alpha=0.5
T=2.0

for epoch in range(EPOCHS):
    student_model.train()
    train_loss = 0
    correct, total = 0, 0
    for X_, y_ in train_loader:
        X_, y_ = X_.to(DEVICE), y_.to(DEVICE)

        with torch.no_grad():
            teacher_outputs = teacher_model(X_)
        student_outputs = student_model(X_)

        # print(y_)
        loss = distillation_loss_fn(student_outputs, teacher_outputs, y_,
                                    alpha=alpha, T=T)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        _, preds = torch.max(student_outputs, 1)
        correct += (preds == y_).sum().item()
        total += y_.size(0)

    train_acc = 100.0 * correct / total
    val_acc = evaluate(student_model, val_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}]")
    print(f"Train Loss: {train_loss/len(train_loader):.4f}")
    print(f"Train Acc:   {train_acc:.2f}% | Val Acc:   {val_acc:.2f}%")

Epoch [1/20]
Train Loss: 2.7757
Train Acc:   47.17% | Val Acc:   58.38%
Epoch [2/20]
Train Loss: 2.0446
Train Acc:   59.61% | Val Acc:   65.83%
Epoch [3/20]
Train Loss: 1.7780
Train Acc:   64.48% | Val Acc:   67.75%
Epoch [4/20]
Train Loss: 1.6156
Train Acc:   67.19% | Val Acc:   70.36%
Epoch [5/20]
Train Loss: 1.5076
Train Acc:   69.18% | Val Acc:   72.18%
Epoch [6/20]
Train Loss: 1.4231
Train Acc:   70.48% | Val Acc:   73.21%
Epoch [7/20]
Train Loss: 1.3465
Train Acc:   72.04% | Val Acc:   73.07%
Epoch [8/20]
Train Loss: 1.2844
Train Acc:   73.04% | Val Acc:   74.69%
Epoch [9/20]
Train Loss: 1.2454
Train Acc:   73.67% | Val Acc:   74.25%
Epoch [10/20]
Train Loss: 1.2027
Train Acc:   74.41% | Val Acc:   76.96%
Epoch [11/20]
Train Loss: 1.1604
Train Acc:   75.05% | Val Acc:   76.59%
Epoch [12/20]
Train Loss: 1.1277
Train Acc:   75.81% | Val Acc:   77.47%
Epoch [13/20]
Train Loss: 1.0942
Train Acc:   76.30% | Val Acc:   77.06%
Epoch [14/20]
Train Loss: 1.0616
Train Acc:   76.69% | Val A

# Eval

In [12]:
param_size = 0
for param in student_model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in student_model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('Model size: {:.3f}MB'.format(size_all_mb))

Model size: 8.097MB


In [13]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

params = count_parameters(student_model)
print("Params:", params)

Params: 2122378


In [15]:
inp = torch.randn(1, 3, 32, 32)

student_model.to('cpu')
num_samples = 100
start_time = time.time()
for _ in range(num_samples):
    output = student_model(inp)
end_time = time.time()

infer_time = ((end_time - start_time) / num_samples) * 1000
print(f'CPU Avg inference time: {infer_time:.4f} ms')

CPU Avg inference time: 0.3996 ms


In [17]:
student_model.to(DEVICE)
start_time = time.time()
for _ in range(num_samples):
    output = student_model(inp.to(DEVICE))
end_time = time.time()

infer_time = ((end_time - start_time) / num_samples) * 1000
print(f'GPU Avg inference time: {infer_time:.4f} ms')

GPU Avg inference time: 0.1437 ms


In [19]:
torch.save(student_model.state_dict(), './student_model.pt')