# Model Quantization
> "How to use Pytorch quantization API for model quantization"

- toc: false
- branch: master
- badges: false
- comments: true
- author: Atmadeep Banerjee
- use_math: true

##  What?

So you have trained a neural network and want to deploy it. Performance — speed and computational complexity, *not* just accuracy — matters a *lot* when in production. If your model can achieve low enough latencies on a cpu instance, you will have a massively lower deployment cost over using a gpu instance. Lower costs equals higher profits.

Model quantization is (usually) the easiest way to massively speed up your model. If you want to learn more about the theory behind quantization and how it works check out this [blogpost](https://sharechat.com/blogs/data-science/neural-network-compression-using-quantization). Feeling too lazy to read through all that? Here’s a quick summary. Quantization provides us a way to compress the weights of our model. Weights are usually represented with 32-bit floats. But we "quantize" the weights and reduce this to 8-bits instead. You can go even further and use as less as 1-bit for every parameter, creating binary neural networks, but that is beyond the scope of this post. While quantization directly reduces model size by 4x, that is not the most important part. Using reduced precisions *significantly* reduces the time taken for matrix multiplication and addition. These are not measly 10-20% gains. You can expect a 3-5x speed up when quantizing a model from FP32 to INT8. These gains are serious enough that they offset the performance gap between a CPU and GPU, making real time inference possible on CPU.

So… what’s the catch you ask? The catch is that using lower precision arithmetic means there is an increased chance of arithmetic overflow — because we are greatly limiting the range in which values can lie. There are ways to reduce the probability of overflow (more on this later) but the chances still remain.

<!-- Knowledge distillation is also a cool thing to try as well but unlike quantization, it might need you to make non-trivial changes to your training loop. Especially if you are doing things more complex than standard classification. Depending on your task you might also need a lot of extra experimentation to get knowledge distillation to work well. But if you can get it to work, you can seriously reduce your model size with a minor hit in performance. Combine that with quantization and you will have blazing fast inference. -->


## How?

Quantizing common pytorch models are pretty simple thanks to Pytorch's quantization API. You need to perform the following steps to get a basic quantized model

#### Step 0: Create a model

Let's create a basic resnet18 model with a binary classification head. Note that we need to use the 'quantization' version of resnet18, instead of standard torchvision version. The latter will give an error. I will explain the reason for this later.

In [None]:
resnet = nn.Sequential(
    *list(models.quantization.resnet18(pretrained=True).children())[:-3], 
    nn.Flatten(), 
    nn.Linear(512,2)
).cuda()

#### Step 1: Fuse layers

In this step we will 'combine' the layers of our model. This step is actually not related to quantization, but it does give extra speedups.

In [None]:
torch.quantization.fuse_modules(resnet, [['0', '1', '2']], inplace=True)
for i in range(4,8):
    torch.quantization.fuse_modules(resnet[i][0], 
                                    [['conv1', 'bn1', 'relu'], ['conv2', 'bn2']], 
                                    inplace=True)
    if resnet[i][0].downsample is not None:
        torch.quantization.fuse_modules(resnet[i][0].downsample, 
                                        [['0', '1']], 
                                        inplace=True)
    torch.quantization.fuse_modules(resnet[i][1], 
                                    [['conv1', 'bn1', 'relu'], ['conv2', 'bn2']], 
                                    inplace=True)

#### Step 2:  Prepare for qat

Prepare the model for quantization aware training

In [None]:
resnet.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
resnet = torch.quantization.prepare_qat(resnet).cuda()

#### Step 3: Train normally

#### Step 4: Post training steps

In [None]:
class Qresnet(nn.Module):
    def __init__(self, m):
        super().__init__()
        self.q = torch.quantization.QuantStub()
        self.m = m
        self.dq = torch.quantization.DeQuantStub()
    def forward(self, x):
        return self.dq(self.m(self.q(x)))

# load the best model from training phase
resnet.load_state_dict(torch.load('best_model.pth'))

# wrap qat resnet with quant dequant stubs
qmodel = Qresnet(resnet)

# add quantization recipe to model
qmodel.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# prepare the modules in the model to be quantized
qmodel = torch.quantization.prepare(qmodel)

# calibrate weights
for x,y in train_loader:
    qmodel(x.cuda())

# actually quantize the trained model. 
qmodel = torch.quantization.convert(qmodel.cpu())

# put to eval mode
qmodel = qmodel.eval()

# script the model using TorchScript for easy delpoyment
torch.jit.script(qmodel)

## Going deeper

