In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import copy

from torchvision import transforms
from torch.utils.data import DataLoader

In [2]:
class Teacher_model(nn.Module):
    def __init__(self, in_channels=1, num_class=10):
        super(Teacher_model, self).__init__()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


class Student_model(nn.Module):
    def __init__(self, in_channels=1, num_class=10):
        super(Student_model, self).__init__()
        self.fc1 = nn.Linear(784, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 10)
        self.relu = nn.ReLU()
        # self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        # x = self.dropout(x)
        x = self.relu(x)
        x = self.fc2(x)
        # x = self.dropout(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

In [3]:
torch.manual_seed(0)
# 使用GPU进行加速卷积运算
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(device)
# 载入训练集
train_dataset = torchvision.datasets.MNIST(
    root="./data/raw-data", train=True, transform=transforms.ToTensor(), download=True)
test_dateset = torchvision.datasets.MNIST(
    root="./data/raw-data", train=False, transform=transforms.ToTensor(), download=True)
train_dataloder = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloder = DataLoader(test_dateset, batch_size=32, shuffle=True)

cuda


In [4]:
# model = Teacher_model()
teacher_model = Student_model()
teacher_model = teacher_model.to(device)

# 损失函数和优化器
loss_function = nn.CrossEntropyLoss()
optim = torch.optim.Adam(teacher_model.parameters(), lr=0.0001)

epoches = 6
for epoch in range(epoches):
    teacher_model.train()
    for image, label in train_dataloder:
        image, label = image.to(device), label.to(device)
        optim.zero_grad()
        out = teacher_model(image)
        loss = loss_function(out, label)
        loss.backward()
        optim.step()

    teacher_model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for image, label in test_dataloder:
            image = image.to(device)
            label = label.to(device)
            out = teacher_model(image)
            pre = out.max(1).indices
            num_correct += (pre == label).sum()
            num_samples += pre.size(0)
        acc = (num_correct/num_samples).item()

    teacher_model.train()
    print("epoches:{},accurate={}".format(epoch, acc))

epoches:0,accurate=0.856499969959259
epoches:1,accurate=0.8878999948501587
epoches:2,accurate=0.9003999829292297
epoches:3,accurate=0.9080999493598938
epoches:4,accurate=0.911899983882904
epoches:5,accurate=0.9154999852180481


In [5]:
# model = Student_model()
# model = model.to(device)

# 损失函数和优化器
# loss_function = nn.CrossEntropyLoss()
# optim = torch.optim.Adam(model.parameters(), lr=0.0001)

# # 学生网络的训练和预测结果
# epoches = 6
# for epoch in range(epoches):
#     model.train()
#     for image, label in train_dataloder:
#         image, label = image.to(device), label.to(device)
#         optim.zero_grad()
#         out = model(image)
#         loss = loss_function(out, label)
#         loss.backward()
#         optim.step()

#     model.eval()
#     num_correct = 0
#     num_samples = 0
#     with torch.no_grad():
#         for image, label in test_dataloder:
#             image = image.to(device)
#             label = label.to(device)
#             out = model(image)
#             pre = out.max(1).indices
#             num_correct += (pre == label).sum()
#             num_samples += pre.size(0)
#         acc = (num_correct/num_samples).item()

#     model.train()
#     print("epoches:{},accurate={}".format(epoch, acc))

In [7]:
# 开始进行知识蒸馏算法
teacher_model.eval()
model = Student_model()
model = model.to(device)
# 蒸馏温度
T = 7
hard_loss = nn.CrossEntropyLoss()
alpha = 0
soft_loss = nn.KLDivLoss(reduction="batchmean")
optim = torch.optim.Adam(model.parameters(), lr=0.0001)

epoches = 5
for epoch in range(epoches):
    model.train()
    for image, label in train_dataloder:
        image, label = image.to(device), label.to(device)
        with torch.no_grad():
            teacher_output = teacher_model(image)
        optim.zero_grad()
        out = model(image)
        loss = hard_loss(out, label)
        ditillation_loss = soft_loss(
            F.softmax(out/T, dim=1), F.softmax(teacher_output/T, dim=1))
        loss_all = loss * alpha + ditillation_loss * (1 - alpha)
        loss_all.backward()
        optim.step()

    model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for image, label in test_dataloder:
            image = image.to(device)
            label = label.to(device)
            out = model(image)
            pre = out.max(1).indices
            num_correct += (pre == label).sum()
            num_samples += pre.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print("epoches:{},accurate={}".format(epoch, acc))

epoches:0,accurate=0.34529998898506165
epoches:1,accurate=0.4495999813079834
epoches:2,accurate=0.6771999597549438
epoches:3,accurate=0.6991999745368958
epoches:4,accurate=0.7091999650001526
