In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

In [None]:
# Define TemperatureSoftmax
class TemperatureSoftmax(nn.Module):
    def __init__(self):
        super(TemperatureSoftmax, self).__init__()

    def forward(self, input_tensor):
        # 对 dim=1 进行带有温度系数的 softmax
        return torch.softmax(input_tensor / 1, dim=1)

# Define teacher model
teacher = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Dropout(0.25),
    nn.Conv2d(32, 64, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Dropout(0.25),
    nn.Conv2d(64, 128, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Dropout(0.25),
    nn.Flatten(),
    nn.Linear(128, 625),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(625, 10),
    TemperatureSoftmax(),
)

# X = torch.randn(1, 1, 28, 28)
# for layer in teacher:
    # X = layer(X)
    # print(layer.__class__.__name__, "output shape:\t", X.shape)

In [None]:
# Define student model
student = nn.Sequential(
    nn.Flatten(), nn.Linear(28 * 28, 1000), nn.Sigmoid(),nn.Dropout(0.5), nn.Linear(1000, 10)
)

In [None]:
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)

student.apply(weights_init)

In [None]:
# Load MNIST dataset and create data loaders
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root="../data", train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root="../data", train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)

In [None]:
# Move models and data to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = teacher.to(device)
student = student.to(device)

In [None]:
# Define loss and optimizer
criterion_teacher = nn.CrossEntropyLoss()
criterion_student = nn.CrossEntropyLoss()
optimizer_teacher = optim.RMSprop(teacher.parameters(), lr=1e-4)
optimizer_student = optim.SGD(student.parameters(), lr=0.01)

In [None]:
# Define teacher evaluation function
def evaluate_teacher(net, criterion, testloader, device):
    net.eval()	# Set model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in testloader:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            loss = criterion(y_hat, y)
            
            running_loss += loss.item()
            _, predicted = y_hat.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    print(f"Test loss = {running_loss / len(testloader)}, accuracy = {correct / total}")

In [None]:
# Define teacher training function
def train_teacher(net, optimizer, criterion, trainloader, testloader, device, num_epochs=10):
    for epoch in range(num_epochs):
        net.train()	# Set model to training mode
        running_loss = 0.0
        for X, y in trainloader:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            loss = criterion(y_hat, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, loss = {running_loss / len(trainloader)}")
        if epoch % 5 == 5 - 1:
            evaluate_teacher(net, criterion, testloader, device)

In [None]:
train_teacher(teacher, optimizer_teacher, criterion_teacher, trainloader, testloader, device)

In [None]:
# Define student evaluation function
def evaluate_student(net, criterion, testloader, device):
    net.eval()	# Set model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in testloader:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            loss = criterion(y_hat, y)
            
            running_loss += loss.item()
            _, predicted = y_hat.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    print(f"Test loss = {running_loss / len(testloader)}, accuracy = {correct / total}")

In [None]:
# Define student training function
def train_student(net, optimizer, criterion, trainloader,testloader, device, num_epochs=10):
    for epoch in range(num_epochs):
        net.train()	# Set model to training mode
        running_loss = 0.0
        for X, y in trainloader:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            loss = criterion(y_hat, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, loss = {running_loss / len(trainloader)}")
        if epoch % 5 == 5 - 1:
            evaluate_student(net, criterion, testloader, device)

In [None]:
train_student(student, optimizer_student, criterion_student, trainloader, testloader, device)

In [None]:
def train_distill(teacher, student, optimizer_student, criterion_student, trainloader, testloader, device, num_epochs=10):
    for epoch in range(num_epochs):
        teacher.eval()
        student.train()
        running_loss = 0.0
        for X, y in trainloader:
            X = X.to(device)
            y = y.to(device)
            y_teacher = teacher(X)
            y_hat = student(X)
            loss = 5 * criterion_student(y_hat, y) + F.kl_div(y_hat.log(), y_teacher, reduction="batchmean")
            
            optimizer_student.zero_grad()
            loss.backward()
            optimizer_student.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, loss = {running_loss / len(trainloader)}")
        if epoch % 5 == 5 - 1:
            evaluate_student(student, criterion_student, testloader, device)

In [None]:
train_distill(teacher, student, optimizer_student, criterion_student, trainloader, testloader, device, num_epochs=50)