## Embedded ML Lab - Excercise 2 - Quantization


The goal of this exercise is to take a given network, fuse its operators, and finally quantize it. For that we will do the following steps
* 1) We define the quantized network with fused operators
* 2) We determine how to fuse `conv-bn-relu` structures into a single quantized operation.
* 3) We fuse the weights from the pre-trained state dict and quantize them
* 4) We use a calibration batch from the pretrained network to determine all required scales
* 5) Done :)

For this lab the non-quantized version of the net we use is already implemented in `net.py`. It contains 6 conv, 6 batchnorm, 6 relu layers, and only has a very small linear part at the end. Take a look at it.

<img src="src/cifarnet.png" alt="drawing" width="800"/>

In [1]:
from net import CifarNet

In [2]:
import torch

torch.backends.quantized.engine = 'qnnpack'

import torchvision
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testloader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('data/', train=False, download=True, transform=tf), batch_size=32)

Files already downloaded and verified


## Preliminaries

To measure the effects of quantization we want to measure the time it takes to calculate a batch with the quantized and the unquatized network to run on the cpu. Additionally, we want to know what the accuracy penalty is.

<span style="color:green">Your Tasks:</span>
* <span style="color:green">Implement a function `net_time` that measures the time it takes (forward pass) to process a batch with size 32 of cifar100. You can use `t_now = time.time()` to get the current time.</span>
    * <span style="color:green">NOTE: To save time, you do not have to iterate over the whole dataset.</span>
* <span style="color:green">Implement a function `net_acc` that measures the accuracy of the net class, and takes the class type, a state_dict, and a dataloader as input.</span>
    * <span style="color:green">NOTE: To save time, you do not have to iterate over the whole dataset.</span>
    * <span style="color:green">NOTE: You can reuse code from the last lab exercises.</span>



In [27]:
import time 

def net_time(model_class, testloader):
    
    #----to-be-done-by-student-------------------
    ###
    ###
    #----to-be-done-by-student-------------------
    t = 0.0
    device = torch.device("cpu")
    model = model_class()
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for inference
        inputs, _ = next(iter(testloader))  # Get a batch of inputs
        inputs = inputs.to(device)
        t_start = time.time()  # Record the start time
        _ = model(inputs)  # Perform forward pass
        t_end = time.time()  # Record the end time
        t = t + t_end - t_start  # Calculate elapsed time
    return t

def net_acc(model_class, state_dict, testloader):
    #----to-be-done-by-student-------------------
    ###
    ###
    #----to-be-done-by-student-------------------
    correct = 0
    total = 0
    #accuracy = 0.0
    #accuracy = test(model, test_loader, device=device)
    model = model_class()
    model.load_state_dict(state_dict)
    #model.eval()  # Set the model to evaluation mode
    for idx, (inputs, targets) in enumerate(testloader):
        outputs = model(inputs)
        _, out_class = torch.max(outputs, dim=1)
        correct += (out_class==targets).sum().item()
        total += targets.size(0)
        if (idx==6):
            break
    accuracy = correct / total
        
    #with torch.no_grad():  # Disable gradient calculation for inference
    #    for inputs, labels in testloader:
    #        inputs, labels = inputs.to(device), labels.to(device)
    #        outputs = model(inputs)  # Perform forward pass
    #        _, predicted = torch.max(outputs, 1)
    #        total += labels.size(0)
    #        correct += (predicted == labels).sum().item()
    #accuracy = correct / total
    return accuracy
    

In [52]:
print(f'Time unquantized: {net_time(CifarNet, testloader)} s')
print(f"Accuracy unquantized: {net_acc(CifarNet, torch.load('state_dict.pt'), testloader):.4%}")

Time unquantized: 1.3777589797973633 s
Accuracy unquantized: 78.0600%


## Quantized network
Now we define the quantized version of CifarNet with fused operators ( conv-bn-relu -> qfused_conv_relu). The resulting network has a structure as shown below:

<img src="src/cifarnet_quantized.png" alt="drawing" width="600"/>

<span style="color:green">Your Tasks:</span>
* <span style="color:green">Take the provided image as well as the CifarNet implementation as reference and implemenet the **forward pass** of QCifarNet.</span>
    * <span style="color:green">The required modules `Conv2drelu` and `QLinear` are already provided and can be used like any other module we have seen before. Note that these modules require their weights to be quantized, the bias is unquantized. The forward pass of these modules require an quantized input and return an quantized output. The modules are essentially only a wrapper with parameters around `torch.ops.quantized.conv2d_relu` and `torch.ops.quantized.linear`. Additionally these modules have an paramter called `scale`, that is used as ouput scale for the operation.</span>
    * <span style="color:green">You might require some other "stateless" operators such as `torch.quantize_per_tensor`, `torch.dequantize`,`torch.flatten`, and `torch.nn.quantized.functional.max_pool2d`.</span>
* <span style="color:green">Profile the resulting net and compare its forward pass time to the non-quantized implementation.</span>



In [4]:
import torch.nn as nn
import torch.nn.functional as F

def f_sd(sd, endswith_key_string):
    keys = [i for i in sd.keys() if i.endswith(endswith_key_string)]
    if not keys:
        raise KeyError(endswith_key_string)
    return sd[keys[0]]

#Quantized Conv2dReLU Module
class QConv2dReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(QConv2dReLU, self).__init__()

        self.weight = torch.nn.Parameter(torch.quantize_per_tensor(torch.Tensor(
                out_channels, in_channels // 1, *(kernel_size, kernel_size)), scale=0.1, zero_point = 0, dtype=torch.qint8), requires_grad=False)
        self.bias = torch.nn.Parameter(torch.Tensor(out_channels), requires_grad=False)

        self.register_buffer('scale', torch.tensor(0.1))

        self.stride = stride
        self.padding = padding
        
        self._prepack = self._prepare_prepack(self.weight, self.bias, stride, padding)
        self._register_load_state_dict_pre_hook(self._sd_hook)

    def _prepare_prepack(self, qweight, bias, stride, padding):
        assert qweight.is_quantized, "QConv2dReLU requires a quantized weight."
        assert not bias.is_quantized, "QConv2dReLU requires a float bias."
        return torch.ops.quantized.conv2d_prepack(qweight, bias, stride=[stride, stride], dilation=[1,1], padding=[padding, padding], groups=1)

    
    def _sd_hook(self, state_dict, prefix, *_):
        self._prepack = self._prepare_prepack(f_sd(state_dict, prefix + 'weight'), f_sd(state_dict, prefix + 'bias'),
                                             self.stride, self.padding)
    
    def forward(self, x):
        return torch.ops.quantized.conv2d_relu(x, self._prepack, self.scale, 64)

    
#Quantized Linear Module
class QLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(QLinear, self).__init__()

        self.weight = torch.nn.Parameter(torch.quantize_per_tensor(torch.Tensor(out_features, in_features), scale=0.1, zero_point = 0, dtype=torch.qint8), requires_grad=False)
        self.bias = torch.nn.Parameter(torch.Tensor(out_features))

        self.register_buffer('scale', torch.tensor(0.1))
        
        self._prepack = self._prepare_prepack(self.weight, self.bias)
        
        self._register_load_state_dict_pre_hook(self._sd_hook)
        
    def _prepare_prepack(self, qweight, bias):
        assert qweight.is_quantized, "QConv2dReLU requires a quantized weight."
        assert not bias.is_quantized, "QConv2dReLU requires a float bias."
        return torch.ops.quantized.linear_prepack(qweight, bias)
    
    def _sd_hook(self, state_dict, prefix, *_):
        self._prepack = self._prepare_prepack(f_sd(state_dict, prefix + 'weight'), f_sd(state_dict, prefix + 'bias'))
        return

    def forward(self, x):
        return torch.ops.quantized.linear(x, self._prepack, self.scale, 64)

In [5]:
print('state_dict of QConv2dReLU')
qconv2drelu = QConv2dReLU(3, 16)
for key in qconv2drelu.state_dict(): print(key, qconv2drelu.state_dict()[key].dtype)
print('\nstate_dict of QLinear')
qlinear = QLinear(10, 10)
for key in qlinear.state_dict(): print(key, qlinear.state_dict()[key].dtype)

state_dict of QConv2dReLU
weight torch.qint8
bias torch.float32
scale torch.float32

state_dict of QLinear
weight torch.qint8
bias torch.float32
scale torch.float32


In [6]:
class QCifarNet(nn.Module):
    def __init__(self):
        super(QCifarNet, self).__init__()
        
        self.register_buffer("scale", torch.tensor(0.1))

        self.conv1 = QConv2dReLU(3, 16, 3, 1, padding=1)
        self.conv2 = QConv2dReLU(16,16, 3, 1, padding=1)

        self.conv3 = QConv2dReLU(16, 32, 3, 1, padding=1)
        self.conv4 = QConv2dReLU(32, 32, 3, 1, padding=1)

        self.conv5 = QConv2dReLU(32, 64, 3, 1, padding=1)
        self.conv6 = QConv2dReLU(64, 64, 3, 1, padding=1)

        self.fc = QLinear(1024, 10)
        
    def forward(self, x):
        #to-be-done-by-student
        ###
        ###
        #to-be-done-by-student
        x = torch.quantize_per_tensor(x, scale=self.scale, zero_point=64, dtype=torch.quint8)
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = torch.nn.quantized.functional.max_pool2d(x, 2, stride=2)
        
        x = self.conv3(x)
        x = self.conv4(x)
        x = torch.nn.quantized.functional.max_pool2d(x, 2, stride=2)
        
        x = self.conv5(x)
        x = self.conv6(x)
        x = torch.nn.quantized.functional.max_pool2d(x, 2, stride=2)
        
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = torch.dequantize(x)
        
        return x

In [72]:
#We evaulate how fast the quantized verions of CifarNet is
print(f"Time quantized: {net_time(QCifarNet, testloader)} s")

Time quantized: 0.1098482608795166 s


## Calibration and Operator Fusion

First we focus on the operator fusion:
* We need calculate the new weights (fused conv and batchnorm weights). After we have weights, we can quantize them using the `tensor_scale` equation from earlier.
    * A Conv2d convolution can be expressed as $y_i = \boldsymbol{ W_{i}} \star x + b_{_i}$, where $y_i$ is the channel wise output of the convolution and $\boldsymbol{ W_{i}}$ is a $\text{channel_in} \times \text{kernel_size} \times \text{kernel_size}$ kernel.
    * The batch_norm operation looks like this: $\hat x_i = \frac{x_i - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}}$, where for each output channel of a convolution $i \in C$, we scale and shift the input to be zero mean and unit variance, where $\mu_i$ is the channel wise input mean, and $\sigma^2_i$ is the channels wise variance. Parameter $\epsilon$ is added for numerical stability.
    * After this shift and scale operation trainable weight and bias terms are added
 $y_i = \gamma_i \hat x_i + \beta_i$, where $\gamma_i$ is a channel wise scale factor and $\beta_i$ is a channel wise bias.
    * We can express the batchnorm operation as $y_i = (\frac{\gamma_i} {\sqrt{\sigma_i^2 + \epsilon}})x_i +  (\frac{ - \mu_i \gamma_i}{\sqrt{\sigma_i^2 + \epsilon}} + \beta_i)$ and fuse it with the convolution kernel by using $y_i = (\frac{\gamma_i} {\sqrt{\sigma_i^2 + \epsilon}} \boldsymbol{ W_i}) \star x_i +  (\frac{ \gamma_i ( b_i - \mu_i)}{\sqrt{\sigma_i^2 + \epsilon}} + \beta_i)$, s.t. the fused kernel (output channel wise) can be expressed as $\tilde{\boldsymbol{ W_{i}}} = (\frac{\gamma_i} {\sqrt{\sigma_i^2 + \epsilon}}) \boldsymbol{ W_i}$ and the fused bias (output channel wise) as $\tilde{b_i} = (\frac{ \gamma_i ( b_i - \mu_i)}{\sqrt{\sigma_i^2 + \epsilon}} + \beta_i)$ .
 

<span style="color:green">Your Tasks:</span>
* <span style="color:green">Implement a function `fuse_conv_bn_weights` that fuses the weights and bias of the convolution with the weights, bias, running_mean and running_var of the batchnorm_layer</span>
    * <span style="color:green"> determine $\tilde{b}$ and $\tilde{\boldsymbol{ W}}$</span>
    * <span style="color:green"> You can either do this channel by channel or compleatly vectorized</span>

In [7]:
def tensor_scale(input):
    return float(2*torch.max(torch.abs(torch.max(input)), torch.abs(torch.min(input))))/127.0

def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_w, bn_b):
    """
    Input:
        conv_w: shape=(output_channels, in_channels, kernel_size, kernel_size)
        conv_b: shape=(output_channels)
        bn_rm:  shape=(output_channels)
        bn_rv:  shape=(output_channels)
        bn_w:   shape=(output_channels)
        bn_b:   shape=(output_channels)
    
    Output:
        fused_conv_w = shape=conv_w
        fused_conv_b = shape=conv_b
    """
    bn_eps = 1e-05

    fused_conv = torch.zeros(conv_w.shape)
    fused_bias = torch.zeros(conv_b.shape)
    
    #to-be-done-by-student
    ###
    ###
    #to-be-done-by-student
    for i in range(conv_w.shape[0]):
        gamma = bn_w[i] / torch.sqrt(bn_rv[i] + bn_eps)
        fused_conv[i] = conv_w[i] * gamma.view(-1, 1, 1, 1)
        fused_bias[i] = gamma * (conv_b[i] - bn_rm[i]) / torch.sqrt(bn_rv[i] + bn_eps) + bn_b[i]
        
    # Quantize the fused weights and bias
    #scale_w = tensor_scale(fused_conv)
    #scale_b = tensor_scale(fused_bias)

    #fused_conv = torch.quantize_per_tensor(fused_conv, scale=scale_w, zero_point=0, dtype=torch.qint8)
    #fused_bias = torch.quantize_per_tensor(fused_bias, scale=scale_b, zero_point=0, dtype=torch.qint8)
    

    return fused_conv, fused_bias

Now that we know how to fuse conv and batchnorm layers, we can setup the quantized state dict. We have to take the unfused unquantized parameters of the unquantized pretrained network (`state_dict.pt`) and fuse and quantize them.

<span style="color:green">Your Tasks:</span>  
* <span style="color:green">Now for each Conv weights and biases, load the pre-trained float weights and biases from the saved state_dict, fuse the corresponding weights and biases with the batch norm weights, biases, mean, and variance, and store the fused quantized weight into the quantized state_dict `qsd`</span>
* <span style="color:green">Some Tips:</span>
    * <span style="color:green">Print out the keys from the unquantized and quantized state_dict and see what is inside.</span>
    * <span style="color:green">You can ignore the scales for now, we will take care of them later.</span>
    * <span style="color:green">Reuse the function `tensor_scale`</span>
    * <span style="color:green">Weights require to be of type torch.qint8, therefor have a zero_point of 0.</span>

In [8]:
#prints keys from quantized net
qnet = QCifarNet()
qsd = qnet.state_dict()
for key in qsd: print(key, qsd[key].dtype)

print('-------end------')

sd = torch.load('state_dict.pt')

for key in sd:print(key, sd[key].dtype)

print('---------end---------')

#-to-be-done- by student 
###
###
#-to-be-done- by student
   
for name, module in qnet.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        conv_name = name
        bn_name = name.replace('conv', 'bn')
        if bn_name in sd:
            fused_conv_w, fused_conv_b = fuse_conv_bn_weights(
                sd[conv_name + '.weight'], 
                sd[conv_name + '.bias'], 
                sd[bn_name + '.running_mean'], 
                sd[bn_name + '.running_var'], 
                sd[bn_name + '.weight'], 
                sd[bn_name + '.bias']
            )
            scale = tensor_scale(fused_conv_w)
            qsd[conv_name + '.weight'] = torch.quantize_per_tensor(fused_conv, scale=scale, zero_point=0, dtype=torch.qint8)
            
            #qsd[conv_name + '.weight'] = fused_conv_w
            #qsd[conv_name + '.bias'] = fused_conv_b        

for key in qsd: print(key, qsd[key].dtype)
torch.save(qsd, 'quantized_state_dict.pt')

scale torch.float32
conv1.weight torch.qint8
conv1.bias torch.float32
conv1.scale torch.float32
conv2.weight torch.qint8
conv2.bias torch.float32
conv2.scale torch.float32
conv3.weight torch.qint8
conv3.bias torch.float32
conv3.scale torch.float32
conv4.weight torch.qint8
conv4.bias torch.float32
conv4.scale torch.float32
conv5.weight torch.qint8
conv5.bias torch.float32
conv5.scale torch.float32
conv6.weight torch.qint8
conv6.bias torch.float32
conv6.scale torch.float32
fc.weight torch.qint8
fc.bias torch.float32
fc.scale torch.float32
-------end------
conv1.weight torch.float32
conv1.bias torch.float32
conv2.weight torch.float32
conv2.bias torch.float32
conv3.weight torch.float32
conv3.bias torch.float32
conv4.weight torch.float32
conv4.bias torch.float32
conv5.weight torch.float32
conv5.bias torch.float32
conv6.weight torch.float32
conv6.bias torch.float32
bn1.weight torch.float32
bn1.bias torch.float32
bn1.running_mean torch.float32
bn1.running_var torch.float32
bn1.num_batches_tra

In [31]:
#wrong code
#for key in sd:
#    if key.endswith('.weight'):
 #       layer_name = key[:-7]  # Remove '.weight' suffix to get layer name
  #      conv_w = sd[key]
   #     conv_b = sd[layer_name + '.bias']
    #    layer_name2 = 
     #   # Fuse and quantize the convolution weights and biases with batch normalization parameters
      #  fused_conv_w, fused_conv_b = fuse_conv_bn_weights(conv_w, conv_b, bn_rm[layer_name], bn_rv[layer_name], bn_w[layer_name], bn_b[layer_name])
       # # Store the fused and quantized weights and biases into the quantized state_dict
        #qsd[key] = fused_conv_w
        #qsd[layer_name + '.bias'] = fused_conv_b

SyntaxError: invalid syntax (<ipython-input-31-dcc6201c0b1b>, line 7)

Now that we have the fused parameters, we still require the right scales for the activations. For that we "observe" the activation scales in the unquantized network using a calibration "batch", reuse the function `tensor_scale`

<span style="color:green">Your Tasks:</span>  
* <span style="color:green">Directly calculate the required scales in the forward pass, e.g. the scale for the inital quantization, and the output scale for each fused operation, and final output scale (the output of the FC layer).</span>
* <span style="color:green">There is already an inherited version of CifarNet provided, where you only have to redefine the forward pass and add the calculated scales to the `calibration_dict`. We will later use them to set the remaining scales in our quantized state_dict.</span>
* <span style="color:green">It is sufficient to estimate the scales in only one forward pass (we can make the batchsize large).</span>

In [22]:
class CifarNetCalibration(CifarNet):
    def __init__(self):
        super(CifarNetCalibration, self).__init__()
        
        #we add a new dict for the corresponding scales
        self.calibration_dict = {}
        
    def forward(self, x):
        
        #to-be-done-by-student
        ###
        ###
        ###
        #--to---be---done---by---student
        # Initial quantization scale
        self.calibration_dict['input_scale'] = torch.tensor(tensor_scale(x))
        
        sd = self.state_dict()
        convlayers = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6]
        
        for layer in range(1,7):
            c = 'conv' + str(layer)
            b = 'bn' +str(layer)
            
            fused_conv, fused_bias = fuse_conv_bn_weights(sd[c + '.weight'], 
                                                          sd[c + '.bias'], 
                                                          sd[b + '.running_mean'], 
                                                          sd[b + '.running_var'], 
                                                          sd[b + '.weight'], 
                                                          sd[b + '.bias'])
            
                
            convlayers[layer-1].weight = nn.Parameter(fused_conv)
            convlayers[layer-1].bias = nn.Parameter(fused_bias)
            
            x = convlayers[layer-1](x)
            
            conv_scale = tensor_scale(x)
            self.calibration_dict[c + '.weight'] = torch.quantize_per_tensor(fused_conv, scale = conv_scale, zero_point = 0, dtype = torch.qint8)
            #self.calibration_dict[c + '.bias'] = fused_bias
            self.calibration_dict[c + '.scale'] = conv_scale
            
            if layer % 2 == 0:
                x = F.max_pool2d(x, 2, stride = 2)
        
        x = torch.flatten(x, 1)
        x = self.fc(x)
        fc_scale = torch.tensor(tensor_scale(x))
        self.calibration_dict['fc.weight'] = torch.quantize_per_tensor(sd['fc.weight'], scale = fc_scale, zero_point = 0, dtype = torch.qint8)
        self.calibration_dict['fc.bias'] = sd['fc.bias']
        self.calibration_dict['fc.scale'] = fc_scale
        
        return x 

In [24]:
#We run the calibration using a batch from the testdata
net_calib = CifarNetCalibration()
net_calib.load_state_dict(torch.load('state_dict.pt'))
_, (data, _) = next(enumerate(testloader))
net_calib(data)
calibration_dict = net_calib.calibration_dict

<span style="color:green">Your Task:</span>  
* <span style="color:green">Now, transfer the scales into the state_dict `qsd`.</span>

In [25]:
#-to-be-done- by student 
###
###
#-to-be-done- by student 
#qnet = CifarNet()  # Or your specific quantized network class
#qsd = qnet.state_dict()

# Transfer scales to the quantized state_dict
#qsd['input_scale'] = calibration_dict['input_scale']
#qsd['conv1_output_scale'] = calibration_dict['conv1_output_scale']
#qsd['conv2_output_scale'] = calibration_dict['conv2_output_scale']
#qsd['fc_output_scale'] = calibration_dict['fc_output_scale']
#qsd['fc2_output_scale'] = calibration_dict['fc2_output_scale']
for key in calibration_dict:
    qsd[key] = calibration_dict[key]

# Save the updated quantized state_dict
#torch.save(qsd, 'quantized_state_dict_with_scales.pt')

# Print the quantized state_dict with scales to verify
for key, value in qsd.items():
    print(f"{key}: {type(value)}")


scale: <class 'torch.Tensor'>
conv1.weight: <class 'torch.Tensor'>
conv1.bias: <class 'torch.Tensor'>
conv1.scale: <class 'float'>
conv2.weight: <class 'torch.Tensor'>
conv2.bias: <class 'torch.Tensor'>
conv2.scale: <class 'float'>
conv3.weight: <class 'torch.Tensor'>
conv3.bias: <class 'torch.Tensor'>
conv3.scale: <class 'float'>
conv4.weight: <class 'torch.Tensor'>
conv4.bias: <class 'torch.Tensor'>
conv4.scale: <class 'float'>
conv5.weight: <class 'torch.Tensor'>
conv5.bias: <class 'torch.Tensor'>
conv5.scale: <class 'float'>
conv6.weight: <class 'torch.Tensor'>
conv6.bias: <class 'torch.Tensor'>
conv6.scale: <class 'float'>
fc.weight: <class 'torch.Tensor'>
fc.bias: <class 'torch.Tensor'>
fc.scale: <class 'torch.Tensor'>
input_scale: <class 'torch.Tensor'>


In [28]:
#We run the accuracy test again to see how much accuracy we loose through quantization
print(f'Time quantized: {net_time(QCifarNet, testloader)} s')
print(f"Accuracy quantized: {net_acc(QCifarNet, qsd, testloader):.4%}")

Time quantized: 1.7280635833740234 s


AttributeError: 'float' object has no attribute 'shape'