In [1]:
!pip install -e . -qq

In [1]:
import yaml
from src.utils import *
from src.data_loader import get_data_loaders
from src.models.cnn_model import M5, QATM5, PTQM5
from src.models.cnn_model_LayerWiseQuant import M5Modular, PTQM5Modular, PTQM5_LayerWiseQuant, QATM5Modular, QATM5_LayerWiseQuant
from src.train import set_seed, train_model
from src.evaluate import evaluate_model, test

with open("configs/cnn_ptq_LayerWiseQuant.yaml") as f:
    config = yaml.safe_load(f)

# Get data
train_loader, test_loader, validate_loader = get_data_loaders(config["data"])


# Initialize model
model_config = config["model"]["base_cnn"]

In [7]:
num_classes_train = len(set(label for _, _, label, *_ in train_loader.dataset))
num_classes_test = len(set(label for _, _, label, *_ in test_loader.dataset))
num_classes_validate = len(set(label for _, _, label, *_ in validate_loader.dataset))
print(num_classes_train, num_classes_test, num_classes_validate)

35 35 35


In [8]:
print("length of train_loader.dataset:", len(train_loader.dataset))
print("length of test_loader.dataset:", len(test_loader.dataset))
print("length of validate_loader.dataset:", len(validate_loader.dataset))

length of train_loader.dataset: 84843
length of test_loader.dataset: 11005
length of validate_loader.dataset: 9981


In [9]:
n_all = len(train_loader.dataset) + len(test_loader.dataset) + len(validate_loader.dataset)
print("Total number of samples across all datasets:", n_all)
print("ratio of train to test to validate datasets:",
      len(train_loader.dataset) / n_all, len(test_loader.dataset) / n_all, len(validate_loader.dataset) / n_all)

Total number of samples across all datasets: 105829
ratio of train to test to validate datasets: 0.8016989672018067 0.10398850976575419 0.09431252303243913


In [12]:
model_qat = QATM5Modular(
            n_input=model_config["n_input"],
            n_output=model_config["n_output"],
            stride=model_config["stride"],
            n_channel=model_config["n_channel"],
            conv_kernel_sizes=model_config["conv_kernel_sizes"]
        )

model_qat.eval()
model_qat.fuse_model()
model_qat.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
print("Model QAT config:", model_qat.qconfig)
print("Activations:", model_qat.qconfig.activation)
# print("Activations Schemes:", model_qat.qconfig.activation().scheme)
print("Weights:", model_qat.qconfig.weight)
# print("Weights Schemes:", model_qat.qconfig.weight().scheme)


Model QAT config: QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
Activations: functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=True){}
Weights: functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, quant_min=-128, quant_max=127

In [4]:
# set_seed(config["model"]["base_cnn"]["seed"])

# Load FP32 model
model_fp32 = PTQM5Modular(n_input=model_config["n_input"],
            n_output=model_config["n_output"],
            stride=model_config["stride"],
            n_channel=model_config["n_channel"],
            conv_kernel_sizes=model_config["conv_kernel_sizes"]).to('cpu')
model_fp32.eval()
model_fp32.load_state_dict(torch.load(model_config["pretrained_path"]))
model_fp32.fuse_model()
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')

print(model_fp32.qconfig.activation)
print(model_fp32.qconfig.weight)
act_obs = model_fp32.qconfig.activation()
wt_obs = model_fp32.qconfig.weight()

print(act_obs.qscheme)
print(wt_obs.qscheme)

functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}
functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){}
torch.per_tensor_affine
torch.per_channel_symmetric




| No. | Activation Observer                     | Weight Observer                             | Details                                      |
|-----|-----------------------------------------|---------------------------------------------|----------------------------------------------|
| 01  | HistogramObserver                       | PerChannelMinMaxObserver                    | Default                                       |
| 02  | HistogramObserver                       | MovingAveragePerChannelMinMaxObserver       |                                              |
| 03  | HistogramObserver                       | MinMaxObserver                              |                                              |
| 04  | HistogramObserver                       | MovingAverageMinMaxObserver                 |                                              |
| 05  | HistogramObserver                       | HistogramObserver                           | Fully histogram-based                        |
| 06  | MinMaxObserver                          | PerChannelMinMaxObserver                    | Fastest; not robust to outliers              |
| 07  | MinMaxObserver                          | MovingAveragePerChannelMinMaxObserver       |                                              |
| 08  | MinMaxObserver                          | MinMaxObserver                              |                                              |
| 09  | MinMaxObserver                          | MovingAverageMinMaxObserver                 |                                              |
| 10  | MinMaxObserver                          | HistogramObserver                           |                                              |
| 11  | MovingAverageMinMaxObserver             | PerChannelMinMaxObserver                    | Balanced and smoother                        |
| 12  | MovingAverageMinMaxObserver             | MovingAveragePerChannelMinMaxObserver       | Recommended for most conv nets               |
| 13  | MovingAverageMinMaxObserver             | MinMaxObserver                              |                                              |
| 14  | MovingAverageMinMaxObserver             | MovingAverageMinMaxObserver                 |                                              |
| 15  | MovingAverageMinMaxObserver             | HistogramObserver                           |                                              |
| 16  | HistogramObserver (Entropy/MSE)         | PerChannelMinMaxObserver                    | Use Entropy/MSE clipping for activation      |
| 17  | HistogramObserver (Entropy/MSE)         | MovingAveragePerChannelMinMaxObserver       |                                              |
| 18  | HistogramObserver (Percentile Clipping) | PerChannelMinMaxObserver                    | Specify percentile like 99.9%                |
| 19  | HistogramObserver (Percentile Clipping) | MovingAveragePerChannelMinMaxObserver       |                                              |


In [None]:
from torch.ao.quantization.observer import (
    HistogramObserver,
    PerChannelMinMaxObserver,
    MovingAveragePerChannelMinMaxObserver
)
from torch.ao.quantization import QConfig

# Example: Use HistogramObserver for activations, PerChannelMinMaxObserver for weights
qconfig_movingAvgPerChannelMinMax = QConfig(
    activation=HistogramObserver.with_args(reduce_range=True),
    weight=MovingAveragePerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
)
model_movingAvgPerChannelMinMax = model_fp32.deepcopy()
model_movingAvgPerChannelMinMax.qconfig = qconfig_movingAvgPerChannelMinMax

In [None]:
torch.ao.quantization.prepare(model_fp32, inplace=True)

# Calibrate model - use validation set
with torch.inference_mode():
    for data, _ in validate_loader:
        data = data.to("cpu")
        model_fp32(data)

# Convert to PTQ model
model = torch.ao.quantization.convert(model_fp32, inplace=False)
print(test(model, train_loader))
print(test(model, test_loader))