# Quantization of Convlution, BN, ReLU

In [1]:
import torch
import torch.nn as nn

torch.manual_seed(100)

class conv_bn_relu(nn.Module):
    def __init__(self):
        super(conv_bn_relu, self).__init__()
        self.conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5, stride=1, padding=1, groups=1, bias=True)
        self.bn = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()

    def forward(self, x, targets=None):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

C:\ProgramData\Anaconda3\lib\site-packages\numpy\.libs\libopenblas.PYQHXLVVQ7VESDPUVUADXEVJOBGHJPAY.gfortran-win_amd64.dll
C:\ProgramData\Anaconda3\lib\site-packages\numpy\.libs\libopenblas.TXA6YQSD3GCQQC22GEQ54J2UDCXDXHWN.gfortran-win_amd64.dll
  stacklevel=1)


## 1. Setting quantization parameter

In [2]:
bit = 8
x_bit = 2**(bit - 1) - 1
w_bit = 2**(bit - 1) - 1
y_bit = 2**(bit - 1) - 1
b_bit = 2**(bit*2 - 1) - 1

## 2. Float model inference

In [3]:
# create model
model = conv_bn_relu()
model.eval()
# create input tensor B*C*W*H
xf = torch.rand(1, 1, 7, 7)
# inference
yf = model(xf)

## 3. Input quantization

In [4]:
# scale
scale_x = (2**(bit - 1) - 1) / xf.abs().max()
# quantize
xq = torch.round(xf * scale_x)
xq = torch.clip(xq, -x_bit, x_bit)

## 4. Weight quantization

In [5]:
# fusion conv and bn parameter
conv = model.conv
bn = model.bn
conv_w = conv.weight
conv_b = conv.bias
bn_gamma = bn.weight / torch.pow(bn.running_var + bn.eps, 0.5)
bn_beta = bn.bias - bn.running_mean * bn_gamma
fusion_w = conv_w * bn_gamma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
fusion_b = conv_b * bn_gamma + bn_beta

# scale of weight
scale_w = (2**(bit - 1) - 1) / fusion_w.abs().max()
# quantize weight
wq = torch.round(fusion_w * scale_w)
wq = torch.clip(wq, -w_bit, w_bit)
# scale of bias
scale_b = scale_w * scale_x
# quantize bais
bq = torch.round(fusion_b * scale_b)
bq = torch.clip(bq, -b_bit, b_bit)

# update parameter
model.conv.weight.data = wq
model.conv.bias.data = bq
model.bn.running_var /= model.bn.running_var
model.bn.running_mean *= 0.0
model.bn.weight.data /= model.bn.weight
model.bn.bias.data *= 0.0

## 5. Output quantization

In [6]:
# scale
scale_y = (2**(bit - 1) - 1) / yf.abs().max()

## 6. Quantization inference

In [8]:
y_pip = model(xq)
yq = torch.round(y_pip * scale_y / (scale_x * scale_w))
yq = torch.clip(yq, -y_bit, y_bit)
yfq = yq / scale_y

print('Float inference:\n', yf)
print('Int8 inference:\n', yfq)

Float inference:
 tensor([[[[0.0000, 0.0545, 0.0000, 0.0000, 0.0533],
          [0.0000, 0.0000, 0.0000, 0.1838, 0.0000],
          [0.0986, 0.0781, 0.1405, 0.0000, 0.0000],
          [0.1124, 0.0923, 0.0378, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0463, 0.4461, 0.0000, 0.1777],
          [0.3421, 0.0896, 0.1494, 0.0811, 0.3406],
          [0.1011, 0.1305, 0.0357, 0.3610, 0.1571],
          [0.2170, 0.0591, 0.1802, 0.0000, 0.0000],
          [0.1733, 0.3323, 0.1169, 0.1065, 0.3219]],

         [[0.4899, 0.5766, 0.2208, 0.2846, 0.4856],
          [0.4484, 0.2111, 0.4620, 0.7209, 0.1935],
          [0.2549, 0.3707, 0.6399, 0.5517, 0.3339],
          [0.4240, 0.4600, 0.5220, 0.5862, 0.5554],
          [0.5264, 0.5483, 0.5431, 0.5070, 0.2373]]]], grad_fn=<ReluBackward0>)
Int8 inference:
 tensor([[[[0.0000, 0.0511, 0.0000, 0.0000, 0.0511],
          [0.0000, 0.0000, 0.0000, 0.1816, 0.0000],
          [0.1022, 0.0795, 0.1419, 0.0000, 0.0000