In [35]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torchvision import datasets,transforms
from tqdm import tqdm
from d2l import torch as d2l

torch.manual_seed(0)
torch.cuda.manual_seed(0)

## 先训练老师模型

In [36]:
#定义老师模型的网络
class TeacherNet(nn.Module):
    def __init__(self) -> None:
        super(TeacherNet,self).__init__()
        self.net=nn.Sequential(
            nn.Conv2d(1,32,3,1),nn.ReLU(),
            nn.Conv2d(32,64,3,1),nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.3),nn.Flatten(),
            nn.Linear(9216,128),nn.ReLU(),
            nn.Dropout2d(0.5),
            nn.Linear(128,10),
        )
        
    def forward(self,x):
        return self.net(x)

In [48]:
def train_teacher(model,device,train_loader,optimizer,epoch):
    model.train()
    trained_samples=0
    train_loss,train_acc=[],[]
    
    for batch_idx,(data,target) in enumerate(train_loader):
        data,target=data.to(device),target.to(device)
        optimizer.zero_grad()
        output=model(data)
        loss=F.cross_entropy(output,target)
        loss.backward()
        optimizer.step()

        trained_samples+=len(data)
        progress=math.ceil(batch_idx/len(train_loader)*50)
        print("\rTrain epoch %d:%d/%d,[%-51s]%d%%" % (epoch,trained_samples,len(train_loader.dataset),'-'*progress+'>',progress*2),end='')
        # acc=(output.argmax(dim=-1)==target).float().mean()

        # #记录损失和精度
        # train_loss.append(loss.item())
        # train_acc.append(acc)
        # print(f'训练周期：{epoch+1:03d}/{num_epoch:03d} 训练损失：{train_loss:.5f} 训练精度：{train_acc:.5f}')

def test_teacher(model,device,test_loader):
    model.eval()
    test_loss=0
    correct=0

    with torch.no_grad():
        for data,target in test_loader:
            data,target =data.to(device),target.to(device)
            output=model(data)
            test_loss+=F.cross_entropy(output,target,reduction='sum').item()
            pred=output.argmax(dim=1,keepdim=True)
            correct+=pred.eq(target.view_as(pred)).sum().item()

    test_loss/=len(test_loader.dataset)

    print('\nTest:average loss:{:.4f},accuracy:{}/{} ({:.0f}%)'.format(test_loss,correct,len(test_loader.dataset),100.*correct/len(test_loader.dataset)))

    return test_loss,correct/len(test_loader.dataset)

In [49]:
def teacher_main():
    num_epochs=10
    batch_size=64
    torch.manual_seed(0)

    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader=torch.utils.data.DataLoader(
        datasets.MNIST("../data/MNIST",train=True,download="True",
                        transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.1307,),(0.3081,))
        ])),
        batch_size=batch_size,shuffle=True
    )

    test_loader=torch.utils.data.DataLoader(
        datasets.MNIST("../data/MNIST",train=False,download=True,
                        transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.1307,),(0.3081,))
        ])),
        batch_size=1000,shuffle=True
    )

    model=TeacherNet().to(device)
    optimizer=torch.optim.Adadelta(model.parameters())

    teacher_history=[]
    max_acc=0.0
    for epoch in range(num_epochs):
        train_teacher(model,device,train_loader,optimizer,epoch)
        loss,acc=test_teacher(model,device,test_loader)

        teacher_history.append((loss,acc))
        if acc>max_acc:
            torch.save(model.state_dict(),"teacher.pt")
            max_acc=acc
    
    return model,teacher_history

In [50]:
teacher_model,teacher_history=teacher_main()

Train epoch 0:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0510,accuracy:9854/10000 (99%)
Train epoch 1:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0406,accuracy:9871/10000 (99%)
Train epoch 2:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0323,accuracy:9888/10000 (99%)
Train epoch 3:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0352,accuracy:9883/10000 (99%)
Train epoch 4:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0332,accuracy:9898/10000 (99%)
Train epoch 5:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0313,accuracy:9908/10000 (99%)
Train epoch 6:60000/60000,[-------------------------------------------------->]100%
Test:average loss:0.0337,accuracy:9908/10000 (99%)
Train epoch 7:60000/60000,[----------------------------