# Quantization of Classifier Cat and Dog Model
<br />
<font size=3>
此篇範例用pytorch的CNN模型做貓跟狗的分類，並且做模型量化。一般在做量化模型程式工具，為了可以套用各種模型，程式會寫許多API來支援各種案例，但為了方便解釋每個量化過程，所以程式寫法會比較不精簡化，如果有興趣的話可以在自行研究如何API化。<br />
<br />
這裡稍微敘述一下，在整個模型做量化的實際過程:<br />
<br />
1. 建立模型，並且在模型中建立觀測數據的節點:<br />
&emsp;&emsp;通常我們會在參數型layer(ex. CONV、BN、FC)的輸出處植入觀測數據範圍的功能，之後才可以計算這些layer的量化參數scale。<br />
<br />
2. Inference calibration dataset:<br />
&emsp;&emsp;在量化時，會先準備一筆dataset跑一遍模型，讓事先埋入觀測數據範圍的函數紀錄data range。<br />
<br />
3. 開始量化:<br />
&emsp;&emsp;a. 在量化前，可以先把CONV和BN的weight和bias先融合在一起，當然這步驟也可以在跑inference calibration前就先做。<br />
&emsp;&emsp;b. 將每層的輸出scale都計算好，再來計算weight和bias的scale，把整個模型都量化好。<br />
&emsp;&emsp;c. 模型量化好後，就可以執行量化的inference做測試了。<br />

<br />
可以開始寫code了，若有GPU的電腦就可以使用CUDA執行，沒有的話就使用CPU，速度不會差很多，因為這邊並沒有要做訓練。<br />
</font>

In [12]:
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchnet.meter as tnt
from torch.utils.data import Dataset, DataLoader

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Create CNN model with quantization function<br />
<br />
<font size=3>
接下來看模型部分會有些複雜，但一個個仔細解釋就會變得簡單許多，先分別每個功能一個個看。<br />
</font>
<br />

### def __init__(self):<br />
<font size=3>
先來看一下我們要使用的模型結構，這裡使用了4層的CONV和兩層FC，每層CONV都會接著BN、ReLU和MaxPool，而FC會接ReLU，通常最後一層就不會再接ReLU，以上是基本模型的結構。<br />
再來是量化的功能，我們在輸入處和每一層ReLU輸出後面都加了scale和range這兩個參數，當然也可以放在MaxPool層後面，其實效果不會差太多，range參數會再做inference calibration的時候儲存此資料範圍，而scale會在量化模型的時候計算。<br />
</font>
<br />

### def update_range(self, data, range_data):<br />
<font size=3>
這個函數很好理解，就是不斷更新此層的資料範圍，如果目前的batch data比過去的範圍還大，就更新參數。<br />
</font>
<br />

### quantize(self, x, scale, int_range):<br />
<font size=3>
顧名思義，就是將數據量化，比較需要特別說明的是torch.clip(xq, -int_range, int_range)這段函式，它用來確保量化後的值能夠在int8資料型態內，也就是資料範圍會限制在-127~128。<br />
</font>
<br />

### weight_quant(self, w, b, in_scale):<br />
<font size=3>
這裡就跟之前講過的量化式一樣的，比較需要注意的時bias的量化參數scale是由weight和input的scale計算出來的，原因就不在這贅述。<br />
</font>
<br />

### fusion_conv_bn_quant(self, conv, bn, in_scale):<br />
<font size=3>
融合CONV和BN這兩個layer，這也跟先前的範例一樣。<br />
</font>
<br />

### cl_scale(self):<br />
<font size=3>
在inference calibration後，就可以使用這個函數開始計算每層layer的feature map和weight的scale，並且把模型內的weight和bias做量化，模型將由FP32轉成INT8。<br />
</font>
<br />

### forward(self, x):<br />
<font size=3>
forward這裡切成兩種模式(FP32 and INT8)，FP32模式可以看到在每層ReLU的輸出處加入觀測數據範圍的函數，而INT8則是在模型輸入層和每個ReUL層後面都加入量化層。<br />
輸入層的量化很好理解，而ReLU後的量化在這解釋一下，因為經過參數層(CONV、BN、FC)的運算，此時的feature map已經包含input and weight的scale，這時候的值域範圍還不是原本輸出該有的值域範圍，資料型態範圍也不會是INT8，而會是INT32，為了將feature map量化到輸出值域範圍，會先還原input and weight的scale，在用此層的scale量化到該值域。
</font>
<br />

In [13]:
class CNN_Model(nn.Module):
    def __init__(self):
        super(CNN_Model, self).__init__()
        # Convolution 1 , input_shape=(3,224,224)
        self.scale_input = torch.tensor(1.0) # quant-scale
        self.range_input = torch.tensor(1.0) # quant-range
        self.cnn1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.scale_cv1 = torch.tensor(1.0) # quant-scale
        self.range_cv1 = torch.tensor(1.0) # quant-range
        # Max pool 1
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        # Convolution 2
        self.cnn2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU()
        self.scale_cv2 = torch.tensor(1.0) # quant-scale
        self.range_cv2 = torch.tensor(1.0) # quant-range
        # Max pool 2
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        # Convolution 3
        self.cnn3 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(16)
        self.relu3 = nn.ReLU()
        self.scale_cv3 = torch.tensor(1.0) # quant-scale
        self.range_cv3 = torch.tensor(1.0) # quant-range
        # Max pool 3
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)
        # Convolution 4
        self.cnn4 = nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding=0)
        self.bn4 = nn.BatchNorm2d(8)
        self.relu4 = nn.ReLU()
        self.scale_cv4 = torch.tensor(1.0) # quant-scale
        self.range_cv4 = torch.tensor(1.0) # quant-range
        # Max pool 4
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)
        # Fully connected 1 ,#input_shape=(8*12*12)
        self.fc1 = nn.Linear(8 * 11 * 11, 512) 
        self.relu5 = nn.ReLU()
        self.scale_fc1 = torch.tensor(1.0) # quant-scale
        self.range_fc1 = torch.tensor(1.0) # quant-range
        self.fc2 = nn.Linear(512, 2)
        self.scale_fc2 = torch.tensor(1.0) # quant-scale
        self.range_fc2 = torch.tensor(1.0) # quant-range
        self.output = nn.Softmax(dim=1)
        
        self.bit = 8
        self.x_bit = 2**(self.bit - 1) - 1
        self.w_bit = 2**(self.bit - 1) - 1
        self.y_bit = 2**(self.bit - 1) - 1
        self.b_bit = 2**(self.bit*2 - 1) - 1
        
        self.Quantization = False
        
    def update_range(self, data, range_data):
        now_range = data.abs().max()
        if now_range > range_data:
            return now_range
        else:
            return range_data
        
    def quantize(self, x, scale, int_range):
        xq = torch.round(x * scale)
        xq = torch.clip(xq, -int_range, int_range)
        return xq
        
    def weight_quant(self, w, b, in_scale):
        # scale of weight
        scale_w = (2**(self.bit - 1) - 1) / w.abs().max()
        # quantize weight
        wq = self.quantize(w, scale_w, self.w_bit)
        # scale of bias
        scale_b = scale_w * in_scale
        # quantize bais
        bq = self.quantize(b, scale_b, self.b_bit)
        
        return wq, bq, scale_w
        
    def fusion_conv_bn_quant(self, conv, bn, in_scale):
        # fusion conv and bn parameter
        conv_w = conv.weight
        conv_b = conv.bias
        bn_gamma = bn.weight / torch.pow(bn.running_var + bn.eps, 0.5)
        bn_beta = bn.bias - bn.running_mean * bn_gamma
        fusion_w = conv_w * bn_gamma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        fusion_b = conv_b * bn_gamma + bn_beta
        
        wq, bq, scale_w = self.weight_quant(fusion_w, fusion_b, in_scale)
        
        # update parameter
        conv.weight.data = wq
        conv.bias.data = bq
        bn.running_var /= bn.running_var
        bn.running_mean *= 0.0
        bn.weight.data /= bn.weight
        bn.bias.data *= 0.0
        
        return conv, bn, scale_w
        
    def cl_scale(self):
        # Quantize activation tensor
        self.scale_input = (2**(self.bit - 1) - 1) / self.range_input
        self.scale_cv1 = (2**(self.bit - 1) - 1) / self.range_cv1
        self.scale_cv2 = (2**(self.bit - 1) - 1) / self.range_cv2
        self.scale_cv3 = (2**(self.bit - 1) - 1) / self.range_cv3
        self.scale_cv4 = (2**(self.bit - 1) - 1) / self.range_cv4
        self.scale_fc1 = (2**(self.bit - 1) - 1) / self.range_fc1
        self.scale_fc2 = (2**(self.bit - 1) - 1) / self.range_fc2
        
        self.cnn1, self.bn1, self.cnn1_w_scale = self.fusion_conv_bn_quant(self.cnn1, self.bn1, self.scale_input)
        self.cnn2, self.bn2, self.cnn2_w_scale = self.fusion_conv_bn_quant(self.cnn2, self.bn2, self.scale_cv1)
        self.cnn3, self.bn3, self.cnn3_w_scale = self.fusion_conv_bn_quant(self.cnn3, self.bn3, self.scale_cv2)
        self.cnn4, self.bn4, self.cnn4_w_scale = self.fusion_conv_bn_quant(self.cnn4, self.bn4, self.scale_cv3)
        
        self.fc1.weight.data, self.fc1.bias.data, self.fc1_w_scale = self.weight_quant(self.fc1.weight, self.fc1.bias, self.scale_cv4)
        self.fc2.weight.data, self.fc2.bias.data, self.fc2_w_scale = self.weight_quant(self.fc2.weight, self.fc2.bias, self.scale_fc1)
        
        
    
    def forward(self, x):
        if not self.Quantization:
            self.range_input = self.update_range(x, self.range_input) # quant-觀察數據範圍
            out = self.cnn1(x)
            out = self.bn1(out)
            out = self.relu1(out)
            self.range_cv1 = self.update_range(out, self.range_cv1) # quant-觀察數據範圍
            out = self.maxpool1(out)
            out = self.cnn2(out)
            out = self.bn2(out)
            out = self.relu2(out)
            self.range_cv2 = self.update_range(out, self.range_cv2) # quant-觀察數據範圍
            out = self.maxpool2(out)
            out = self.cnn3(out)
            out = self.bn3(out)
            out = self.relu3(out)
            self.range_cv3 = self.update_range(out, self.range_cv3) # quant-觀察數據範圍
            out = self.maxpool3(out)
            out = self.cnn4(out)
            out = self.bn4(out)
            out = self.relu4(out)
            self.range_cv4 = self.update_range(out, self.range_cv4) # quant-觀察數據範圍
            out = self.maxpool4(out)
            out = out.view(out.size(0), -1)
            out = self.fc1(out)
            self.range_fc1 = self.update_range(out, self.range_fc1) # quant-觀察數據範圍
            out = self.fc2(out)
            self.range_fc2 = self.update_range(out, self.range_fc2) # quant-觀察數據範圍
        else:
            xq = self.quantize(x, self.scale_input, self.x_bit) # quant fp32 to int8
            out = self.cnn1(xq)
            out = self.relu1(out)
            out = torch.clip(torch.round(out * self.scale_cv1 / (self.scale_input * self.cnn1_w_scale)), -self.x_bit, self.x_bit)
            out = self.maxpool1(out)
            out = self.cnn2(out)
            out = self.relu2(out)
            out = torch.clip(torch.round(out * self.scale_cv2 / (self.scale_cv1 * self.cnn2_w_scale)), -self.x_bit, self.x_bit)
            out = self.maxpool2(out)
            out = self.cnn3(out)
            out = self.relu3(out)
            out = torch.clip(torch.round(out * self.scale_cv3 / (self.scale_cv2 * self.cnn3_w_scale)), -self.x_bit, self.x_bit)
            out = self.maxpool3(out)
            out = self.cnn4(out)
            out = self.relu4(out)
            out = torch.clip(torch.round(out * self.scale_cv4 / (self.scale_cv3 * self.cnn4_w_scale)), -self.x_bit, self.x_bit)
            out = self.maxpool4(out)
            out = out.view(out.size(0), -1)
            out = self.fc1(out)
            out = torch.clip(torch.round(out * self.scale_fc1 / (self.scale_cv4 * self.fc1_w_scale)), -self.x_bit, self.x_bit)
            out = self.fc2(out)
            out = torch.clip(torch.round(out * self.scale_fc2 / (self.scale_fc1 * self.fc2_w_scale)), -self.x_bit, self.x_bit)
            

        return out

## Create dataloader<br />
這邊就是建立dataloader的API，這裡就先不贅述。<br />

In [14]:
class dataset(Dataset):
    def __init__(self, datapath, datasize=(224, 224), mode='train'):
        self.datapath = datapath
        self.datasize = datasize
        self.mode = mode
        self.datanames = os.listdir(os.path.join(self.datapath, self.mode))

    def __len__(self):
        return len(self.datanames)
    
    def __getitem__(self, idx):
        
        image_path = os.path.join(self.datapath, self.mode, self.datanames[idx])
        image = Image.open(image_path, mode='r')
        image = image.convert('RGB')
        image = image.resize(self.datasize)
        image = np.array(image).transpose(2,0,1)
        
        if 'cat' in self.datanames[idx]:
            label = [0]
        else:
            label = [1]
        
        image = torch.FloatTensor(image)
        label = torch.LongTensor(label)
        
        return image, label
     
    def collate_fn(self, batch):
        imgs = list()
        labels = list()
        for b in batch:
            imgs.append(b[0])
            labels.append(b[1])
        imgs = torch.stack(imgs, dim=0)
        labels = torch.stack(labels, dim=0)
        return imgs, labels

## Dataloader setting<br />
設定dataloader，依照代碼的需求，需要在根本代碼檔案相同路徑下建立一個cat_dog資料夾，裡面在建立一個val資料夾，放入圖片。<br />

In [15]:
# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 5
train_dataset = dataset('cat_dog', datasize=(224, 224), mode='val')
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=train_dataset.collate_fn)

## Load pre-train model


In [16]:
model = CNN_Model()

# Load pre-train model
checkpoint_backbone = torch.load('BEST_checkpoint.pth.tar', map_location=torch.device('cpu'))
best_model = checkpoint_backbone['model']
pretrained_dict = best_model.state_dict()
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
print("pretrained_dict:\n", pretrained_dict.keys())
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

model.to(device)
model.eval()

pretrained_dict:
 dict_keys(['cnn1.weight', 'cnn1.bias', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'cnn2.weight', 'cnn2.bias', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var', 'bn2.num_batches_tracked', 'cnn3.weight', 'cnn3.bias', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var', 'bn3.num_batches_tracked', 'cnn4.weight', 'cnn4.bias', 'bn4.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var', 'bn4.num_batches_tracked', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])


CNN_Model(
  (cnn1): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn3): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
  (bn3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (cnn4): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1))
  (bn4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu4): ReLU()
  (maxpool4): MaxPool2d(kernel_size=2, stride=2, padding=

## Run float model inference<bn />
先執行一遍把所有calibration資料都跑一遍，讓每層輸出都獲的該數據範圍。<bn />

In [17]:
classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 2))

end_i = train_loader.__len__()

for i, data in enumerate(train_loader, 0):
    inputs, labels = data
    inputs = inputs.to(device)
    labels = labels.to(device).squeeze(-1)

    outputs = model(inputs)

    classerr.add(outputs.data, labels)

    Top1 = classerr.value()[0]

    print('FP32 Top1: {:.3f}, {}/{}\n'.format(Top1, i, end_i))

FP32 Top1: 100.000, 0/20

FP32 Top1: 100.000, 1/20

FP32 Top1: 100.000, 2/20

FP32 Top1: 100.000, 3/20

FP32 Top1: 100.000, 4/20

FP32 Top1: 100.000, 5/20

FP32 Top1: 100.000, 6/20

FP32 Top1: 97.500, 7/20

FP32 Top1: 97.778, 8/20

FP32 Top1: 98.000, 9/20

FP32 Top1: 98.182, 10/20

FP32 Top1: 98.333, 11/20

FP32 Top1: 98.462, 12/20

FP32 Top1: 98.571, 13/20

FP32 Top1: 97.333, 14/20

FP32 Top1: 97.500, 15/20

FP32 Top1: 97.647, 16/20

FP32 Top1: 97.778, 17/20

FP32 Top1: 96.842, 18/20

FP32 Top1: 97.000, 19/20



## Run quantization inference<bn />
計算scale，把Quantization功能打開，開始inference quantization mmodel。<bn />

In [18]:
model.cl_scale()
model.Quantization = True

for i, data in enumerate(train_loader, 0):
    inputs, labels = data
    inputs = inputs.to(device)
    labels = labels.to(device).squeeze(-1)

    outputs = model(inputs)

    outputs /= model.fc2_w_scale
    # print(outputs)

    classerr.add(outputs.data, labels)

    Top1 = classerr.value()[0]

    print('INT8 Top1: {:.3f}, {}/{}\n'.format(Top1, i, end_i))

INT8 Top1: 97.143, 0/20

INT8 Top1: 97.273, 1/20

INT8 Top1: 97.391, 2/20

INT8 Top1: 96.667, 3/20

INT8 Top1: 96.800, 4/20

INT8 Top1: 96.923, 5/20

INT8 Top1: 97.037, 6/20

INT8 Top1: 96.429, 7/20

INT8 Top1: 95.862, 8/20

INT8 Top1: 96.000, 9/20

INT8 Top1: 96.129, 10/20

INT8 Top1: 96.250, 11/20

INT8 Top1: 96.364, 12/20

INT8 Top1: 96.471, 13/20

INT8 Top1: 96.000, 14/20

INT8 Top1: 96.111, 15/20

INT8 Top1: 96.216, 16/20

INT8 Top1: 96.316, 17/20

INT8 Top1: 96.410, 18/20

INT8 Top1: 96.500, 19/20

