# Quantization-Aware Training (QAT)

This tutorial demonstrates how to fine-tune a YOLO-NAS model with fake-quantization nodes for higher INT8 accuracy.

## Prerequisites

You need a trained checkpoint and training data.

In [None]:
import lightning as L
from pathlib import Path
from modern_yolonas import yolo_nas_s
from modern_yolonas.data import YOLODetectionDataset, load_dataset_config
from modern_yolonas.data.transforms import Compose, HSVAugment, HorizontalFlip, LetterboxResize, Normalize
from modern_yolonas.training import (
    YoloNASLightningModule, QATCallback, DetectionDataModule, extract_model_state_dict,
)
from modern_yolonas.quantization import prepare_model_qat, convert_quantized, export_quantized_onnx

# Load trained model
model = yolo_nas_s(pretrained=True)  # or load from checkpoint
# sd = extract_model_state_dict("runs/hardhat/last.ckpt")
# model.load_state_dict(sd)

# Prepare for QAT (inserts fake-quant nodes)
qat_model = prepare_model_qat(model)
print("QAT model prepared with fake-quantization nodes")

# Setup data
# cfg = load_dataset_config("path/to/data.yaml")
# train_transforms = Compose([HSVAugment(), HorizontalFlip(p=0.5), LetterboxResize(target_size=640), Normalize()])
# val_transforms = Compose([LetterboxResize(target_size=640), Normalize()])
# train_ds = YOLODetectionDataset(root=cfg.root, split=cfg.train_split, transforms=train_transforms)
# val_ds = YOLODetectionDataset(root=cfg.root, split=cfg.val_split, transforms=val_transforms)

# QAT fine-tuning — must use precision="32-true" (no AMP)
# qat_lit_model = YoloNASLightningModule(model=qat_model, num_classes=cfg.num_classes, lr=2e-5, warmup_steps=50)
# qat_data_module = DetectionDataModule(train_dataset=train_ds, val_dataset=val_ds, batch_size=8, num_workers=2)

# qat_trainer = L.Trainer(
#     max_epochs=5, accelerator="auto",
#     precision="32-true",  # No AMP — fake-quant incompatible with autocast
#     callbacks=[QATCallback(freeze_bn_after_epoch=3, freeze_observer_after_epoch=5)],
#     default_root_dir="runs/qat",
# )
# qat_trainer.fit(qat_lit_model, datamodule=qat_data_module)
# print("QAT training complete!")

In [None]:
# Convert QAT model and export
# qat_quantized = convert_quantized(qat_model)
# qat_onnx_path = "model_qat_int8.onnx"
# export_quantized_onnx(qat_quantized, qat_onnx_path, input_size=640)
#
# qat_size = Path(qat_onnx_path).stat().st_size / (1024 * 1024)
# print(f"QAT ONNX exported to: {qat_onnx_path}")
# print(f"QAT model size: {qat_size:.1f} MB")

print("Uncomment the cells above after setting up your data and checkpoint paths.")