在 MNIST 数据集上做一个知识蒸馏的小 demo.

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

from tqdm import tqdm

In [3]:
torch.manual_seed(42)  # 固定 CPU 随机数种子
torch.cuda.manual_seed(42)  # 固定 GPU 随机数种子

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
train_data = torchvision.datasets.MNIST(
    root='../data/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

test_data = torchvision.datasets.MNIST(
    root='../data/',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

train_loader = data.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = data.DataLoader(dataset=test_data, batch_size=32, shuffle=False)

定义 Teacher 模型

In [5]:
class TeacherModel(nn.Module):
    def __init__(self, num_classes=10, dropout_rate=0.5):
        super(TeacherModel, self).__init__()
        self.fc1 = nn.Linear(in_features=784, out_features=1200)
        self.fc2 = nn.Linear(in_features=1200, out_features=1200)
        self.fc3 = nn.Linear(in_features=1200, out_features=num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = x.reshape(-1, 784)
        x = self.relu(self.dropout(self.fc1(x)))
        x = self.relu(self.dropout(self.fc2(x)))
        x = self.fc3(x)

        return x

定义 Student 模型

In [6]:
class StudentModel(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentModel, self).__init__()
        self.fc1 = nn.Linear(in_features=784, out_features=20)
        self.fc2 = nn.Linear(in_features=20, out_features=20)
        self.fc3 = nn.Linear(in_features=20, out_features=num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.reshape(-1, 784)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)

        return x

训练 Teacher 模型

In [7]:
teacher_model = TeacherModel().to(device)
optimizer = optim.Adam(teacher_model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

In [8]:
# 训练 Teacher 模型
teacher_model.train()
epochs = 5
for epoch in range(epochs):
    for images, labels in tqdm(train_loader, desc='Epoch {}/{}'.format(epoch + 1, epochs)):
        images = images.to(device)
        labels = labels.to(device)

        # 前向传播
        preds = teacher_model(images)
        loss = criterion(preds, labels)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Epoch 1/5: 100%|██████████| 1875/1875 [01:39<00:00, 18.84it/s]
Epoch 2/5: 100%|██████████| 1875/1875 [01:43<00:00, 18.16it/s]
Epoch 3/5: 100%|██████████| 1875/1875 [01:43<00:00, 18.08it/s]
Epoch 4/5: 100%|██████████| 1875/1875 [01:46<00:00, 17.57it/s]
Epoch 5/5: 100%|██████████| 1875/1875 [01:54<00:00, 16.43it/s]


In [9]:
# 评估 Teacher 模型
teacher_model.eval()
num_correct = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        preds = teacher_model(images)
        num_correct += preds.argmax(dim=1).eq(labels).sum().item()

    print('Teacher 模型预测正确的数量为: {}'.format(num_correct))

100%|██████████| 313/313 [00:03<00:00, 93.96it/s] 

Teacher 模型预测正确的数量为: 9812





不使用知识蒸馏技术，训练 Student 模型

In [10]:
student_model = StudentModel().to(device)
optimizer = optim.Adam(student_model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

In [11]:
# 训练 Student 模型
student_model.train()
epochs = 5
for epoch in range(epochs):
    for images, labels in tqdm(train_loader, desc='Epoch {}/{}'.format(epoch + 1, epochs)):
        images = images.to(device)
        labels = labels.to(device)

        # 前向传播
        preds = student_model(images)
        loss = criterion(preds, labels)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Epoch 1/5: 100%|██████████| 1875/1875 [00:19<00:00, 94.03it/s] 
Epoch 2/5: 100%|██████████| 1875/1875 [00:16<00:00, 111.59it/s]
Epoch 3/5: 100%|██████████| 1875/1875 [00:18<00:00, 102.66it/s]
Epoch 4/5: 100%|██████████| 1875/1875 [00:23<00:00, 80.87it/s]
Epoch 5/5: 100%|██████████| 1875/1875 [00:24<00:00, 75.97it/s]


In [12]:
# 评估 Student 模型
student_model.eval()
num_correct = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        preds = student_model(images)
        num_correct += preds.argmax(dim=1).eq(labels).sum().item()

    print('Student 模型预测正确的数量为: {}'.format(num_correct))

100%|██████████| 313/313 [00:03<00:00, 95.71it/s] 


Student 模型预测正确的数量为: 9322


使用知识蒸馏来训练 Student 模型

In [16]:
# 准备 Teacher 模型
teacher_model.eval()
# 准备新的 Student 模型
student_model = StudentModel().to(device)
# 定义 hard loss
hard_loss = nn.CrossEntropyLoss()
# 定义 soft loss
soft_loss = nn.KLDivLoss(reduction='batchmean')
# hard loss 占比
alpha = 0.5
# 蒸馏温度
temperature = 20
# 定义优化器
optimizer = optim.Adam(student_model.parameters(), lr=3e-4)

In [17]:
# 训练 Student 模型
student_model.train()
epochs = 5
for epoch in range(epochs):
    for images, labels in tqdm(train_loader, desc='Epoch {}/{}'.format(epoch + 1, epochs)):
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            # Teacher 模型的预测结果
            teacher_preds = teacher_model(images)

        # 前向传播
        student_preds = student_model(images)
        student_loss = hard_loss(student_preds, labels)
        distillation_loss = soft_loss(
            torch.softmax(student_preds / temperature, dim=1),
            torch.softmax(teacher_preds / temperature, dim=1)
        )
        loss = alpha * student_loss + (1 - alpha) * distillation_loss

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Epoch 1/5: 100%|██████████| 1875/1875 [00:33<00:00, 55.87it/s]
Epoch 2/5: 100%|██████████| 1875/1875 [00:33<00:00, 55.93it/s]
Epoch 3/5: 100%|██████████| 1875/1875 [00:33<00:00, 55.46it/s]
Epoch 4/5: 100%|██████████| 1875/1875 [00:35<00:00, 53.15it/s]
Epoch 5/5: 100%|██████████| 1875/1875 [00:32<00:00, 57.81it/s]


In [18]:
# 评估 Student 模型
student_model.eval()
num_correct = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        preds = student_model(images)
        num_correct += preds.argmax(dim=1).eq(labels).sum().item()

    print('使用知识蒸馏训练的 Student 模型预测正确的数量为: {}'.format(num_correct))

100%|██████████| 313/313 [00:02<00:00, 150.55it/s]

使用知识蒸馏训练的 Student 模型预测正确的数量为: 9332



