### Based on: https://www.youtube.com/watch?v=5Lxuu16_28o

### Imports

In [1]:
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
from collections import OrderedDict

### Model

In [2]:
class FusableModel(torch.nn.Module):
    def __init__(self):
        super(FusableModel, self).__init__()

        self.quant = torch.quantization.QuantStub()

        # Assumes n x 3 x 224 x 224 input
        self.conv_bn_relu = torch.nn.Sequential(OrderedDict([
            ('conv', torch.nn.Conv2d(3, 5, (3, 3), bias=False).to(dtype=torch.float)),
            ('bn', torch.nn.BatchNorm2d(5).to(dtype=torch.float)),
            ('relu', torch.nn.ReLU(inplace=True))
        ]))
        self.linear = torch.nn.Linear(5 * 222 * 222, 100)
        self.relu = torch.nn.ReLU(inplace=True)

        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv_bn_relu(x)
        x = x.view(-1, 5 * 222 * 222)
        x = self.linear(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

In [3]:
model = FusableModel()
model.eval()
print(model)

FusableModel(
  (quant): QuantStub()
  (conv_bn_relu): Sequential(
    (conv): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (linear): Linear(in_features=246420, out_features=100, bias=True)
  (relu): ReLU(inplace=True)
  (dequant): DeQuantStub()
)


### Fuse layers

In [4]:
modules_to_fuse = [['conv_bn_relu.conv', 'conv_bn_relu.bn', 'conv_bn_relu.relu'], ['linear', 'relu']]
torch.quantization.fuse_modules(model, modules_to_fuse, inplace=True)
print(model)

FusableModel(
  (quant): QuantStub()
  (conv_bn_relu): Sequential(
    (conv): ConvReLU2d(
      (0): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU(inplace=True)
    )
    (bn): Identity()
    (relu): Identity()
  )
  (linear): LinearReLU(
    (0): Linear(in_features=246420, out_features=100, bias=True)
    (1): ReLU(inplace=True)
  )
  (relu): Identity()
  (dequant): DeQuantStub()
)


### Quantize

In [5]:
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')

torch.quantization.prepare(model, inplace=True)

def calibrate(model, calibration_data):
    # Calibration code
    return

calibrate(model, [])

model = torch.quantization.convert(model, inplace=True)



### Export to TorchScript and Optimize

In [6]:
torchscript_model = torch.jit.script(model)

torchscript_model_optimized = optimize_for_mobile(torchscript_model)

torch.jit.save(torchscript_model_optimized, 'model.pt')