# 知识蒸馏Pytorch

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.backends
from torch.utils.data import DataLoader
from torchinfo import summary

from tqdm import tqdm

In [12]:
# 设置随机数种子，便于复现
torch.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
device
torch.backends.cudnn.benchmark = True

# 载入MNIST数据集

In [16]:
# ----------------------------------------------------------------------
# 载入训练集
# ----------------------------------------------------------------------
train_dataset = torchvision.datasets.MNIST(
    root = '/home/xianghao/data/',
    train = True,
    transform = torchvision.transforms.ToTensor(),
    download=False
)

test_dataset = torchvision.datasets.MNIST(
    root = '/home/xianghao/data/',
    train = False,
    transform = torchvision.transforms.ToTensor(),
    download=False
)

# 生成dataloader
train_dataloader = DataLoader(
    dataset = train_dataset,
    batch_size = 32,
    shuffle = True
)

test_dataloader = DataLoader(
    dataset = test_dataset, 
    batch_size = 32, 
    shuffle = False
)

# 教师模型

In [26]:
class TeacherModel(nn.Module):
    def __init__(self, in_channels = 1, num_classes = 10):
        super().__init__()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 2000)
        self.fc3 = nn.Linear(2000, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.fc3(x)
        
        return x

# 从头训练模型

In [27]:
model = TeacherModel()
model = model.to(device)
summary(model)

Layer (type:depth-idx)                   Param #
TeacherModel                             --
├─Linear: 1-1                            942,000
├─Linear: 1-2                            2,402,000
├─Linear: 1-3                            20,010
├─ReLU: 1-4                              --
├─Dropout: 1-5                           --
Total params: 3,364,010
Trainable params: 3,364,010
Non-trainable params: 0

In [28]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
epochs = 6
for epoch in range(epochs):
    model.train()
    for batch_idx, (inputs, targets) in tqdm(enumerate(train_dataloader)):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # 测试集评估
    model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for x, y in test_dataloader:
            x = x.to(device)
            y = y.to(device)
            
            preds = model(x) # (batch, 10)
            predictions = preds.max(dim=1).indices
            num_correct += (predictions==y).sum()
            num_samples = predictions.size(0)
        acc = (num_correct/num_samples).item()
    print(f"Epoch:{epoch}\tAccuracy:{acc:.4f}")
        

1875it [00:14, 133.17it/s]


Epoch:0	Accuracy:592.0625


1875it [00:06, 308.83it/s]


Epoch:1	Accuracy:601.9375


256it [00:00, 314.33it/s]