在MNIST数据集上，从头训练教师网络、从头训练学生网络、知识蒸馏训练学生网络，比较性能。



# 导入工具包

In [3]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm

In [4]:
# 设置随机数种子，便于复现
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
device

device(type='cpu')

In [6]:
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True

# 载入MNIST数据集

In [7]:
# 载入训练集
train_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

# 载入测试集
test_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

# 生成dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



# 教师模型

In [8]:
class TeacherModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(TeacherModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc3(x)
        
        return x  

# 从头训练教师模型

In [9]:
model = TeacherModel()
model = model.to(device)

In [10]:
summary(model)

Layer (type:depth-idx)                   Param #
TeacherModel                             --
├─ReLU: 1-1                              --
├─Linear: 1-2                            942,000
├─Linear: 1-3                            1,441,200
├─Linear: 1-4                            12,010
├─Dropout: 1-5                           --
Total params: 2,395,210
Trainable params: 2,395,210
Non-trainable params: 0

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [12]:
epochs = 6
for epoch in range(epochs):
    model.train()
    
    # 训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        # 前向预测
        preds = model(data)
        loss = criterion(preds, targets)
        
        # 反向传播，优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # 测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1, acc))

100%|███████████████████████████████████████| 1875/1875 [00:26<00:00, 71.23it/s]


Epoch:1	 Accuracy:0.9418


100%|███████████████████████████████████████| 1875/1875 [00:26<00:00, 71.46it/s]


Epoch:2	 Accuracy:0.9622


100%|███████████████████████████████████████| 1875/1875 [00:28<00:00, 66.57it/s]


Epoch:3	 Accuracy:0.9691


100%|███████████████████████████████████████| 1875/1875 [00:27<00:00, 68.60it/s]


Epoch:4	 Accuracy:0.9736


100%|███████████████████████████████████████| 1875/1875 [00:30<00:00, 61.19it/s]


Epoch:5	 Accuracy:0.9774


100%|███████████████████████████████████████| 1875/1875 [00:27<00:00, 68.37it/s]


Epoch:6	 Accuracy:0.9792


In [13]:
teacher_model = model

# 学生模型

In [14]:
class StudentModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(StudentModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        # x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        # x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc3(x)
        
        return x  

# 从头训练学生模型

In [15]:
model = StudentModel()
model = model.to(device)

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [17]:
epochs = 3
for epoch in range(epochs):
    model.train()
    
    # 训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        # 前向预测
        preds = model(data)
        loss = criterion(preds, targets)
        
        # 反向传播，优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # 测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1, acc))

100%|██████████████████████████████████████| 1875/1875 [00:04<00:00, 415.71it/s]


Epoch:1	 Accuracy:0.8632


100%|██████████████████████████████████████| 1875/1875 [00:04<00:00, 425.10it/s]


Epoch:2	 Accuracy:0.8936


100%|██████████████████████████████████████| 1875/1875 [00:04<00:00, 417.05it/s]


Epoch:3	 Accuracy:0.9026


In [18]:
student_model_scratch = model

# 知识蒸馏训练学生模型

In [19]:
# 准备预训练好的教师模型
teacher_model.eval()

# 准备新的学生模型
model = StudentModel()
model = model.to(device)
model.train()

# 蒸馏温度
temp = 7

In [20]:
# hard_loss
hard_loss = nn.CrossEntropyLoss()
# hard_loss 权重
alpha = 0.3

# soft_loss
soft_loss = nn.KLDivLoss(reduction="batchmean")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [21]:
epochs = 3
for epoch in range(epochs):
    
    # 训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        # 教师模型预测
        with torch.no_grad():
            teacher_preds = teacher_model(data)
        
        # 学生模型预测
        student_preds = model(data)
        # 计算hard_loss
        student_loss = hard_loss(student_preds, targets)
        
        # 计算蒸馏后的预测结果及soft_loss
        ditillation_loss = soft_loss(
            F.softmax(student_preds / temp, dim=1),
            F.softmax(teacher_preds / temp, dim=1)
        )
        
        # 将hard_loss和soft_loss加权求和
        loss = alpha * student_loss + (1 - alpha) * ditillation_loss

        # 反向传播，优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # 测试集上评估模型性能
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch+1, acc))

100%|██████████████████████████████████████| 1875/1875 [00:10<00:00, 185.49it/s]


Epoch:1	 Accuracy:0.8576


100%|██████████████████████████████████████| 1875/1875 [00:09<00:00, 202.37it/s]


Epoch:2	 Accuracy:0.8896


100%|██████████████████████████████████████| 1875/1875 [00:09<00:00, 207.49it/s]


Epoch:3	 Accuracy:0.8975
