In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary

In [3]:
#设置随机种子
torch.manual_seed(0)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
#使用cuda进行加速卷积运算
torch.backends.cudnn.benchmark=True
#载入训练集
train_dataset=torchvision.datasets.MNIST(root="minst/",train=True,transform=transforms.ToTensor(),download=True)
test_dateset=torchvision.datasets.MNIST(root="minst/",train=False,transform=transforms.ToTensor(),download=True)
train_dataloder=DataLoader(train_dataset,batch_size=32,shuffle=True)
test_dataloder=DataLoader(test_dateset,batch_size=32,shuffle=True)

In [4]:
#搭建教师网络
class Teacher_model(nn.Module):
    def __init__(self,in_channels=1,num_class=10):
        super(Teacher_model, self).__init__()
        self.fc1=nn.Linear(784,1200)
        self.fc2=nn.Linear(1200,1200)
        self.fc3=nn.Linear(1200,10)
        self.relu=nn.ReLU()
        self.dropout=nn.Dropout(0.5)
    def forward(self,x):
        x=x.view(-1,784)
        x=self.fc1(x)
        x=self.dropout(x)
        x=self.relu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x
 
model=Teacher_model()
model=model.to(device)
 
#损失函数和优化器
loss_function=nn.CrossEntropyLoss()
optim=torch.optim.Adam(model.parameters(),lr=0.0001)

#训练
epoches=6
for epoch in range(epoches):
    model.train()
    for image,label in train_dataloder:
        image,label=image.to(device),label.to(device)
        optim.zero_grad()
        out=model(image)
        loss=loss_function(out,label)
        loss.backward()
        optim.step()
 
    model.eval()
    num_correct=0
    num_samples=0
    with torch.no_grad():
        for image,label in test_dataloder:
            image=image.to(device)
            label=label.to(device)
            out=model(image)
            pre=out.max(1).indices
            num_correct+=(pre==label).sum()
            num_samples+=pre.size(0)
        acc=(num_correct/num_samples).item()
 
    model.train()
    print("epoches:{},accurate={}".format(epoch,acc))

teacher_model=model

epoches:0,accurate=0.9406999945640564
epoches:1,accurate=0.9606999754905701
epoches:2,accurate=0.967799961566925
epoches:3,accurate=0.9739999771118164
epoches:4,accurate=0.977899968624115
epoches:5,accurate=0.9792999625205994


In [8]:
#搭建学生网络
class Student_model(nn.Module):
    def __init__(self,in_channels=1,num_class=10):
        super(Student_model, self).__init__()
        self.fc1 = nn.Linear(784, 15)
        self.fc2 = nn.Linear(15, 15)
        self.fc3 = nn.Linear(15, 10)
        self.relu = nn.ReLU()
        #self.dropout = nn.Dropout(0.5)
 
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        #x = self.dropout(x)
        x = self.relu(x)
        x = self.fc2(x)
        #x = self.dropout(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x
 
model=Student_model()
model=model.to(device)
 
#损失函数和优化器
loss_function=nn.CrossEntropyLoss()
optim=torch.optim.Adam(model.parameters(),lr=0.0001)

#训练
epoches=6
for epoch in range(epoches):
    model.train()
    for image,label in train_dataloder:
        image,label=image.to(device),label.to(device)
        optim.zero_grad()
        out=model(image)
        loss=loss_function(out,label)
        loss.backward()
        optim.step()
 
    model.eval()
    num_correct=0
    num_samples=0
    with torch.no_grad():
        for image,label in test_dataloder:
            image=image.to(device)
            label=label.to(device)
            out=model(image)
            pre=out.max(1).indices
            num_correct+=(pre==label).sum()
            num_samples+=pre.size(0)
        acc=(num_correct/num_samples).item()
 
    model.train()
    print("epoches:{},accurate={}".format(epoch,acc))
    
student_model=model

epoches:0,accurate=0.7882999777793884
epoches:1,accurate=0.8601999878883362
epoches:2,accurate=0.8833999633789062
epoches:3,accurate=0.8973999619483948
epoches:4,accurate=0.9031999707221985
epoches:5,accurate=0.9063999652862549


## 知识蒸馏

In [15]:
 #开始进行知识蒸馏算法
teacher_model.eval()
model=Student_model()
model=model.to(device)
#蒸馏温度
T=4
hard_loss=nn.CrossEntropyLoss()
alpha=0.3
soft_loss=nn.KLDivLoss(reduction="batchmean")
optim=torch.optim.Adam(model.parameters(),lr=0.0001)

#训练与结果预测
epoches=6
for epoch in range(epoches):
    model.train()
    for image,label in train_dataloder:
        image,label=image.to(device),label.to(device)
        with torch.no_grad():
            teacher_output=teacher_model(image)
        optim.zero_grad()
        out=model(image)
        loss=hard_loss(out,label)
        ditillation_loss=soft_loss(F.softmax(out/T,dim=1),F.softmax(teacher_output/T,dim=1))
        loss_all=loss*alpha+ditillation_loss*(1-alpha)
        loss.backward()
        optim.step()
 
    model.eval()
    num_correct=0
    num_samples=0
    with torch.no_grad():
        for image,label in test_dataloder:
            image=image.to(device)
            label=label.to(device)
            out=model(image)
            pre=out.max(1).indices
            num_correct+=(pre==label).sum()
            num_samples+=pre.size(0)
        acc=(num_correct/num_samples).item()
 
    model.train()
    print("epoches:{},accurate={}".format(epoch,acc))
    
student_model=model

epoches:0,accurate=0.857699990272522
epoches:1,accurate=0.88919997215271
epoches:2,accurate=0.9013999700546265
epoches:3,accurate=0.9060999751091003
epoches:4,accurate=0.9098999500274658
epoches:5,accurate=0.913100004196167
