In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import copy
from torchvision import transforms
from torch.utils.data import DataLoader

In [2]:
from utils.data_util import *
from utils.model_util import LeNet5

torch.set_printoptions(precision=2,
                       threshold=1000,
                       edgeitems=5,
                       linewidth=1000,
                       sci_mode=False)
# 是否使用显卡加速
if torch.cuda.is_available():
    device = 'cuda'
    if torch.backends.cudnn.is_available():
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(device)

cuda


In [3]:
train_dataset, test_dateset, c, h, w = get_dataset()
DataSplit = SplitData(train_dataset)
[teacher_dataset, student_dataset, distill_dataset, test_dataset] = DataSplit.all_iid(4, 3200)
num_target = DataSplit.num_target

train_dataloder = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloder = DataLoader(test_dateset, batch_size=32, shuffle=True)

In [4]:
model = LeNet5(h, w, c, num_target)
# model = Teacher_model()
model = model.to(device)

# 损失函数和优化器
loss_function = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=0.0001)

epoches = 6
for epoch in range(epoches):
    model.train()
    for image, label in train_dataloder:
        image, label = image.to(device), label.to(device)
        optim.zero_grad()
        out = model(image)
        loss = loss_function(out, label)
        loss.backward()
        optim.step()

    model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for image, label in test_dataloder:
            image = image.to(device)
            label = label.to(device)
            out = model(image)
            pre = out.max(1).indices
            num_correct += (pre == label).sum()
            num_samples += pre.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print("epoches:{},accurate={}".format(epoch, acc))

epoches:0,accurate=0.8804000020027161
epoches:1,accurate=0.920199990272522
epoches:2,accurate=0.9350000023841858
epoches:3,accurate=0.9460999965667725
epoches:4,accurate=0.9554999470710754
epoches:5,accurate=0.9637999534606934


In [5]:
# 开始进行知识蒸馏算法
teacher_model = copy.deepcopy(model)
teacher_model.eval()
# model = Student_model()
hard_loss = nn.CrossEntropyLoss()
alpha = 0
soft_loss = nn.KLDivLoss(reduction="batchmean")
epoches = 5
# 蒸馏温度
for i in range(1, 20):
    alpha = i * 0.05
    for t in range(1, 10):
        T = t * 0.5
        for epoch in range(epoches):
            model = LeNet5(h, w, c, num_target)
            model = model.to(device)
            model.train()
            optim = torch.optim.Adam(model.parameters(), lr=0.0001)
            for image, label in train_dataloder:
                image, label = image.to(device), label.to(device)
                with torch.no_grad():
                    teacher_output = teacher_model(image)
                optim.zero_grad()
                out = model(image)
                loss = hard_loss(out, label)
                ditillation_loss = soft_loss(
                    F.softmax(out/T, dim=1), F.softmax(teacher_output/T, dim=1))
                loss_all = loss * alpha + ditillation_loss * (1 - alpha)
                loss_all.backward()
                optim.step()

            model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for image, label in test_dataloder:
                    image = image.to(device)
                    label = label.to(device)
                    out = model(image)
                    pre = out.max(1).indices
                    num_correct += (pre == label).sum()
                    num_samples += pre.size(0)
                acc = (num_correct/num_samples).item()

            model.train()
            print("alpha:{:.2f}, T:{:.1f}, epoches:{}, accurate={:.3f}".format(alpha, T, epoch, acc))

alpha:0.05, T:0.5, epoches:0, accurate=0.879
alpha:0.05, T:0.5, epoches:1, accurate=0.891
alpha:0.05, T:0.5, epoches:2, accurate=0.887
alpha:0.05, T:0.5, epoches:3, accurate=0.900
alpha:0.05, T:0.5, epoches:4, accurate=0.888
alpha:0.05, T:1.0, epoches:0, accurate=0.887
alpha:0.05, T:1.0, epoches:1, accurate=0.884
alpha:0.05, T:1.0, epoches:2, accurate=0.893
alpha:0.05, T:1.0, epoches:3, accurate=0.875
alpha:0.05, T:1.0, epoches:4, accurate=0.880
alpha:0.05, T:1.5, epoches:0, accurate=0.874
alpha:0.05, T:1.5, epoches:1, accurate=0.876
alpha:0.05, T:1.5, epoches:2, accurate=0.867
alpha:0.05, T:1.5, epoches:3, accurate=0.868
alpha:0.05, T:1.5, epoches:4, accurate=0.874
alpha:0.05, T:2.0, epoches:0, accurate=0.875
alpha:0.05, T:2.0, epoches:1, accurate=0.874
alpha:0.05, T:2.0, epoches:2, accurate=0.881
alpha:0.05, T:2.0, epoches:3, accurate=0.870
alpha:0.05, T:2.0, epoches:4, accurate=0.871
alpha:0.05, T:2.5, epoches:0, accurate=0.872
alpha:0.05, T:2.5, epoches:1, accurate=0.869
alpha:0.05