In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

In [2]:
# Define TemperatureSoftmax
class TemperatureSoftmax(nn.Module):
    def __init__(self):
        super(TemperatureSoftmax, self).__init__()

    def forward(self, input_tensor):
        # 对 dim=1 进行带有温度系数的 softmax
        return torch.softmax(input_tensor / 1, dim=1)

# Define teacher model
teacher = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Dropout(0.25),
    nn.Conv2d(32, 64, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Dropout(0.25),
    nn.Conv2d(64, 128, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Dropout(0.25),
    nn.Flatten(),
    nn.Linear(128, 625),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(625, 10),
    TemperatureSoftmax(),
)

# X = torch.randn(1, 1, 28, 28)
# for layer in teacher:
    # X = layer(X)
    # print(layer.__class__.__name__, "output shape:\t", X.shape)

In [3]:
# Define student model
student = nn.Sequential(
    nn.Flatten(), nn.Linear(28 * 28, 1000), nn.Sigmoid(),nn.Dropout(0.5), nn.Linear(1000, 10)
)

In [4]:
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)

student.apply(weights_init)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=1000, bias=True)
  (2): Sigmoid()
  (3): Dropout(p=0.5, inplace=False)
  (4): Linear(in_features=1000, out_features=10, bias=True)
)

In [None]:
# Load MNIST dataset and create data loaders
import data_loader

train_iter, test_iter = data_loader.load_data_fashion_mnist(batch_size=128)

In [13]:
for X, y in train_iter:
    print(X.shape, y.shape)
    break

torch.Size([128, 1, 28, 28]) torch.Size([128])


In [None]:
from d2l import torch as d2l
d2l.load_data_fashion_mnist(batch_size=128, resize=None)

In [6]:
# Move models and data to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = teacher.to(device)
student = student.to(device)

In [7]:
# Define loss and optimizer
criterion_teacher = nn.CrossEntropyLoss()
criterion_student = nn.CrossEntropyLoss()
optimizer_teacher = optim.RMSprop(teacher.parameters(), lr=1e-4)
optimizer_student = optim.SGD(student.parameters(), lr=0.01)

In [8]:
# Define teacher evaluation function
def evaluate_teacher(net, criterion, testloader, device):
    net.eval()	# Set model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in testloader:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            loss = criterion(y_hat, y)
            
            running_loss += loss.item()
            _, predicted = y_hat.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    print(f"Test loss = {running_loss / len(testloader)}, accuracy = {correct / total}")

In [9]:
# Define teacher training function
def train_teacher(net, optimizer, criterion, trainloader, testloader, device, num_epochs=10):
    for epoch in range(num_epochs):
        net.train()	# Set model to training mode
        running_loss = 0.0
        for X, y in trainloader:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            loss = criterion(y_hat, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, loss = {running_loss / len(trainloader)}")
        if epoch % 5 == 5 - 1:
            evaluate_teacher(net, criterion, testloader, device)

In [10]:
train_teacher(teacher, optimizer_teacher, criterion_teacher, trainloader, testloader, device)

Epoch 1, loss = 1.8600780948647049
Epoch 2, loss = 1.6323614804221114
Epoch 3, loss = 1.575677398171252
Epoch 4, loss = 1.5538970979292002
Epoch 5, loss = 1.540759448303597
Test loss = 1.508636889578421, accuracy = 0.9548
Epoch 6, loss = 1.5320327200615076
Epoch 7, loss = 1.5264999281877139
Epoch 8, loss = 1.5212465628886274
Epoch 9, loss = 1.5161578268892983
Epoch 10, loss = 1.5114354317122176
Test loss = 1.491573232638685, accuracy = 0.9701


In [11]:
# Define student evaluation function
def evaluate_student(net, criterion, testloader, device):
    net.eval()	# Set model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in testloader:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            loss = criterion(y_hat, y)
            
            running_loss += loss.item()
            _, predicted = y_hat.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    print(f"Test loss = {running_loss / len(testloader)}, accuracy = {correct / total}")

In [12]:
# Define student training function
def train_student(net, optimizer, criterion, trainloader,testloader, device, num_epochs=10):
    for epoch in range(num_epochs):
        net.train()	# Set model to training mode
        running_loss = 0.0
        for X, y in trainloader:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            loss = criterion(y_hat, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, loss = {running_loss / len(trainloader)}")
        if epoch % 5 == 5 - 1:
            evaluate_student(net, criterion, testloader, device)

In [13]:
train_student(student, optimizer_student, criterion_student, trainloader, testloader, device)

Epoch 1, loss = 1.629913806406928
Epoch 2, loss = 0.914098279054231
Epoch 3, loss = 0.7158152863287977
Epoch 4, loss = 0.6220535689960919
Epoch 5, loss = 0.5684157444088698
Test loss = 0.4377809646952001, accuracy = 0.8846
Epoch 6, loss = 0.5328225425438586
Epoch 7, loss = 0.5026534875191605
Epoch 8, loss = 0.48351090104341
Epoch 9, loss = 0.464187063642148
Epoch 10, loss = 0.45269791257661035
Test loss = 0.36054175620592094, accuracy = 0.8988


In [14]:
def train_distill(teacher, student, optimizer_student, criterion_student, trainloader, testloader, device, num_epochs=10):
    for epoch in range(num_epochs):
        teacher.eval()
        student.train()
        running_loss = 0.0
        for X, y in trainloader:
            X = X.to(device)
            y = y.to(device)
            y_teacher = teacher(X)
            y_hat = student(X)
            loss = 5 * criterion_student(y_hat, y) + F.kl_div(y_hat.log(), y_teacher, reduction="batchmean")
            
            optimizer_student.zero_grad()
            loss.backward()
            optimizer_student.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, loss = {running_loss / len(trainloader)}")
        if epoch % 5 == 5 - 1:
            evaluate_student(student, criterion_student, testloader, device)

In [15]:
train_distill(teacher, student, optimizer_student, criterion_student, trainloader, testloader, device, num_epochs=50)

Epoch 1, loss = nan
Epoch 2, loss = nan
Epoch 3, loss = nan
Epoch 4, loss = -1.4691381464634876
Epoch 5, loss = -1.658426731380064
Test loss = 0.2927684231390116, accuracy = 0.9144
Epoch 6, loss = -1.822038261112628
Epoch 7, loss = -1.9480363682134827
Epoch 8, loss = -2.0780909328318353
Epoch 9, loss = -2.1798136198698583
Epoch 10, loss = -2.3125697112540955
Test loss = 0.2574114916209556, accuracy = 0.925
Epoch 11, loss = -2.3981384280393883
Epoch 12, loss = -2.477684957386334
Epoch 13, loss = -2.592115228364208
Epoch 14, loss = -2.667604061077907
Epoch 15, loss = -2.7550074886411493
Test loss = 0.23327444908739645, accuracy = 0.9301
Epoch 16, loss = -2.8285503107855825
Epoch 17, loss = -2.8967212115777836
Epoch 18, loss = -2.971695560382119
Epoch 19, loss = -3.039699140133888
Epoch 20, loss = -3.09438220880179
Test loss = 0.20313910609464855, accuracy = 0.9407
Epoch 21, loss = -3.162474264722389
Epoch 22, loss = -3.232210760685935
Epoch 23, loss = -3.28618391185427
Epoch 24, loss = -