In [None]:
import torch
from torch import nn
import torchvision
from torchvision import datasets, transforms
from torch.utils import data
from torch.nn import functional as F

In [None]:
def load_data_from_FashionMNIST(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))  # 用于调整图片大小
    trans = transforms.Compose(trans)  # compose用于将多个transform组合操作
    train_set = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=trans, download=True)
    test_set = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=trans, download=True)
    workers_num = 4
    return data.DataLoader(train_set, batch_size, shuffle=True, num_workers=workers_num), \
        data.DataLoader(test_set, batch_size, shuffle=False, num_workers=workers_num)

In [None]:
train_loader, test_loader = load_data_from_FashionMNIST(256)

In [None]:
def to_one_hot(test_batch):
    batch_size = test_batch.shape[0]
    one_hot = torch.zeros(size=(batch_size, 10))
    for i in range(batch_size):
        label = int(test_batch[i])
        one_hot[i][label] = 1
    return one_hot

In [None]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
import torch.nn as nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.flatten = nn.Flatten()
        # Adjust the following line if input dimensions change
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)  # No activation here if using nn.CrossEntropyLoss
        return x


In [None]:
class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, item):
        return self.data[item]
def count_accurate(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    correct_count = 0
    for i in range(len(y_hat)):
        if y_hat[i].type(y.dtype) == y[i]:
            correct_count += 1
    return float(correct_count)
def calc_accuracy(net, data_iter):
    if isinstance(net, torch.nn.Module):
        net.eval()  # 进入评估模式
    metric = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(count_accurate(net(X), y), y.numel())
    return metric[0] / metric[1]

In [None]:
num_epochs = 20
lr1 = 0.5
student_model = MLP()
optimizer_s = torch.optim.SGD(params=student_model.parameters(), lr=lr1)
criterion = F.binary_cross_entropy_with_logits
# student model
print("student model training")
for _ in range(num_epochs):
    total_loss1 = 0
    train_correct = 0
    train_num = 0
    student_model.train()
    for x, y in train_loader:
        one_hot_labels = to_one_hot(y)
        pred1 = student_model(x)
        train_num += one_hot_labels.shape[0]
        train_correct += count_accurate(pred1, y)
        l1 = criterion(pred1, one_hot_labels, reduction='none')
        optimizer_s.zero_grad()
        l1.mean().backward()
        optimizer_s.step()
        total_loss1 += l1.sum()
    train_acc = train_correct / train_num
    test_acc = calc_accuracy(student_model, test_loader)
    print(f'loss:{total_loss1 / train_num}, train_acc:{train_acc}, test_acc:{test_acc}')

In [None]:
num_epochs = 20
lr2 = 0.9
teacher_model = LeNet()
optimizer_t = torch.optim.SGD(params=teacher_model.parameters(), lr=lr2)
# teacher model
print("teacher model training")
for _ in range(num_epochs):
    total_loss2 = 0
    train_correct = 0
    train_num = 0
    teacher_model.train()
    for x, y in train_loader:
        one_hot_labels = to_one_hot(y)
        pred2 = teacher_model(x)
        train_num += one_hot_labels.shape[0]
        train_correct += count_accurate(pred2, y)
        l2 = criterion(pred2, one_hot_labels, reduction='none')
        optimizer_t.zero_grad()
        l2.mean().backward()
        optimizer_t.step()
        total_loss2 += l2.sum()
    train_acc = train_correct / train_num
    test_acc = calc_accuracy(teacher_model, test_loader)
    print(f'loss:{total_loss2 / train_num}, train_acc:{train_acc}, test_acc:{test_acc}')

In [None]:
# distill
num_epochs = 10
lr3 = 0.5
print("distill training")
teacher_model.eval()
student_model = MLP()
optimizer_s = torch.optim.SGD(params=student_model.parameters(), lr=lr3)
for _ in range(num_epochs):
    total_loss1 = 0
    train_correct = 0
    train_num = 0
    student_model.train()
    for x, y in train_loader:
        one_hot_labels = teacher_model(x)
        one_hot_labels = F.softmax(one_hot_labels, dim=1)
        pred1 = student_model(x)
        train_num += one_hot_labels.shape[0]
        train_correct += count_accurate(pred1, y)
        l1 = criterion(pred1, one_hot_labels, reduction='none')
        optimizer_s.zero_grad()
        l1.mean().backward()
        optimizer_s.step()
        total_loss1 += l1.sum()
    train_acc = train_correct / train_num
    test_acc = calc_accuracy(student_model, test_loader)
    print(f'loss:{total_loss1 / train_num}, train_acc:{train_acc}, test_acc:{test_acc}')