In [1]:
import torch
from torch import nn
import torchvision
from torchvision import datasets, transforms
from torch.utils import data
from torch.nn import functional as F

In [2]:
def load_data_from_FashionMNIST(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))  # 用于调整图片大小
    trans = transforms.Compose(trans)  # compose用于将多个transform组合操作
    train_set = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=trans, download=True)
    test_set = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=trans, download=True)
    workers_num = 4
    return data.DataLoader(train_set, batch_size, shuffle=True, num_workers=workers_num), \
        data.DataLoader(test_set, batch_size, shuffle=False, num_workers=workers_num)

In [3]:
train_loader, test_loader = load_data_from_FashionMNIST(256)

In [4]:
def to_one_hot(test_batch):
    batch_size = test_batch.shape[0]
    one_hot = torch.zeros(size=(batch_size, 10))
    for i in range(batch_size):
        label = int(test_batch[i])
        one_hot[i][label] = 1
    return one_hot

In [5]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
import torch.nn as nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.flatten = nn.Flatten()
        # Adjust the following line if input dimensions change
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)  # No activation here if using nn.CrossEntropyLoss
        return x


In [6]:
class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, item):
        return self.data[item]
def count_accurate(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    correct_count = 0
    for i in range(len(y_hat)):
        if y_hat[i].type(y.dtype) == y[i]:
            correct_count += 1
    return float(correct_count)
def calc_accuracy(net, data_iter):
    if isinstance(net, torch.nn.Module):
        net.eval()  # 进入评估模式
    metric = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(count_accurate(net(X), y), y.numel())
    return metric[0] / metric[1]

In [7]:
num_epochs = 20
lr1 = 0.1
student_model = MLP()
optimizer_s = torch.optim.SGD(params=student_model.parameters(), lr=lr1)
criterion = F.binary_cross_entropy_with_logits
# student model
print("student model training")
for _ in range(num_epochs):
    total_loss1 = 0
    train_correct = 0
    train_num = 0
    student_model.train()
    for x, y in train_loader:
        one_hot_labels = to_one_hot(y)
        pred1 = student_model(x)
        train_num += one_hot_labels.shape[0]
        train_correct += count_accurate(pred1, y)
        l1 = criterion(pred1, one_hot_labels, reduction='none')
        optimizer_s.zero_grad()
        l1.mean().backward()
        optimizer_s.step()
        total_loss1 += l1.sum()
    train_acc = train_correct / train_num
    test_acc = calc_accuracy(student_model, test_loader)
    print(f'loss:{total_loss1 / train_num}, train_acc:{train_acc}, test_acc:{test_acc}')

student model training
loss:3.0192267894744873, train_acc:0.5307166666666666, test_acc:0.6516
loss:1.970055103302002, train_acc:0.6667166666666666, test_acc:0.6588
loss:1.626470923423767, train_acc:0.6799, test_acc:0.6712
loss:1.4566562175750732, train_acc:0.7018833333333333, test_acc:0.7057
loss:1.3486512899398804, train_acc:0.73075, test_acc:0.7297
loss:1.2700363397598267, train_acc:0.7522333333333333, test_acc:0.7481
loss:1.2082586288452148, train_acc:0.76695, test_acc:0.7632
loss:1.1580709218978882, train_acc:0.7793833333333333, test_acc:0.7731
loss:1.1164493560791016, train_acc:0.7902333333333333, test_acc:0.7837
loss:1.0817862749099731, train_acc:0.7976, test_acc:0.7856
loss:1.0526635646820068, train_acc:0.80405, test_acc:0.7944
loss:1.0276020765304565, train_acc:0.8084666666666667, test_acc:0.7977
loss:1.0055186748504639, train_acc:0.8128, test_acc:0.8028
loss:0.9866871237754822, train_acc:0.8168833333333333, test_acc:0.8082
loss:0.9700549244880676, train_acc:0.8195166666666667,

In [8]:
num_epochs = 20
lr2 = 0.9
teacher_model = LeNet()
optimizer_t = torch.optim.SGD(params=teacher_model.parameters(), lr=lr2)
# teacher model
print("teacher model training")
for _ in range(num_epochs):
    total_loss2 = 0
    train_correct = 0
    train_num = 0
    teacher_model.train()
    for x, y in train_loader:
        one_hot_labels = to_one_hot(y)
        pred2 = teacher_model(x)
        train_num += one_hot_labels.shape[0]
        train_correct += count_accurate(pred2, y)
        l2 = criterion(pred2, one_hot_labels, reduction='none')
        optimizer_t.zero_grad()
        l2.mean().backward()
        optimizer_t.step()
        total_loss2 += l2.sum()
    train_acc = train_correct / train_num
    test_acc = calc_accuracy(teacher_model, test_loader)
    print(f'loss:{total_loss2 / train_num}, train_acc:{train_acc}, test_acc:{test_acc}')

teacher model training
loss:2.3286850452423096, train_acc:0.47696666666666665, test_acc:0.7064
loss:1.103438377380371, train_acc:0.75505, test_acc:0.7658
loss:0.9096980690956116, train_acc:0.8043166666666667, test_acc:0.7987
loss:0.8129816651344299, train_acc:0.829, test_acc:0.8345
loss:0.741916835308075, train_acc:0.8459166666666667, test_acc:0.8529
loss:0.692064106464386, train_acc:0.85575, test_acc:0.8297
loss:0.6487722396850586, train_acc:0.8658333333333333, test_acc:0.8579
loss:0.622210681438446, train_acc:0.8725666666666667, test_acc:0.8541
loss:0.594935953617096, train_acc:0.8773666666666666, test_acc:0.8591
loss:0.5793455839157104, train_acc:0.8813, test_acc:0.8544
loss:0.5590655207633972, train_acc:0.88565, test_acc:0.8719
loss:0.5408010482788086, train_acc:0.8895166666666666, test_acc:0.8777
loss:0.5255268812179565, train_acc:0.8934166666666666, test_acc:0.8793
loss:0.5084717273712158, train_acc:0.8966833333333334, test_acc:0.8576
loss:0.49864572286605835, train_acc:0.8981333

In [11]:
# distill
num_epochs = 20
print("distill training")
teacher_model.eval()
student_model = MLP()
optimizer_s = torch.optim.SGD(params=student_model.parameters(), lr=0.5)
for _ in range(num_epochs):
    total_loss1 = 0
    train_correct = 0
    train_num = 0
    student_model.train()
    for x, y in train_loader:
        one_hot_labels = teacher_model(x)
        one_hot_labels = F.softmax(one_hot_labels, dim=1)
        pred1 = student_model(x)
        train_num += one_hot_labels.shape[0]
        train_correct += count_accurate(pred1, y)
        l1 = criterion(pred1, one_hot_labels, reduction='none')
        optimizer_s.zero_grad()
        l1.mean().backward()
        optimizer_s.step()
        total_loss1 += l1.sum()
    train_acc = train_correct / train_num
    test_acc = calc_accuracy(student_model, test_loader)
    print(f'loss:{total_loss1 / train_num}, train_acc:{train_acc}, test_acc:{test_acc}')

distill training
loss:1.8577848672866821, train_acc:0.6235666666666667, test_acc:0.6999
loss:1.1260950565338135, train_acc:0.7608666666666667, test_acc:0.7715
loss:0.9580980539321899, train_acc:0.7983333333333333, test_acc:0.798
loss:0.8693897724151611, train_acc:0.8115666666666667, test_acc:0.803
loss:0.8128560185432434, train_acc:0.8203666666666667, test_acc:0.7985
loss:0.7778867483139038, train_acc:0.8237, test_acc:0.8148
loss:0.7482032179832458, train_acc:0.8277833333333333, test_acc:0.8106
loss:0.7225842475891113, train_acc:0.8323333333333334, test_acc:0.8185
loss:0.7027120590209961, train_acc:0.8356666666666667, test_acc:0.8201
loss:0.6861705183982849, train_acc:0.8387, test_acc:0.8196
loss:0.6714637279510498, train_acc:0.8411166666666666, test_acc:0.8333
loss:0.6561890244483948, train_acc:0.8437166666666667, test_acc:0.8343
loss:0.6449607610702515, train_acc:0.8450333333333333, test_acc:0.8279
loss:0.6346712708473206, train_acc:0.8467666666666667, test_acc:0.8338
loss:0.62397187