# Static Quantization for Yolo
---

In [None]:
# Install libraries
%pip install opencv-python ultralytics onnxruntime numpy 

In [2]:
# Libraries

# Built-in
from pathlib import Path

# Third-party
import cv2
import numpy as np
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType, QuantFormat
from ultralytics import YOLO

In [3]:
model_fp32 = Path('models/best_fp32_preprocessed.onnx')
model_quant = Path('models/static_int8_quant.onnx')
calibration_images_path = Path("tcc-1/valid/images/")

In [32]:
# Static Quantization for YoloV8
class ImageCalibrationDataReader(CalibrationDataReader):
    def __init__(self, image_paths):
        self.image_paths = image_paths
        self.idx = 0
        self.input_name = "images"

    def preprocess(self, frame):
        frame = cv2.imread(frame)
        X = cv2.resize(frame, (640, 640))
        image_data = np.array(X).astype(np.float32) / 255.0
        image_data = np.transpose(image_data, (2, 0, 1))
        image_data = np.expand_dims(image_data, axis=0)
        return image_data

    def get_next(self):
        if self.idx >= len(self.image_paths):
            return None

        image_path = self.image_paths[self.idx]
        input_data = self.preprocess(image_path)
        self.idx += 1
        return {self.input_name: input_data}


calibration_data_reader = ImageCalibrationDataReader(list(calibration_images_path.glob("*.jpg")))

In [33]:
quantize_static(model_fp32,
                model_quant,
                weight_type=QuantType.QInt8,
                activation_type=QuantType.QUInt8,
                calibration_data_reader=calibration_data_reader,
                quant_format=QuantFormat.QDQ,
                nodes_to_exclude=['/model.22/Concat_3', '/model.22/Split', '/model.22/Sigmoid'
                                '/model.22/dfl/Reshape', '/model.22/dfl/Transpose', '/model.22/dfl/Softmax', 
                                '/model.22/dfl/conv/Conv', '/model.22/dfl/Reshape_1', '/model.22/Slice_1',
                                '/model.22/Slice', '/model.22/Add_1', '/model.22/Sub', '/model.22/Div_1',
                                '/model.22/Concat_4', '/model.22/Mul_2', '/model.22/Concat_5'],
                per_channel=False,
                reduce_range=True,)

In [None]:
model = YOLO(model_quant, task="detect")
metrics = model.val(data="tcc-1/data.yaml", split='val')
print(metrics)