In [1]:
import torchvision.models.quantization as models

# You will need the number of filters in the `fc` for future use.
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_fe = models.resnet18(pretrained=True, progress=True, quantize=True)
num_ftrs = model_fe.fc.in_features

Downloading: "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth" to /home/ss/.cache/torch/hub/checkpoints/resnet18_fbgemm_16fa66dd.pth
100%|██████████| 11.2M/11.2M [00:01<00:00, 6.10MB/s]
  scales = torch.tensor(scales, dtype=torch.double, device=storage.device)


In [4]:
import torch

In [13]:
from torch import nn

def create_combined_model(model_fe):
  # Step 1. Isolate the feature extractor.
  model_fe_features = nn.Sequential(
    model_fe.quant,  # Quantize the input
    model_fe.conv1,
    model_fe.bn1,
    model_fe.relu,
    model_fe.maxpool,
    model_fe.layer1,
    model_fe.layer2,
    model_fe.layer3,
    model_fe.layer4,
    model_fe.avgpool,
    model_fe.dequant,  # Dequantize the output
  )

  # Step 2. Create a new "head"
  new_head = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(num_ftrs, 2),
  )

  # Step 3. Combine, and don't forget the quant stubs.
  new_model = nn.Sequential(
    model_fe_features,
    nn.Flatten(1),
    new_head,
  )
  
  print(new_model)
  
  return new_model

In [14]:
new_model = create_combined_model(model_fe)

Sequential(
  (0): Sequential(
    (0): Quantize(scale=tensor([0.0374]), zero_point=tensor([57]), dtype=torch.quint8)
    (1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.028605546802282333, zero_point=0, padding=(3, 3))
    (2): Identity()
    (3): Identity()
    (4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (5): Sequential(
      (0): QuantizableBasicBlock(
        (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.016524722799658775, zero_point=0, padding=(1, 1))
        (bn1): Identity()
        (relu): Identity()
        (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.04645531252026558, zero_point=75, padding=(1, 1))
        (bn2): Identity()
        (add_relu): QFunctional(
          scale=0.03447607904672623, zero_point=0
          (activation_post_process): Identity()
        )
      )
      (1): QuantizableBasicBlock(
        (conv1): QuantizedConvReLU2d(64,

In [10]:
new_model.eval()

x = new_model(torch.randn(1,3,224,224))

In [11]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')

In [12]:
print_model_size(new_model)

11.30 MB


In [16]:
import torch

# Load the model
model = torch.jit.load("./trained_mobile_net.pt")

# Set the model to evaluation mode

In [17]:

model.eval()

# Get an example input
example = torch.randn(1, 3, 28, 28)

# Make a prediction
output = model(example)

In [18]:
output

tensor([[0.2116, 0.7884]], grad_fn=<SoftmaxBackward0>)

In [19]:
import os

In [None]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')

In [20]:
print_model_size(model)

11.76 MB


In [21]:
model_dynamic_quantized = torch.quantization.quantize_dynamic(
    model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8
)

In [22]:
print_model_size(model_dynamic_quantized)

11.76 MB


In [24]:
model.eval()



# Specify quantization configuration
# Start with simple min/max range estimation and per-tensor quantization of weights
model.qconfig = torch.ao.quantization.default_qconfig
print(model.qconfig)
torch.ao.quantization.prepare(model, inplace=True)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})


RecursiveScriptModule(
  original_name=MobileNet
  (model): RecursiveScriptModule(
    original_name=MobileNetV2
    (features): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(
        original_name=Conv2dNormActivation
        (0): RecursiveScriptModule(original_name=Conv2d)
        (1): RecursiveScriptModule(original_name=BatchNorm2d)
        (2): RecursiveScriptModule(original_name=ReLU6)
      )
      (1): RecursiveScriptModule(
        original_name=InvertedResidual
        (conv): RecursiveScriptModule(
          original_name=Sequential
          (0): RecursiveScriptModule(
            original_name=Conv2dNormActivation
            (0): RecursiveScriptModule(original_name=Conv2d)
            (1): RecursiveScriptModule(original_name=BatchNorm2d)
            (2): RecursiveScriptModule(original_name=ReLU6)
          )
          (1): RecursiveScriptModule(original_name=Conv2d)
          (2): RecursiveScriptModule(original_name=BatchNorm2d)
   

In [25]:
qmodel = torch.ao.quantization.convert(model, inplace=False)

In [26]:
print_model_size(qmodel)

11.76 MB
