In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm

In [3]:
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True

In [5]:
# 载入训练集
train_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

# 载入测试集
test_dataset = torchvision.datasets.MNIST(
    root="dataset/",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

# 生成dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

In [33]:
print(train_dataset)

Dataset MNIST
    Number of datapoints: 60000
    Root location: dataset/
    Split: Train
    StandardTransform
Transform: ToTensor()


In [8]:
class TeacherModel(nn.Module):
    def __init__(self,in_channels=1,num_classes=10):
        super().__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,1024)
        self.fc2 = nn.Linear(1024,1024)
        self.fc3 = nn.Linear(1024,num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc3(x)
        
        return x

In [9]:
model = TeacherModel()
model = model.to(device)

In [10]:
summary(model)

Layer (type:depth-idx)                   Param #
TeacherModel                             --
├─ReLU: 1-1                              --
├─Linear: 1-2                            803,840
├─Linear: 1-3                            1,049,600
├─Linear: 1-4                            10,250
├─Dropout: 1-5                           --
Total params: 1,863,690
Trainable params: 1,863,690
Non-trainable params: 0

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 1e-4)

In [13]:
epochs = 6
for epoch in range(epochs):
    model.train()
    
    for data,targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        preds = model(data)
        loss = criterion(preds,targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x,y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()
    
    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

100%|██████████| 1875/1875 [00:05<00:00, 348.01it/s]


Epoch:1	 Accuracy:0.9386


100%|██████████| 1875/1875 [00:05<00:00, 362.72it/s]


Epoch:2	 Accuracy:0.9616


100%|██████████| 1875/1875 [00:05<00:00, 364.28it/s]


Epoch:3	 Accuracy:0.9682


100%|██████████| 1875/1875 [00:05<00:00, 358.76it/s]


Epoch:4	 Accuracy:0.9731


100%|██████████| 1875/1875 [00:05<00:00, 363.01it/s]


Epoch:5	 Accuracy:0.9747


100%|██████████| 1875/1875 [00:05<00:00, 359.78it/s]


Epoch:6	 Accuracy:0.9783


In [14]:
teacher_model = model

In [19]:
class StudentModel(nn.Module):
    def __init__(self,in_channels = 1,num_classes = 10):
        super().__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784,32)
        self.fc2 = nn.Linear(32,32)
        self.fc3 = nn.Linear(32,num_classes)

    def forward(self,x):
        x = x.view(-1,784)
        
        x = self.fc1(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.relu(x)
        
        x = self.fc3(x)
        return x

In [20]:
model = StudentModel()
model = model.to(device)

In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 1e-4)

In [22]:
epochs = 3
for epoch in range(epochs):
    model.train()
    
    for data,targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        preds = model(data)
        loss = criterion(preds,targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x,y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()
    
    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

100%|██████████| 1875/1875 [00:05<00:00, 372.31it/s]


Epoch:1	 Accuracy:0.8832


100%|██████████| 1875/1875 [00:05<00:00, 370.94it/s]


Epoch:2	 Accuracy:0.9067


100%|██████████| 1875/1875 [00:05<00:00, 372.76it/s]


Epoch:3	 Accuracy:0.9148


In [24]:
student_model_scratch = model

# 知识蒸馏训练学生模型

In [30]:
teacher_model.eval()

model = StudentModel()
model = model.to(device)
model.train()

Temp = 3 # 超参数之一

In [36]:
hard_loss = nn.CrossEntropyLoss()
alpha = 0.3 # 超参数之一

soft_loss = nn.KLDivLoss(reduction="batchmean")

optimizer = torch.optim.Adam(model.parameters(),lr = 1e-4)

In [37]:
epochs = 3
for epoch in range(epochs):
    
    for data,targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        
        with torch.no_grad():
            teacher_preds = teacher_model(data)
            
        student_preds = model(data)
        
        loss = alpha * (Temp * Temp) * soft_loss(F.softmax(student_preds / Temp,dim = 1),F.softmax(teacher_preds / Temp,dim = 1)) +(1-alpha)* hard_loss(student_preds,targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for x,y in test_loader:
            x = x.to(device)
            y = y.to(device)
            
            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()
    
    model.train()
    print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

100%|██████████| 1875/1875 [00:05<00:00, 325.35it/s]


Epoch:1	 Accuracy:0.9237


100%|██████████| 1875/1875 [00:05<00:00, 324.93it/s]


Epoch:2	 Accuracy:0.9274


100%|██████████| 1875/1875 [00:05<00:00, 327.66it/s]


Epoch:3	 Accuracy:0.9295
