### imports

FetchModel

In [26]:
from torch.utils.mobile_optimizer import optimize_for_mobile
import torch
import torchvision

from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy

print(torch.__version__)

2.0.0


In [None]:
torchvision.models.list_models()

In [24]:
model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V1)
model.eval()

# model.half()

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

Convert model to ptl file

In [25]:
traced_script_module = torch.jit.script(model)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("model_fp16.ptl")

RuntimeError: "rsqrt_cpu" not implemented for 'Half'

quantization

In [20]:
model_to_quantize = copy.deepcopy(model)
model_to_quantize.eval()
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
qconfig_dict = {"": qconfig_mapping}
# qconfig_dict[""].weight_dtype = torch.float16

example_input = torch.rand(1, 3, 224, 224)

model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_input)

calibration_data = [torch.randn(1, 3, 224, 224) for _ in range(100)]
for i in range(len(calibration_data)):
   model_prepared(calibration_data[i])


model_quantized = quantize_fx.convert_fx(model_prepared)

In [21]:
traced_script_module = torch.jit.script(model_quantized)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("model_quantized.ptl")



fusion

In [None]:
model_to_quantize = copy.deepcopy(model)
model_fused = quantize_fx.fuse_fx(model_to_quantize)

snipets of code

In [None]:
import copy
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
from torchvision.models import resnet50
fp32_model = resnet50().eval()
model = copy.deepcopy(fp32_model)
# `qconfig` means quantization configuration, it specifies how should we
# observe the activation and weight of an operator
# `qconfig_dict`, specifies the `qconfig` for each operator in the model
# we can specify `qconfig` for certain types of modules
# we can specify `qconfig` for a specific submodule in the model
# we can specify `qconfig` for some functioanl calls in the model
# we can also set `qconfig` to None to skip quantization for some operators
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
# `prepare_fx` inserts observers in the model based on the configuration in `qconfig_dict`
model_prepared = prepare_fx(model, qconfig_dict)
# calibration runs the model with some sample data, which allows observers to record the statistics of
# the activation and weigths of the operators
calibration_data = [torch.randn(1, 3, 224, 224) for _ in range(100)]
for i in range(len(calibration_data)):
   model_prepared(calibration_data[i])
# `convert_fx` converts a calibrated model to a quantized model, this includes inserting
# quantize, dequantize operators to the model and swap floating point operators with quantized operators
model_quantized = convert_fx(copy.deepcopy(model_prepared))
# benchmark
x = torch.randn(1, 3, 224, 224)
%timeit fp32_model(x)
%timeit model_quantized(x)

In [None]:
torchvision.models.alexnet()
torchvision.models.convnext_base()
torchvision.models.convnext_large()
torchvision.models.convnext_small()
torchvision.models.convnext_tiny()
torchvision.models.deeplabv3_mobilenet_v3_large()
torchvision.models.deeplabv3_resnet101()
torchvision.models.deeplabv3_resnet50()
torchvision.models.densenet121()
torchvision.models.densenet161()
torchvision.models.densenet169()
torchvision.models.densenet201()
torchvision.models.efficientnet_b0()
torchvision.models.efficientnet_b1()
torchvision.models.efficientnet_b2()
torchvision.models.efficientnet_b3()
torchvision.models.efficientnet_b4()
torchvision.models.efficientnet_b5()
torchvision.models.efficientnet_b6()
torchvision.models.efficientnet_b7()
torchvision.models.efficientnet_v2_l()
torchvision.models.efficientnet_v2_m()
torchvision.models.efficientnet_v2_s()
torchvision.models.fasterrcnn_mobilenet_v3_large_320_fpn()
torchvision.models.fasterrcnn_mobilenet_v3_large_fpn()
torchvision.models.fasterrcnn_resnet50_fpn()
torchvision.models.fasterrcnn_resnet50_fpn_v2()
torchvision.models.fcn_resnet101()
torchvision.models.fcn_resnet50()
torchvision.models.fcos_resnet50_fpn()
torchvision.models.googlenet()
torchvision.models.inception_v3()
torchvision.models.keypointrcnn_resnet50_fpn()
torchvision.models.lraspp_mobilenet_v3_large()
torchvision.models.maskrcnn_resnet50_fpn()
torchvision.models.maskrcnn_resnet50_fpn_v2()
torchvision.models.maxvit_t()
torchvision.models.mc3_18()
torchvision.models.mnasnet0_5()
torchvision.models.mnasnet0_75()
torchvision.models.mnasnet1_0()
torchvision.models.mnasnet1_3()
torchvision.models.mobilenet_v2()
torchvision.models.mobilenet_v3_large()
torchvision.models.mobilenet_v3_small()
torchvision.models.mvit_v1_b()
torchvision.models.mvit_v2_s()
torchvision.models.quantized_googlenet()
torchvision.models.quantized_inception_v3()
torchvision.models.quantized_mobilenet_v2()
torchvision.models.quantized_mobilenet_v3_large()
torchvision.models.quantized_resnet18()
torchvision.models.quantized_resnet50()
torchvision.models.quantized_resnext101_32x8d()
torchvision.models.quantized_resnext101_64x4d()
torchvision.models.quantized_shufflenet_v2_x0_5()
torchvision.models.quantized_shufflenet_v2_x1_0()
torchvision.models.quantized_shufflenet_v2_x1_5()
torchvision.models.quantized_shufflenet_v2_x2_0()
torchvision.models.r2plus1d_18()
torchvision.models.r3d_18()
torchvision.models.raft_large()
torchvision.models.raft_small()
torchvision.models.regnet_x_16gf()
torchvision.models.regnet_x_1_6gf()
torchvision.models.regnet_x_32gf()
torchvision.models.regnet_x_3_2gf()
torchvision.models.regnet_x_400mf()
torchvision.models.regnet_x_800mf()
torchvision.models.regnet_x_8gf()
torchvision.models.regnet_y_128gf()
torchvision.models.regnet_y_16gf()
torchvision.models.regnet_y_1_6gf()
torchvision.models.regnet_y_32gf()
torchvision.models.regnet_y_3_2gf()
torchvision.models.regnet_y_400mf()
torchvision.models.regnet_y_800mf()
torchvision.models.regnet_y_8gf()
torchvision.models.resnet101()
torchvision.models.resnet152()
torchvision.models.resnet18()
torchvision.models.resnet34()
torchvision.models.resnet50()
torchvision.models.resnext101_32x8d()
torchvision.models.resnext101_64x4d()
torchvision.models.resnext50_32x4d()
torchvision.models.retinanet_resnet50_fpn()
torchvision.models.retinanet_resnet50_fpn_v2()
torchvision.models.s3d()
torchvision.models.shufflenet_v2_x0_5()
torchvision.models.shufflenet_v2_x1_0()
torchvision.models.shufflenet_v2_x1_5()
torchvision.models.shufflenet_v2_x2_0()
torchvision.models.squeezenet1_0()
torchvision.models.squeezenet1_1()
torchvision.models.ssd300_vgg16()
torchvision.models.ssdlite320_mobilenet_v3_large()
torchvision.models.swin3d_b()
torchvision.models.swin3d_s()
torchvision.models.swin3d_t()
torchvision.models.swin_b()
torchvision.models.swin_s()
torchvision.models.swin_t()
torchvision.models.swin_v2_b()
torchvision.models.swin_v2_s()
torchvision.models.swin_v2_t()
torchvision.models.vgg11()
torchvision.models.vgg11_bn()
torchvision.models.vgg13()
torchvision.models.vgg13_bn()
torchvision.models.vgg16()
torchvision.models.vgg16_bn()
torchvision.models.vgg19()
torchvision.models.vgg19_bn()
torchvision.models.vit_b_16()
torchvision.models.vit_b_32()
torchvision.models.vit_h_14()
torchvision.models.vit_l_16()
torchvision.models.vit_l_32()
torchvision.models.wide_resnet101_2()
torchvision.models.wide_resnet50_2()
