# Analysis and Visualization on fp and quantized models

## 1. The keys in the model checkpoints

In [28]:
import torch

fp_dict = torch.load("../models/fp32_model.pth")
print("The FP model has", len(fp_dict.keys()), "keys. The keys in FP model dict are: ")
for key in fp_dict:
    print(key, fp_dict[key].dtype, fp_dict[key].size())

The FP model has 30 keys. The keys in FP model dict are: 
conv1.weight torch.float32 torch.Size([32, 1, 80])
conv1.bias torch.float32 torch.Size([32])
bn1.weight torch.float32 torch.Size([32])
bn1.bias torch.float32 torch.Size([32])
bn1.running_mean torch.float32 torch.Size([32])
bn1.running_var torch.float32 torch.Size([32])
bn1.num_batches_tracked torch.int64 torch.Size([])
conv2.weight torch.float32 torch.Size([32, 32, 3])
conv2.bias torch.float32 torch.Size([32])
bn2.weight torch.float32 torch.Size([32])
bn2.bias torch.float32 torch.Size([32])
bn2.running_mean torch.float32 torch.Size([32])
bn2.running_var torch.float32 torch.Size([32])
bn2.num_batches_tracked torch.int64 torch.Size([])
conv3.weight torch.float32 torch.Size([64, 32, 3])
conv3.bias torch.float32 torch.Size([64])
bn3.weight torch.float32 torch.Size([64])
bn3.bias torch.float32 torch.Size([64])
bn3.running_mean torch.float32 torch.Size([64])
bn3.running_var torch.float32 torch.Size([64])
bn3.num_batches_tracked torch.

In [29]:
qat_dict = torch.load("../models/qat_model.pth")
print("The quantized model has", len(qat_dict.keys()), "keys. The keys in qat model dict are: ")
for key in qat_dict:
    value = qat_dict[key]
    if isinstance(value, torch.Tensor):
        print(key, value.dtype, value.size())
    else:
        print(key, type(value))

The quantized model has 22 keys. The keys in qat model dict are: 
conv1.weight torch.qint8 torch.Size([32, 1, 80])
conv1.bias torch.float32 torch.Size([32])
conv1.scale torch.float32 torch.Size([])
conv1.zero_point torch.int64 torch.Size([])
conv2.weight torch.qint8 torch.Size([32, 32, 3])
conv2.bias torch.float32 torch.Size([32])
conv2.scale torch.float32 torch.Size([])
conv2.zero_point torch.int64 torch.Size([])
conv3.weight torch.qint8 torch.Size([64, 32, 3])
conv3.bias torch.float32 torch.Size([64])
conv3.scale torch.float32 torch.Size([])
conv3.zero_point torch.int64 torch.Size([])
conv4.weight torch.qint8 torch.Size([64, 64, 3])
conv4.bias torch.float32 torch.Size([64])
conv4.scale torch.float32 torch.Size([])
conv4.zero_point torch.int64 torch.Size([])
fc1.scale torch.float32 torch.Size([])
fc1.zero_point torch.int64 torch.Size([])
fc1._packed_params.dtype <class 'torch.dtype'>
fc1._packed_params._packed_params <class 'tuple'>
quant.scale torch.float32 torch.Size([1])
quant.zero

The keys in the quantized model are different from the keys in the FP model. The difference is due to the following reasons:
-  BatchNorm Folding: 

    --`bn.weight`, `bn.bias`, `running_mean`, `running_var`

    During QAT, BatchNorm layers are fused with their corresponding `Conv1d` layers using `fuse_model()`. And the running statistics and batch norm parameters are folded into the `Conv1d` weights and biases. 

- Quantization Related Parameters: 

    ++`convX.scale`, `convX.zero_point`

- FC Parameters Packing: 

    --`fc1.weight`, `fc1.bias` 

    ++`fc1._packed_params.dtype`, `fc1._packed_params`, `fc1.scale`, `fc1.zero_point`

    Since the `nn.Linear` in fp model is replaced by `torch.ao.nn.qat.Linear` the quantized version, the keys of fc layers varied. 

    - `fc1._packed_params.dtype` stores the data type of the quantized weights in `fc1` (i.e. `torch.qint8`). 
    - `fc1._packed_params._packed_params` has 2 elements. The first one is quantized weight tensor, indicating its quantization scheme, scale and zero_point for each channel as well. The second one element is the bias tensor in float32. Usually the bias are not quantized.




In [47]:
print(len(qat_dict["fc1._packed_params._packed_params"]))
print(qat_dict["fc1._packed_params._packed_params"][0].dtype)
print(qat_dict["fc1._packed_params._packed_params"][0].type())
print(qat_dict["fc1._packed_params._packed_params"][0].size())
print(qat_dict["fc1._packed_params._packed_params"][0])

print(qat_dict["fc1._packed_params._packed_params"][1].dtype)
print(qat_dict["fc1._packed_params._packed_params"][1].type())
print(qat_dict["fc1._packed_params._packed_params"][1].size())

print(qat_dict["fc1.scale"])
print(qat_dict["fc1.zero_point"])


2
torch.qint8
torch.quantized.QInt8Tensor
torch.Size([35, 64])
tensor([[ 0.0344, -0.1279, -0.1894,  ...,  0.0000, -0.0271,  0.0025],
        [-0.0525, -0.0739,  0.0447,  ...,  0.0000, -0.0311, -0.0097],
        [-0.0202, -0.0173,  0.1080,  ...,  0.0000,  0.0086, -0.0043],
        ...,
        [-0.1384, -0.0336,  0.1285,  ...,  0.0000,  0.0040, -0.0040],
        [-0.0579,  0.1221,  0.0150,  ...,  0.0000, -0.0450,  0.0000],
        [ 0.1803,  0.0767,  0.0690,  ...,  0.0000, -0.0096,  0.0058]],
       size=(35, 64), dtype=torch.qint8,
       quantization_scheme=torch.per_channel_affine,
       scale=tensor([0.0025, 0.0019, 0.0014, 0.0017, 0.0014, 0.0020, 0.0024, 0.0021, 0.0020,
        0.0018, 0.0020, 0.0015, 0.0027, 0.0019, 0.0021, 0.0019, 0.0027, 0.0024,
        0.0016, 0.0022, 0.0022, 0.0020, 0.0020, 0.0022, 0.0028, 0.0019, 0.0018,
        0.0020, 0.0020, 0.0016, 0.0022, 0.0022, 0.0020, 0.0021, 0.0019],
       dtype=torch.float64),
       zero_point=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0

### Overall
| Components | FP model keys (30) | QAT model keys (22) |
| --- | --- | --- |
| Convolution layers | `convX.weight`, `convX.bias` | `convX.weight`, `convX.bias`, `convX.scale`, `convX.zero_point` |
| BatchNorm layers | `bnX.weight`, `bnX.bias`, `bnX.running_mean`, `bnX.running_var` | Folded into `convX` |
| FC layers | `fc1.weight`, `fc1.bias` | `fc1.scale`, `fc1.zero_point`, `fc1._packed_params.dtype`, `fc1._packed_params` |
| Quant Stubs |  | `quant.scale`, `quant.zero_point` |