In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torchvision import datasets,transforms
from tqdm import tqdm
from d2l import torch as d2l

torch.manual_seed(0)
torch.cuda.manual_seed(0)

## 先训练老师模型

In [4]:
#定义老师模型的网络
class TeacherNet(nn.Module):
    def __init__(self) -> None:
        super(TeacherNet,self).__init__()
        self.net=nn.Sequential(
            nn.Conv2d(1,32,3,1),nn.ReLU(),
            nn.Conv2d(32,64,3,1),nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.3),nn.Flatten(),
            nn.Linear(9216,128),nn.ReLU(),
            nn.Dropout2d(0.5),
            nn.Linear(128,10),
        )
        
    def forward(self,x):
        return self.net(x)

In [48]:
def train_teacher(model,device,train_loader,optimizer,epoch):
    model.train()
    trained_samples=0
    train_loss,train_acc=[],[]
    
    for batch_idx,(data,target) in enumerate(train_loader):
        data,target=data.to(device),target.to(device)
        optimizer.zero_grad()
        output=model(data)
        loss=F.cross_entropy(output,target)
        loss.backward()
        optimizer.step()

        trained_samples+=len(data)
        progress=math.ceil(batch_idx/len(train_loader)*50)
        print("\rTrain epoch %d:%d/%d,[%-51s]%d%%" % (epoch,trained_samples,len(train_loader.dataset),'-'*progress+'>',progress*2),end='')
        # acc=(output.argmax(dim=-1)==target).float().mean()

        # #记录损失和精度
        # train_loss.append(loss.item())
        # train_acc.append(acc)
        # print(f'训练周期：{epoch+1:03d}/{num_epoch:03d} 训练损失：{train_loss:.5f} 训练精度：{train_acc:.5f}')

def test_teacher(model,device,test_loader):
    model.eval()
    test_loss=0
    correct=0

    with torch.no_grad():
        for data,target in test_loader:
            data,target =data.to(device),target.to(device)
            output=model(data)
            test_loss+=F.cross_entropy(output,target,reduction='sum').item()
            pred=output.argmax(dim=1,keepdim=True)
            correct+=pred.eq(target.view_as(pred)).sum().item()

    test_loss/=len(test_loader.dataset)

    print('\nTest:average loss:{:.4f},accuracy:{}/{} ({:.0f}%)'.format(test_loss,correct,len(test_loader.dataset),100.*correct/len(test_loader.dataset)))

    return test_loss,correct/len(test_loader.dataset)

In [49]:
def teacher_main():
    num_epochs=10
    batch_size=64
    torch.manual_seed(0)

    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader=torch.utils.data.DataLoader(
        datasets.MNIST("../data/MNIST",train=True,download="True",
                        transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.1307,),(0.3081,))
        ])),
        batch_size=batch_size,shuffle=True
    )

    test_loader=torch.utils.data.DataLoader(
        datasets.MNIST("../data/MNIST",train=False,download=True,
                        transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.1307,),(0.3081,))
        ])),
        batch_size=1000,shuffle=True
    )

    model=TeacherNet().to(device)
    optimizer=torch.optim.Adadelta(model.parameters())

    teacher_history=[]
    max_acc=0.0
    for epoch in range(num_epochs):
        train_teacher(model,device,train_loader,optimizer,epoch)
        loss,acc=test_teacher(model,device,test_loader)

        teacher_history.append((loss,acc))
        if acc>max_acc:
            torch.save(model.state_dict(),"teacher.pt")
            max_acc=acc
    
    return model,teacher_history

In [50]:
teacher_model,teacher_history=teacher_main()

Train epoch 0:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0510,accuracy:9854/10000 (99%)
Train epoch 1:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0406,accuracy:9871/10000 (99%)
Train epoch 2:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0323,accuracy:9888/10000 (99%)
Train epoch 3:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0352,accuracy:9883/10000 (99%)
Train epoch 4:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0332,accuracy:9898/10000 (99%)
Train epoch 5:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0313,accuracy:9908/10000 (99%)
Train epoch 6:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0337,accuracy:9908/10000 (99%)
Train epoch 7:60000/60000,[----------------------------

In [10]:
#让老师教学生网络
class StudentNet(nn.Module):
    def __init__(self) -> None:
        super(StudentNet,self).__init__()
        self.fc1=nn.Linear(28*28,128)
        self.fc2=nn.Linear(128,64)
        self.fc3=nn.Linear(64,10)
    
    def forward(self,x):
        x=torch.flatten(x,1)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        output=F.relu(self.fc3(x))
        return output

In [11]:
#关键：定义KD的loss
def distillationLoss(y,labels,teacher_scores,temp,alpha):
    return nn.KLDivLoss()(F.log_softmax(y/temp,dim=1),F.softmax(teacher_scores/temp,dim=1))*(temp*temp*2.0*alpha)+F.cross_entropy(y,labels)*(1-alpha)

In [27]:
teacher=TeacherNet()
teacher.load_state_dict(torch.load("teacher.pt"))

def train_student_kd(model,device,train_loader,optimizer,epoch):
    model.train()
    trained_samples=0

    for batch_idx,(data,target) in enumerate(train_loader):
        data,target=data.to(device),target.to(device)
        optimizer.zero_grad()
        output=model(data)
        teacher_output=teacher(data)
        loss=distillationLoss(output,target,teacher_output,temp=5.0,alpha=0.7)
        loss.backward()
        optimizer.step()

        trained_samples+=len(data)
        progress=math.ceil(batch_idx/len(train_loader)*50)
        print("\rTrain epoch %d:%d/%d,[%-51s]%d%%" % (epoch,trained_samples,len(train_loader.dataset),'-'*progress+'>',progress*2),end='')

def test_student_kd(model,device,test_loader):
    model.eval()
    test_loss=0
    correct=0

    with torch.no_grad():
        for data,target in test_loader:
            data,target=data.to(device),target.to(device)
            output=model(data)
            test_loss+=F.cross_entropy(output,target,reduction="sum").item()
            pred=output.argmax(dim=1,keepdim=True)
            correct+=pred.eq(target.view_as(pred)).sum().item()
    
    test_loss/=len(test_loader.dataset)

    print("\nTest:average loss:{:.4f},accuracy:{}/{} ({:.0f}%)".format(test_loss,correct,len(test_loader.dataset),100.*correct/len(test_loader.dataset)))

    return test_loss,correct/len(test_loader.dataset)

In [30]:
def student_kd_main():
    epochs=10
    batch_size=64
    torch.manual_seed(0)

    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader=torch.utils.data.DataLoader(
        datasets.MNIST("../data/MNIST",train=True,download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,),(0.3081,))
        ])),
        batch_size=batch_size,shuffle=True
    )

    test_loader=torch.utils.data.DataLoader(
        datasets.MNIST("../data/MNIST",train=False,download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,),(0.3081,))
        ])),
        batch_size=1000,shuffle=True
    )

    model=StudentNet().to(device)
    optimizer=torch.optim.Adadelta(model.parameters())

    student_history=[]
    max_acc=0.0
    for epoch in range(1,epochs+1):
        train_student_kd(model,device,train_loader,optimizer,epoch)
        loss,acc=test_student_kd(model,device,test_loader)
        student_history.append((loss,acc))

        if acc>max_acc:
            torch.save(model.state_dict(),"student_kd.pt")
            max_acc=acc
    
    return model,student_history

In [31]:
student_kd_model,student_kd_history=student_kd_main()



Train epoch 1:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.1555,accuracy:9663/10000 (97%)
Train epoch 2:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.1268,accuracy:9716/10000 (97%)
Train epoch 3:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0989,accuracy:9774/10000 (98%)
Train epoch 4:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0888,accuracy:9796/10000 (98%)
Train epoch 5:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0795,accuracy:9802/10000 (98%)
Train epoch 6:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0785,accuracy:9821/10000 (98%)
Train epoch 7:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0689,accuracy:9817/10000 (98%)
Train epoch 8:60000/60000,[----------------------------

In [41]:
def train_student(model,device,train_loader,optimizer,epoch):
    model.train()


    for data,target in tqdm(train_loader):
        data,target=data.to(device),target.to(device)
        optimizer.zero_grad()
        output=model(data)
        loss=F.cross_entropy(output,target)
        loss.backward()
        optimizer.step()

    print('epoch:%d finished' % epoch)

def test_student(model,device,test_loader):
    model.eval()
    loss=0
    acc=0

    with torch.no_grad():
        for data,target in test_loader:
            data,target=data.to(device),target.to(device)
            output=model(data)
            loss+=F.cross_entropy(output,target,reduction='sum').item()
            pred=output.argmax(dim=1,keepdim=True)
            acc+=pred.eq(target.view_as(pred)).sum().item()
        
        loss/=len(test_loader.dataset)
    
    print(f'test acc:{acc/len(test_loader.dataset)},average loss:{loss}')    

    return loss,acc/len(test_loader.dataset)

In [42]:
def student_main():
    train_loader=torch.utils.data.DataLoader(
            datasets.MNIST("../data/MNIST",train=True,download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,),(0.3081,))
            ])),
            batch_size=64,shuffle=True
        )

    test_loader=torch.utils.data.DataLoader(
        datasets.MNIST("../data/MNIST",train=False,download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,),(0.3081,))
        ])),
        batch_size=1000,shuffle=True
    )

    epochs=10
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model=StudentNet().to(device)
    optimizer=torch.optim.Adadelta(model.parameters())
    max_acc=0.0

    for epoch in range(epochs):
        train_student(model,device,train_loader,optimizer,epoch+1)
        loss,acc=test_student(model,device,test_loader)
        
        if acc>max_acc:
            torch.save(model.state_dict(),"student.pt")
            max_acc=acc
        
    return model
    

In [43]:
model=student_main()

100%|██████████| 938/938 [00:21<00:00, 43.65it/s]


epoch:1 finished
test acc:0.8709,average loss:0.3392891387939453


100%|██████████| 938/938 [00:21<00:00, 44.02it/s]


epoch:2 finished
test acc:0.8779,average loss:0.31562213897705077


100%|██████████| 938/938 [00:21<00:00, 43.49it/s]


epoch:3 finished
test acc:0.8806,average loss:0.30705920104980466


100%|██████████| 938/938 [00:21<00:00, 43.42it/s]


epoch:4 finished
test acc:0.8787,average loss:0.3131327026367188


100%|██████████| 938/938 [00:21<00:00, 43.39it/s]


epoch:5 finished
test acc:0.8856,average loss:0.2978225158691406


100%|██████████| 938/938 [00:21<00:00, 44.00it/s]


epoch:6 finished
test acc:0.8816,average loss:0.31645680236816404


100%|██████████| 938/938 [00:21<00:00, 44.02it/s]


epoch:7 finished
test acc:0.8807,average loss:0.32659322814941405


100%|██████████| 938/938 [00:22<00:00, 41.74it/s]


epoch:8 finished
test acc:0.8821,average loss:0.32430076904296873


100%|██████████| 938/938 [00:20<00:00, 44.75it/s]


epoch:9 finished
test acc:0.8835,average loss:0.3294613433837891


100%|██████████| 938/938 [00:20<00:00, 45.14it/s]


epoch:10 finished
test acc:0.8836,average loss:0.3229833068847656
