# Experiment 1 Per-Layer Sensitivity

The idea of this experienment Quantize only one block at a time, keep others FP32.

This experiment investigates the layer-wise sensitivity of a deep CNN model to quantization(PTQ, QAT). The goal is to understand how quantizing individual layers affects overall model accuracy, and to identify which layers are more robust or more sensitive to quantization-induced degradation. 

To achieve this, a modified, modularized variant of original deep CNN was implemented that allows selective quantization of individual convolutional blocks (L1–L4). For each run, only one block was quantized at a time while the rest remained in FP32. `qconfig` was applied only to the target block and `quant/dequant` stubs to isolate the effect.

## 1. PTQ 

In [1]:
import torch
import yaml
from src.utils import *
from src.data_loader import get_data_loaders
from src.models.cnn_model_LayerWiseQuant import M5Modular, QATM5Modular, QATM5_LayerWiseQuant, PTQM5Modular, PTQM5_LayerWiseQuant
from src.evaluate import test
fp_dict = torch.load("../models/cnn_fp32_model.pth")
ptq_dict = torch.load("../models/cnn_ptq_model.pth")
LWQ_dict_dicts = {
    1: torch.load("../models/cnn_ptq_LayerWiseQuant_q1_model.pth"),
    2: torch.load("../models/cnn_ptq_LayerWiseQuant_q2_model.pth"),
    3: torch.load("../models/cnn_ptq_LayerWiseQuant_q3_model.pth"),
    4: torch.load("../models/cnn_ptq_LayerWiseQuant_q4_model.pth"),
}

data_config = {
    "raw_dir": "../data/raw",
    "processed_dir": "../data/processed",
    "sample_rate": 8000,
    "batch_size": 256,
    "version": "v0.1"
}
train_loader, test_loader, validate_loader = get_data_loaders(data_config)
# print("Length of train_loader: ", len(train_loader))
# print("Length of test_loader: ", len(test_loader))
# print("Length of validate_loader: ", len(validate_loader))
# print("Shape of a batch in train_loader: ", next(iter(train_loader))[0].shape)


  device=storage.device,


In [2]:
# Check the keys in the dictionaries
LWPTQ_dict_sample = LWQ_dict_dicts[1]
ptq_dict = torch.load("../models/cnn_ptq_model.pth")
print("Keys in fp_dict:")
for key in fp_dict.keys():
    print(key)
print("****************************************")
print("Keys in LWPTQ_dict_sample:")
for key in LWPTQ_dict_sample.keys():
    print(key)
print("****************************************")
print("Keys in ptq_dict:")
for key in ptq_dict.keys():
    print(key)
print("****************************************")

Keys in fp_dict:
block1.block.0.weight
block1.block.0.bias
block1.block.1.weight
block1.block.1.bias
block1.block.1.running_mean
block1.block.1.running_var
block1.block.1.num_batches_tracked
block2.block.0.weight
block2.block.0.bias
block2.block.1.weight
block2.block.1.bias
block2.block.1.running_mean
block2.block.1.running_var
block2.block.1.num_batches_tracked
block3.block.0.weight
block3.block.0.bias
block3.block.1.weight
block3.block.1.bias
block3.block.1.running_mean
block3.block.1.running_var
block3.block.1.num_batches_tracked
block4.block.0.weight
block4.block.0.bias
block4.block.1.weight
block4.block.1.bias
block4.block.1.running_mean
block4.block.1.running_var
block4.block.1.num_batches_tracked
fc1.weight
fc1.bias
****************************************
Keys in LWPTQ_dict_sample:
block1.block.0.weight
block1.block.0.bias
block1.block.0.scale
block1.block.0.zero_point
block2.block.0.0.weight
block2.block.0.0.bias
block3.block.0.0.weight
block3.block.0.0.bias
block4.block.0.0.w

In [3]:
# Ckeck keys in fp_dict after fusion
fp_dict = torch.load("../models/cnn_fp32_model.pth")
with open('../configs/cnn_fp32.yaml') as f:
        config = yaml.safe_load(f)
model = PTQM5Modular(
    n_input=config["model"]["base_cnn"]["n_input"],
    n_output=config["model"]["base_cnn"]["n_output"],
    stride=config["model"]["base_cnn"]["stride"],
    n_channel=config["model"]["base_cnn"]["n_channel"],
    conv_kernel_sizes=config["model"]["base_cnn"]["conv_kernel_sizes"]).to('cpu')
model.eval()
model.load_state_dict(fp_dict)
print("Keys in fp_dict before fusion:")
for key in model.state_dict().keys(): 
    print(key)
print("****************************************")

model.fuse_model() 

print("Keys in fp_dict after fusion:")
for key in model.state_dict().keys():
    print(key)
# Keys changed after fusion

Keys in fp_dict before fusion:
block1.block.0.weight
block1.block.0.bias
block1.block.1.weight
block1.block.1.bias
block1.block.1.running_mean
block1.block.1.running_var
block1.block.1.num_batches_tracked
block2.block.0.weight
block2.block.0.bias
block2.block.1.weight
block2.block.1.bias
block2.block.1.running_mean
block2.block.1.running_var
block2.block.1.num_batches_tracked
block3.block.0.weight
block3.block.0.bias
block3.block.1.weight
block3.block.1.bias
block3.block.1.running_mean
block3.block.1.running_var
block3.block.1.num_batches_tracked
block4.block.0.weight
block4.block.0.bias
block4.block.1.weight
block4.block.1.bias
block4.block.1.running_mean
block4.block.1.running_var
block4.block.1.num_batches_tracked
fc1.weight
fc1.bias
****************************************
Keys in fp_dict after fusion:
block1.block.0.0.weight
block1.block.0.0.bias
block2.block.0.0.weight
block2.block.0.0.bias
block3.block.0.0.weight
block3.block.0.0.bias
block4.block.0.0.weight
block4.block.0.0.bia

In [4]:
# Check keys in qat_dict 
qat_dict = torch.load("../models/cnn_qat_model.pth")
print("Keys in qat_dict:")
for key in qat_dict.keys():
    print(key)
print("****************************************")

LWQAT_dict_sample = torch.load("../models/cnn_qat_LayerWiseQuant_q1_model.pth")
print("Keys in LWQAT_dict_sample:")
for key in LWQAT_dict_sample.keys():
    print(key)
print("****************************************")

Keys in qat_dict:
block1.block.0.weight
block1.block.0.bias
block1.block.0.scale
block1.block.0.zero_point
block2.block.0.weight
block2.block.0.bias
block2.block.0.scale
block2.block.0.zero_point
block3.block.0.weight
block3.block.0.bias
block3.block.0.scale
block3.block.0.zero_point
block4.block.0.weight
block4.block.0.bias
block4.block.0.scale
block4.block.0.zero_point
fc1.scale
fc1.zero_point
fc1._packed_params.dtype
fc1._packed_params._packed_params
quant.scale
quant.zero_point
****************************************
Keys in LWQAT_dict_sample:
block1.block.0.weight
block1.block.0.bias
block1.block.0.scale
block1.block.0.zero_point
block2.block.0.0.weight
block2.block.0.0.bias
block3.block.0.0.weight
block3.block.0.0.bias
block4.block.0.0.weight
block4.block.0.0.bias
fc1.weight
fc1.bias
quant.scale
quant.zero_point
****************************************


In [5]:
# Load FP model
config_fp = '../configs/cnn_fp32.yaml'
with open(config_fp, 'r') as f:
    config = yaml.safe_load(f)
    
params_fp = config["model"]["base_cnn"]
model_fp = M5Modular(
        n_input=params_fp["n_input"],
        n_output=params_fp["n_output"],
        stride=params_fp["stride"],
        n_channel=params_fp["n_channel"],
        conv_kernel_sizes=params_fp["conv_kernel_sizes"]
        )
model_fp.load_state_dict(fp_dict)
model_fp.to('cpu')

# evaluate FP model
acc_fp = test(model_fp, test_loader)
print(f"FP32 model accuracy: {acc_fp:.4f}")

FP32 model accuracy: 83.0713


In [6]:
# Load fully quantized PTQ model
# Load PTQ model
config_PTQ = '../configs/cnn_ptq.yaml'
with open(config_PTQ, 'r') as f:
    config = yaml.safe_load(f)
    
params_PTQ = config["model"]["base_cnn"]
model_PTQ = PTQM5Modular(
            n_input=params_PTQ["n_input"],
            n_output=params_PTQ["n_output"],
            stride=params_PTQ["stride"],
            n_channel=params_PTQ["n_channel"],
            conv_kernel_sizes=params_PTQ["conv_kernel_sizes"]
        )
# Fuse and prepare for quantization
model_PTQ.eval()
model_PTQ.fuse_model()
model_PTQ.qconfig = torch.ao.quantization.get_default_qconfig('x86')

model_PTQ.train()
torch.ao.quantization.prepare_qat(model_PTQ, inplace=True)

# Convert to quantized model
model_PTQ.eval()
model_PTQ = torch.ao.quantization.convert(model_PTQ, inplace=False)

# Load checkpoint
model_PTQ.load_state_dict(ptq_dict)
model_PTQ.to('cpu')

# evaluate PTQ model
acc_PTQ = test(model_PTQ, test_loader)
print(f"PTQ model accuracy: {acc_PTQ:.4f}")



PTQ model accuracy: 75.8473


In [7]:

config_LWQ = '../configs/cnn_ptq_LayerWiseQuant.yaml'
with open(config_LWQ, 'r') as f:
    config = yaml.safe_load(f)

# for i in range(1, 2):
for i in config["model"]["quantization"]:
    model_LWQ = PTQM5_LayerWiseQuant(
        quantized_block_idx = i,
        n_input=config["model"]["base_cnn"]["n_input"],
        n_output=config["model"]["base_cnn"]["n_output"],
        stride=config["model"]["base_cnn"]["stride"],
        n_channel=config["model"]["base_cnn"]["n_channel"],
        conv_kernel_sizes=config["model"]["base_cnn"]["conv_kernel_sizes"],
    )

    # Fuse and prepare for quantization
    model_LWQ.eval()
    # print(f"Layer-Wise Quantized Model before fuse: {model_LWQ}")
    model_LWQ.fuse_model()
    # print(f"Layer-Wise Quantized Model after fuse, before Layer {i} quantized: {model_LWQ}")

    qconfig = torch.ao.quantization.get_default_qconfig('x86')
    model_LWQ.set_qconfig_for_layerwise(qconfig)
    torch.ao.quantization.prepare(model_LWQ, inplace=True)

    # Convert to quantized model
    # model_LWQ.eval()
    model_LWQ = torch.ao.quantization.convert(model_LWQ, inplace=False)
    # print(f"Layer-Wise Quantized Model Layer {i} quantized : {model_LWQ}")
    # # Load checkpoint
    model_LWQ.load_state_dict(LWQ_dict_dicts[i])

    # evaluate single layer quantized model
    acc_LWQ = test(model_LWQ, test_loader)
    print(f"Layer-Wise Quantized Model (Layer {i} quantized) accuracy: {acc_LWQ:.4f}")


Layer-Wise Quantized Model (Layer 1 quantized) accuracy: 76.8287
Layer-Wise Quantized Model (Layer 2 quantized) accuracy: 80.8905
Layer-Wise Quantized Model (Layer 3 quantized) accuracy: 81.8174
Layer-Wise Quantized Model (Layer 4 quantized) accuracy: 82.9714


|Model|Acc|Accuracy Drop (vs. FP32)|
|---|---|---|
|FP32|83.0713|-0.00%|
|PTQ (L4 Quantized)|82.9714|-0.10%|
|PTQ (L3 Quantized)|81.8174|-1.25%|
|PTQ (L2 Quantized)|80.8905|-2.18%|
|PTQ (L1 Quantized)|76.8287|-6.24%|
|PTQ (Fully Quantized)|75.8473|-7.22%|


> Accuracy Drop=FP32 Accuracy−Quantized Model Accuracy

## Insights

1. Early layers (especially L1) are highly sensitive to quantization and significantly degrade accuracy when quantized. Later layers (L3, L4) are more robust. Early layers handle raw features, which are more sensitive to quantization noise. Later layers operate on higher-level representations and are more robust to quantization.

2. Compared with L1-only quantization, fully quantized model reduces accuracy further, indicating accumulated quantization noise.

3. Layer-wise PTQ provides insight into per-layer sensitivity, which can guide efficient mixed-precision or hybrid quantization strategies, e.g. If we choose mixed-precision, keep front layers in FP32 and quantize later layers; If applying QAT, prioritize front layers, etc.

In [None]:
x = (83.0713 - 82.9714) 
print(f"PTQ model accuracy drop L4: {x}")
x = (83.0713 - 81.8174) 
print(f"PTQ model accuracy drop L3: {x}")
x = (83.0713 - 80.8905) 
print(f"PTQ model accuracy drop L2: {x}")
x = (83.0713 - 76.8287) 
print(f"PTQ model accuracy drop L1: {x}")

In [None]:
# Construct a new combined checkpoint dictionary from layer-wise PTQ models
combined_dict = {}

# Define mapping of block keys to their source dict
for i in range(1, 5):
    prefix = f"block{i}.block.0"
    for suffix in ["weight", "bias", "scale", "zero_point"]:
        key = f"{prefix}.{suffix}"
        combined_dict[key] = LWQ_dict_dicts[i][key]

# Add remaining non-block keys (e.g., fc1, quant/dequant) from fully quantized model
for key in ptq_dict:
    if not any(key.startswith(f"block{i}.block.0") for i in range(1, 5)):
        combined_dict[key] = ptq_dict[key]
        
# print("Keys in combined_dict:")
# for key in combined_dict.keys():
#     print(key)

In [None]:
model_PTQ_combined = PTQM5Modular(
            n_input=params_PTQ["n_input"],
            n_output=params_PTQ["n_output"],
            stride=params_PTQ["stride"],
            n_channel=params_PTQ["n_channel"],
            conv_kernel_sizes=params_PTQ["conv_kernel_sizes"]
        )
# Fuse and prepare for quantization
model_PTQ_combined.eval()
model_PTQ_combined.fuse_model()
model_PTQ_combined.qconfig = torch.ao.quantization.get_default_qconfig('x86')

model_PTQ_combined.train()
torch.ao.quantization.prepare(model_PTQ_combined, inplace=True)

# Convert to quantized model
model_PTQ_combined.eval()
model_PTQ_combined = torch.ao.quantization.convert(model_PTQ_combined, inplace=False)

# Load checkpoint
model_PTQ_combined.load_state_dict(combined_dict)
model_PTQ_combined.to('cpu')

# evaluate PTQ model
acc_PTQ = test(model_PTQ_combined, test_loader)
print(f"Combined PTQ model accuracy: {acc_PTQ:.4f}")

## 2. QAT

In [None]:
# Load fully quantized QAT model
# Load QAT model
config_QAT = '../configs/cnn_qat.yaml'
with open(config_QAT, 'r') as f:
    config = yaml.safe_load(f)
    
params_QAT = config["model"]["base_cnn"]
model_QAT = QATM5Modular(
            n_input=params_QAT["n_input"],
            n_output=params_QAT["n_output"],
            stride=params_QAT["stride"],
            n_channel=params_QAT["n_channel"],
            conv_kernel_sizes=params_QAT["conv_kernel_sizes"]
        )
# Fuse and prepare for quantization
model_QAT.eval()
model_QAT.fuse_model()
model_QAT.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')

model_QAT.train()
torch.ao.quantization.prepare_qat(model_QAT, inplace=True)

# Convert to quantized model
model_QAT.eval()
model_QAT = torch.ao.quantization.convert(model_QAT, inplace=False)

# Load checkpoint
qat_dict = torch.load("../models/cnn_qat_model.pth")
model_QAT.load_state_dict(qat_dict)
model_QAT.to('cpu')

# evaluate QAT model
acc_QAT = test(model_QAT, test_loader)
print(f"QAT model accuracy: {acc_QAT:.4f}")

In [None]:
config_LWQ = '../configs/cnn_qat_LayerWiseQuant.yaml'
with open(config_LWQ, 'r') as f:
    config = yaml.safe_load(f)

LWQ_QAT_dict_dicts = {
    1: torch.load("../models/cnn_qat_LayerWiseQuant_q1_model.pth"),
    2: torch.load("../models/cnn_qat_LayerWiseQuant_q2_model.pth"),
    3: torch.load("../models/cnn_qat_LayerWiseQuant_q3_model.pth"),
    4: torch.load("../models/cnn_qat_LayerWiseQuant_q4_model.pth"),
}

for i in config["model"]["quantization"]:
    model_LWQ = QATM5_LayerWiseQuant(
        quantized_block_idx = i,
        n_input=config["model"]["base_cnn"]["n_input"],
        n_output=config["model"]["base_cnn"]["n_output"],
        stride=config["model"]["base_cnn"]["stride"],
        n_channel=config["model"]["base_cnn"]["n_channel"],
        conv_kernel_sizes=config["model"]["base_cnn"]["conv_kernel_sizes"],
    )

    # Fuse and prepare for quantization
    model_LWQ.eval()
    # print(f"Layer-Wise Quantized Model before fuse: {model_LWQ}")
    model_LWQ.fuse_model()
    # print(f"Layer-Wise Quantized Model after fuse, before Layer {i} quantized: {model_LWQ}")

    qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
    model_LWQ.set_qconfig_for_layerwise(qconfig)
    torch.ao.quantization.prepare(model_LWQ, inplace=True)

    # Convert to quantized model
    # model_LWQ.eval()
    model_LWQ = torch.ao.quantization.convert(model_LWQ, inplace=False)
    # print(f"Layer-Wise Quantized Model Layer {i} quantized : {model_LWQ}")
    # # Load checkpoint
    model_LWQ.load_state_dict(LWQ_QAT_dict_dicts[i])

    # evaluate single layer quantized model
    acc_LWQ = test(model_LWQ, test_loader)
    print(f"Layer-Wise Quantized Model (Layer {i} quantized) accuracy: {acc_LWQ:.4f}")

|Model|Acc|Accuracy Drop (vs. FP32)|
|---|---|---|
|FP32|83.0713|-0.00%|
|**QAT (L4 Quantized)**|83.2894|+0.22%(??)|
|PTQ (L4 Quantized)|82.9714|-0.10%|
|**QAT (L3 Quantized)**|82.7442|-0.33%|
|PTQ (L3 Quantized)|81.8174|-1.25%|
|**QAT (L2 Quantized)**|81.2631|-1.81%|
|PTQ (L2 Quantized)|80.8905|-2.18%|
|**QAT (L1 Quantized)**|80.4180|-2.65%|
|PTQ (L1 Quantized)|76.8287|-6.24%|
|**QAT (Fully Quantized)**|79.4639|-3.61%|
|PTQ (Fully Quantized)|75.8473|-7.22%|

> Accuracy Drop=FP32 Accuracy−Quantized Model Accuracy

In [None]:
x = (83.0713 - 83.2894) 
print(f"QAT model accuracy drop L4: {x}")
x = (83.0713 - 82.7442) 
print(f"QAT model accuracy drop L3: {x}")
x = (83.0713 - 81.2631) 
print(f"QAT model accuracy drop L2: {x}")
x = (83.0713 - 80.4180) 
print(f"QAT model accuracy drop L1: {x}")
x = (83.0713 - 79.4639) 
print(f"QAT model accuracy drop: {x}")

## Insights

1. QAT consistently outperforms PTQ

2. Early layers (e.g., L1) are most sensitive to quantization


| Layer | PTQ Drop | QAT Drop   |
| ----- | -------- | ---------- |
| L1    | -6.24%   | -2.65%     |
| L4    | -0.10%   | **+0.22%** |


3.Later layers are robust — QAT even improves L4(???)

- QAT on L4 gives +0.22% gain over FP32.
- This small improvement may be due to full precision of FC layer
