## 知识蒸馏-mnist手写数据集

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import utils.calculate_param as cp
from tqdm import tqdm

### 数据准备

In [2]:
# 没有就下载-手写数据集
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.ToTensor(),
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.ToTensor(),
)

In [3]:
batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break


Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


### 设备准备

In [4]:
# 获取设备类型
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

### 定义教师模型

In [5]:
class TeacherModel(nn.Module):
    def __init__(self, num_classes= 10):
        super(TeacherModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p = 0.5)

    # 处理逻辑：fc1->dropout->relu->fc2->dropout->relu->fc3
    def forward(self, x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)
        return x


### 教师模型设置

In [6]:
model = TeacherModel().to(device)

criterion = nn.CrossEntropyLoss() # 设置使用交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4) # 使用Adam优化器，学习率为lr=1e-4

### 教师模型信息

In [7]:
# 输出模型的参数信息-100w参数
cp.get_summary(model, input_size=(64, 1, 28,28))

torchinfo信息如下：
Layer (type:depth-idx)                   Output Shape              Param #
TeacherModel                             [64, 10]                  --
├─Linear: 1-1                            [64, 1200]                942,000
├─Dropout: 1-2                           [64, 1200]                --
├─ReLU: 1-3                              [64, 1200]                --
├─Linear: 1-4                            [64, 1200]                1,441,200
├─Dropout: 1-5                           [64, 1200]                --
├─ReLU: 1-6                              [64, 1200]                --
├─Linear: 1-7                            [64, 10]                  12,010
Total params: 2,395,210
Trainable params: 2,395,210
Non-trainable params: 0
Total mult-adds (M): 153.29
Input size (MB): 0.20
Forward/backward pass size (MB): 1.23
Params size (MB): 9.58
Estimated Total Size (MB): 11.02


### 教师模型训练&评估

In [8]:
epochs = 6 # 训练6轮
for epoch in range(epochs):
    model.train()
    for data,targets in tqdm(train_dataloader):
        # 前向预测
        outputs = model(data)
        loss = criterion(outputs, targets)

        # 反向传播，优化权重
        optimizer.zero_grad()  # 把梯度置为0
        loss.backward()
        optimizer.step()

    # 测试集上评估性能 进入评估模式
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x,y in test_dataloader:
            outputs = model(x)
            predictions = outputs.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct / num_samples).item()

    model.train()  # 再次进入训练模式
    print("Epoch:{}\t Accuracy:{:4f}".format(epoch + 1, acc))

# 暂存教师模型为teacher_model
teacher_model = model


100%|██████████| 938/938 [00:21<00:00, 44.27it/s]


Epoch:1	 Accuracy:0.830500


100%|██████████| 938/938 [00:21<00:00, 44.47it/s]


Epoch:2	 Accuracy:0.845900


100%|██████████| 938/938 [00:21<00:00, 42.76it/s]


Epoch:3	 Accuracy:0.857800


100%|██████████| 938/938 [00:20<00:00, 46.57it/s]


Epoch:4	 Accuracy:0.865400


100%|██████████| 938/938 [00:20<00:00, 46.21it/s]


Epoch:5	 Accuracy:0.867100


100%|██████████| 938/938 [00:20<00:00, 46.11it/s]


Epoch:6	 Accuracy:0.872900


### 定义学生模型

In [10]:
# 学生模型
class StudentModel(nn.Module):
    def __init__( self, num_class=10):
        super(StudentModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_class)
        self.dropout = nn.Dropout(p = 0.5)

    # 处理逻辑：fc1->relu->fc2->relu->fc3
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        # x = self.dropout(x)
        x = self.relu(x)

        # x = self.fc2(x)
        # x = self.dropout(x)
        # x = self.relu(x)

        x = self.fc3(x)
        return x


### 学生模型设置

In [11]:
# 从头先训练一下学生模型
model = StudentModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)

### 学生模型信息

In [13]:
# 输出模型的参数信息-1w参数
cp.get_summary(model, input_size=(64, 1, 28,28))

torchinfo信息如下：
Layer (type:depth-idx)                   Output Shape              Param #
StudentModel                             [64, 10]                  210
├─Linear: 1-1                            [64, 20]                  15,700
├─ReLU: 1-2                              [64, 20]                  --
├─Linear: 1-3                            [64, 20]                  420
├─Linear: 1-7                            [64, 10]                  (recursive)
├─Dropout: 1-5                           --                        --
├─ReLU: 1-6                              [64, 20]                  --
├─Linear: 1-7                            [64, 10]                  (recursive)
Total params: 16,330
Trainable params: 16,330
Non-trainable params: 0
Total mult-adds (M): 1.06
Input size (MB): 0.20
Forward/backward pass size (MB): 0.02
Params size (MB): 0.06
Estimated Total Size (MB): 0.29


### 学生模型训练&评估

In [14]:
epochs = 6

for epoch in range(epochs):
    model.train()
    for data,targets in tqdm(train_dataloader):
        # 前向预测
        outputs = model(data)
        loss = criterion(outputs, targets)

        # 反向传播，优化权重
        optimizer.zero_grad() # 把梯度置为0
        loss.backward()
        optimizer.step()

    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x,y in  test_dataloader:
            outputs = model(x)
            predictions = outputs.max(1).indices
            num_correct += (predictions==y).sum()
            num_samples += predictions.size(0)
            acc = (num_correct / num_samples).item()

    model.train()
    print("Epoch:{}\t Accuracy:{:4f}".format(epoch + 1, acc))

# 暂存不加蒸馏学生模型为student_model
student_model = model


100%|██████████| 938/938 [00:05<00:00, 177.43it/s]


Epoch:1	 Accuracy:0.686100


100%|██████████| 938/938 [00:05<00:00, 174.70it/s]


Epoch:2	 Accuracy:0.755400


100%|██████████| 938/938 [00:05<00:00, 175.49it/s]


Epoch:3	 Accuracy:0.783600


100%|██████████| 938/938 [00:07<00:00, 125.87it/s]


Epoch:4	 Accuracy:0.798700


100%|██████████| 938/938 [00:07<00:00, 125.51it/s]


Epoch:5	 Accuracy:0.804600


100%|██████████| 938/938 [00:06<00:00, 136.04it/s]


Epoch:6	 Accuracy:0.814500


### 知识蒸馏准备

In [15]:
# 准备预训练好的教师模型
teacher_model.eval()

# 准备新的学生模型
model = StudentModel().to(device)

# 蒸馏温度
T = 7

### 蒸馏参数设置

In [16]:
# hard_loss
hard_loss = nn.CrossEntropyLoss()
# hard_loss权重
alpha = 0.3
# soft_loss kl散度
soft_loss = nn.KLDivLoss(reduction='batchmean')

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

### 知识蒸馏训练&评估

In [17]:
epochs = 6
for epoch in range(epochs):
    for data,targets in tqdm(train_dataloader):
        data, targets = data.to(device), targets.to(device)
        # 教师模型预测
        with torch.no_grad():
            teacher_outputs = teacher_model(data)
        # 学生模型预测
        student_outputs = model(data)
        student_loss = hard_loss(student_outputs, targets)
        # 计算蒸馏后的预测结果及soft_loss
        distillation_loss = soft_loss(
            F.softmax(student_outputs/T, dim=1),
            F.softmax(teacher_outputs/T, dim=1)
        )
        # 将 hard_loss 和 soft_loss 加权求和
        loss = alpha * student_loss + (1-alpha) * distillation_loss
        # 反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试集上评估性能
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x,y in test_dataloader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            pred = outputs.max(1).indices
            num_correct += (pred == y).sum()
            num_samples += pred.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print("Epoch:{}\t Accuracy:{:4f}".format(epoch + 1, acc))


100%|██████████| 938/938 [00:10<00:00, 93.05it/s] 


Epoch:1	 Accuracy:0.666400


100%|██████████| 938/938 [00:09<00:00, 102.55it/s]


Epoch:2	 Accuracy:0.745100


100%|██████████| 938/938 [00:09<00:00, 97.97it/s] 


Epoch:3	 Accuracy:0.777100


100%|██████████| 938/938 [00:09<00:00, 102.24it/s]


Epoch:4	 Accuracy:0.791300


100%|██████████| 938/938 [00:08<00:00, 107.52it/s]


Epoch:5	 Accuracy:0.802900


100%|██████████| 938/938 [00:08<00:00, 108.41it/s]


Epoch:6	 Accuracy:0.810800


### 蒸馏模型保存

In [18]:
# 保存模型
torch.save(model.state_dict(), "./models/distillation_model.pth")