In [2]:
import torch
from torch.ao.quantization import get_default_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
import torchvision.models as models

# 加载一个预训练的 ResNet18 模型
float_model = models.resnet18(pretrained=True)

# 确保模型处于评估模式
float_model.eval()

# 创建一个符合模型输入要求的样例输入
# 假设我们使用一个单张3通道224x224的图片
example_inputs = torch.rand(1, 3, 224, 224)

# 设置量化配置
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)

# 准备模型：融合模块并插入观察者
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)

# 动态量化不需要校准步骤
# 将准备好的模型转换为量化模型
quantized_model = convert_fx(prepared_model)

# quantized_model 现在是量化后的模型
print(quantized_model)

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


GraphModule(
  (conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=1.0, zero_point=0, padding=(3, 3))
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
    )
    (1): Module(
      (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
    )
  )
  (layer2): Module(
    (0): Module(
      (conv1): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1))
      (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zer

In [3]:
# 保存状态字典
torch.save(quantized_model.state_dict(), 'quantized_model_state_dict.pth')

In [4]:
# 将模型转换为 TorchScript
scripted_quantized_model = torch.jit.script(quantized_model)

# 保存序列化的模型
torch.jit.save(scripted_quantized_model, 'quantized_model_scripted.pth')


