In [2]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch.ao.quantization import get_default_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
import torchvision.models as models

# 数据集的转换操作
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 将图片大小调整为 224 x 224
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

# 下载 CIFAR10 训练集
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)

# 下载 CIFAR10 测试集（用于校准）
testset = CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32, shuffle=False)

# 加载一个预训练的 ResNet 模型
float_model = models.resnet18(pretrained=True)
float_model.eval()

# 设置量化配置
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)

# 定义校准函数
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)

# 获取一个样例输入
example_inputs = (next(iter(trainloader))[0])

# 准备模型：融合模块并插入观察者
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)

# 校准模型
calibrate(prepared_model, testloader)

# 将校准后的模型转换为量化模型
quantized_model = convert_fx(prepared_model)

# quantized_model 现在是量化后的模型


Files already downloaded and verified


  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


In [3]:
print(quantized_model)

GraphModule(
  (conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.010527534410357475, zero_point=0, padding=(3, 3))
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.006821439601480961, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.028011683374643326, zero_point=63, padding=(1, 1))
    )
    (1): Module(
      (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.009053519926965237, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.03019980899989605, zero_point=70, padding=(1, 1))
    )
  )
  (layer2): Module(
    (0): Module(
      (conv1): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=0.008322453126311302, zero_point=0, pad

In [4]:
# 保存状态字典
torch.save(quantized_model.state_dict(), 'quantized_model_state_dict.pth')


In [5]:
# 将模型转换为 TorchScript
scripted_quantized_model = torch.jit.script(quantized_model)

# 保存序列化的模型
torch.jit.save(scripted_quantized_model, 'quantized_model_scripted.pth')


