# 先下载CIFAR-10数据集

In [2]:
import torch
import torchvision
from torchvision.transforms import transforms

from utils.criterion import test_model, count_parameters
from utils.micro_ghost_resnet import MicroResNetGhost

# 加载CIFAR-10数据集

model_path = "../new_model_weights/microresnet_cifar10_best.pth"

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 随机裁剪
    transforms.RandomHorizontalFlip(),      # 随机水平翻转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR10实际均值方差
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 下载数据集
trainset = torchvision.datasets.CIFAR10(
    root='../data', train=True, download=True, transform=transform_train
)
testset = torchvision.datasets.CIFAR10(
    root='../data', train=False, download=True, transform=transform_test
)

# 数据加载器
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

In [3]:
from utils.micro_resnet import MicroResNet
from utils.res_net import ResNet20
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

teacher_model = ResNet20(num_classes=10).to(device)
teacher_model.load_state_dict(torch.load("../new_model_weights/resnet20_cifar10_best.pth"))

vanilla_student_model = MicroResNet(num_classes=10).to(device)
vanilla_student_model.load_state_dict(torch.load("../new_model_weights/microresnet_cifar10_best.pth"))

base_student_model = MicroResNet(num_classes=10).to(device)
base_student_model.load_state_dict(torch.load("../new_model_weights/microresnet_cifar10_distill_best.pth"))

progressive_student_model = MicroResNet(num_classes=10).to(device)
progressive_student_model.load_state_dict(torch.load("../new_model_weights/progressive_distillation_result.pth"))

feature_adaptive_student_model = MicroResNet(num_classes=10).to(device)
feature_adaptive_student_model.load_state_dict(torch.load("../new_model_weights/feature_adaptive_distillation_result.pth"))

ghost_model = MicroResNetGhost(num_classes=10).to(device)
ghost_model.load_state_dict(torch.load("../new_model_weights/microresnet_ghost_cifar10_distill_best.pth"))

print("The size of teacher model:", count_parameters(teacher_model))
print("The size of student model", count_parameters(vanilla_student_model))
print("The size of ghost student model", count_parameters(ghost_model))
print("=" * 60)
print("The best accuracy teacher model can do:", test_model(teacher_model, testloader, device))
print("The best accuracy student model can do:", test_model(vanilla_student_model, testloader, device))
print("The best accuracy distilled student model can do:", test_model(base_student_model, testloader, device))
print("The best accuracy progressive distilled student model can do:", test_model(progressive_student_model, testloader, device))
print("The best accuracy feature-based distilled student model can do:", test_model(feature_adaptive_student_model, testloader, device))
print("The best accuracy ghost student model can do:", test_model(ghost_model, testloader, device))


The size of teacher model: 17444682
The size of student model 75290
The size of ghost student model 11098
The best accuracy teacher model can do: 94.01
The best accuracy student model can do: 83.45
The best accuracy distilled student model can do: 84.7
The best accuracy progressive distilled student model can do: 84.05
The best accuracy feature-based distilled student model can do: 83.31
The best accuracy ghost student model can do: 83.25
