In [1]:
# detection models
detection_model_path = 'model/faster_rcnn_state.pth'

In [2]:
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
from PIL import Image
import cv2
import numpy as np
import torchvision
import os
import torch
from torch.quantization import QuantStub, DeQuantStub

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

In [3]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')

In [4]:
torch.backends.quantized.supported_engines

['none', 'onednn', 'x86', 'fbgemm']

In [5]:
map_location="cpu"
backend = "qnnpack"

In [7]:
# load  a model; pre-trained on COCO
detection_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None)  # (pretrained=False, pretrained_backbone=False)
WEIGHTS_FILE = detection_model_path
num_classes = 23
# get number of input features for the classifier
in_features = detection_model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
detection_model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# Load the traines weights
detection_model.load_state_dict(torch.load(WEIGHTS_FILE))

In [None]:
# detection_model.eval()

In [None]:
detection_model.backbone.qconfig = torch.quantization.get_default_qconfig('qnnpack')
detection_model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(detection_model, inplace=False)
static_quantized_model = torch.quantization.convert(model_static_quantized, inplace=False)

print_model_size(detection_model) # will print original model size
print_model_size(static_quantized_model) ## will print quantized model size



166.15 MB
42.14 MB


In [None]:
script_model_detection = torch.jit.script(static_quantized_model)
script_model_detection.eval()

RecursiveScriptModule(
  original_name=FasterRCNN
  (transform): RecursiveScriptModule(original_name=GeneralizedRCNNTransform)
  (backbone): RecursiveScriptModule(
    original_name=BackboneWithFPN
    (body): RecursiveScriptModule(
      original_name=IntermediateLayerGetter
      (conv1): RecursiveScriptModule(original_name=Conv2d)
      (bn1): RecursiveScriptModule(original_name=BatchNorm2d)
      (relu): RecursiveScriptModule(original_name=ReLU)
      (maxpool): RecursiveScriptModule(original_name=MaxPool2d)
      (layer1): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=Bottleneck
          (conv1): RecursiveScriptModule(original_name=Conv2d)
          (bn1): RecursiveScriptModule(original_name=BatchNorm2d)
          (conv2): RecursiveScriptModule(original_name=Conv2d)
          (bn2): RecursiveScriptModule(original_name=BatchNorm2d)
          (conv3): RecursiveScriptModule(original_name=Conv2d)
          (bn3): R

In [None]:
# quantized model path
# quantized_model_path = '/content/drive/My Drive/models/quantized/detection_quantized_model.pt'
mobile_models_path = '/content/drive/My Drive/models/mobile/'

In [None]:
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(script_model_detection, example)
optimized_traced_model = optimize_for_mobile(traced_script_module, backend='CPU') #backend='CPU'




In [None]:
optimized_traced_model

RecursiveScriptModule(
  original_name=FasterRCNN
  (rpn): RecursiveScriptModule(
    original_name=RegionProposalNetwork
    (anchor_generator): RecursiveScriptModule(original_name=AnchorGenerator)
  )
  (roi_heads): RecursiveScriptModule(
    original_name=RoIHeads
    (box_roi_pool): RecursiveScriptModule(original_name=MultiScaleRoIAlign)
  )
)

In [None]:
# optimized_traced_model._save_for_lite_interpreter(mobile_models_path + "detection_model.pt")
torch.jit.save(optimized_traced_model, mobile_models_path + "detection_model.pt")