How to fuse a list of PyTorch modules. How to compare the performance of a fused model with non-fused version

### Define the example model

In [11]:
import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile
from copy import deepcopy

In [2]:
class AnnotatedConvBnReluModel(nn.Module):
    def __init__(self):
        super(AnnotatedConvBnReluModel, self).__init__()
        self.conv = nn.Conv2d(3, 5, 3, bias=False)
        self.bn = nn.BatchNorm2d(5)
        self.relu = nn.ReLU(inplace=True)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = x.contiguous()
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

### Generate two models with and without fuse modules

In [7]:
torch.backends.quantized.engine = 'qnnpack'

In [12]:
model = AnnotatedConvBnReluModel()

def prepare_save(model, fused):
    m = deepcopy(model)
    model.qconfig = torch.quantization.get_default_qconfig("qnnpack")
    torch.quantization.prepare(m, inplace=True)
    torch.quantization.convert(m, inplace=True)
    torchscript_model = torch.jit.script(m)
    torchscript_model_optimized = optimize_for_mobile(torchscript_model)
    torch.jit.save(torchscript_model_optimized, "model.pt" if not fused else "model_fused.pt")

In [13]:
prepare_save(model, False)



In [18]:
model_fused = torch.quantization.fuse_modules(model, [['bn', 'relu']], inplace=False)

In [20]:
prepare_save(model_fused, True)

