In [3]:
import torch
import numpy as np
import pandas as pd
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
# from torchinfo import summary
from tqdm import tqdm


import matplotlib.pyplot as plt

In [4]:
#设置随机种子

torch.manual_seed(0)

# device = torch.device("cuda" if torch.cuda.is_available() else "pcu")

<torch._C.Generator at 0x21ed766e350>

In [5]:
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark=True

In [6]:
# 载入MNIST数据集
# 载入训练集
train_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
# 载入测试集
test_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)
train_loder = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
test_loder  = DataLoader(dataset=test_dataset, batch_size=32,shuffle=False)

In [7]:
# 教师模型
class TeacherModel(nn.Module):
    def __init__(self,in_channels=1,num_classes=10):
        super(TeacherModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,1200)
        self.fc2 = nn.Linear(1200,1200)
        self.fc3 = nn.Linear(1200,num_classes)
        self.dropout = nn.Dropout(p = 0.5)

    def forward(self,x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

In [8]:
model = TeacherModel()

In [11]:
criterion = nn.CrossEntropyLoss() # 设置使用交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4) # 使用Adam优化器，学习率为lr=1e-4

epochs = 1 # 训练6轮
for epoch in range(epochs):
    model.train()

    for data,targets in tqdm(train_loder):
        # 前向预测
        preds = model(data)
        loss = criterion(preds,targets)

        # 反向传播，优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试集上评估性能
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x,y in test_loder:
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct / num_samples).item()
    
    model.train()
    teacher_model = model
    print(("Epoch:{}\t Accuracy:{:4f}").format(epoch+1,acc))

100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:14<00:00, 127.94it/s]


Epoch:1	 Accuracy:0.906300


In [18]:
# 学生模型
class StudentModel(nn.Module):
    def __init__( self,inchannels=1,num_class=10):
        super(StudentModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_class)
        #self.dropout = nn.Dropout(p = 0.5)

    def forward(self,x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        #x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        #x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc3(x)

        return x

model = StudentModel() # 从头先训练一下学生模型
# 设置交叉损失函数 和 激活函数
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
epochs = 3
# 训练集上训练权重
for epoch in range(epochs):
    model.train()

    for data,targets in tqdm(train_loder):
        # 前向预测
        preds = model(data)
        loss = criterion(preds,targets)

        # 反向传播，优化权重
        optimizer.zero_grad() # 把梯度置为0
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        for x,y in  test_loder:
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions==y).sum()
            num_samples += predictions.size(0)
            acc = (num_correct / num_samples).item()

    model.train()
    print(("Epoch:{}\t Accuracy:{:4f}").format(epoch+1,acc))

100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:10<00:00, 171.88it/s]


Epoch:1	 Accuracy:0.879980


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:10<00:00, 182.26it/s]


Epoch:2	 Accuracy:0.881233


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:09<00:00, 188.08it/s]


Epoch:3	 Accuracy:0.884443


In [19]:
student_model_scratch = model

In [20]:
# 准备好预训练好的教师模型
teacher_model.eval()

# 准备新的学生模型
model = StudentModel()
model.train()

StudentModel(
  (relu): ReLU()
  (fc1): Linear(in_features=784, out_features=20, bias=True)
  (fc2): Linear(in_features=20, out_features=20, bias=True)
  (fc3): Linear(in_features=20, out_features=10, bias=True)
)

In [21]:
# 蒸馏温度
temp = 7

# hard_loss
hard_loss = nn.CrossEntropyLoss()
# hard_loss权重
alpha = 0.3

# soft_loss
soft_loss = nn.KLDivLoss(reduction="batchmean")

optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)


In [16]:
epochs = 3
for epoch in range(epochs):
    for data,targets in tqdm(train_loder):
        # 教师模型预测
        with torch.no_grad():
            teacher_preds = teacher_model(data)

        # 学生模型预测
        student_preds = student_model_scratch(data)

        student_loss = hard_loss(student_preds,targets)

        # 计算蒸馏后的预测结果及soft_loss
        distillation_loss = soft_loss(
            F.softmax(student_preds/temp, dim=1),
            F.softmax(teacher_preds/temp, dim=1)
        )

        # 将 hard_loss 和 soft_loss 加权求和
        loss = alpha * student_loss + (1-alpha) * distillation_loss

        # 反向传播,优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试集上评估性能
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x,y in test_loder:
            preds = student_model_scratch(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print(("Epoch:{}\t Accuracy:{:4f}").format(epoch+1,acc))

100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:10<00:00, 176.48it/s]


Epoch:1	 Accuracy:0.899300


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:10<00:00, 180.12it/s]


Epoch:2	 Accuracy:0.899300


100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:09<00:00, 195.39it/s]


Epoch:3	 Accuracy:0.899300
