# YOLO-X Tiny Quant example

## Prepare the ENV & Args

In [None]:
import logging
import sys

logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s", stream=sys.stdout)

In [None]:
import argparse
import itertools

import torch
from trainer import Trainer
from yolo_x_tiny_exp import Exp

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", default="./yolox_tiny.pth", type=str, help="pre train checkpoint")
parser.add_argument("--batch-size", type=int, default=64, help="batch size")
parser.add_argument("--random_size_range", type=int, default=3, help="random_size")
parser.add_argument("--experiment_name", type=str, default="0", help="exp name")
parser.add_argument("--data_dir", default="./coco_data", help="Data set directory.")

parser.add_argument("--min_lr_ratio", type=float, default=0.01, help="batch size")
parser.add_argument("--ema_decay", type=float, default=0.9995, help="ema decay reate.")

parser.add_argument("--output_dir", default="./YOLOX_outputs", help="Experiments results save path.")
parser.add_argument("--workers", default=4, type=int, help="Number of data loading workers to be used.")
parser.add_argument("--multiscale_range", default=5, type=int, help="multiscale_range.")
parser.add_argument("--start_epoch", type=int, default=280, help="batch size")
args = parser.parse_args([])

### Init the experiments & trainer

In [None]:
exp = Exp(args)
trainer = Trainer(exp, args)

## Prepare FP32 model & test accuracy

In [None]:
logging.info(f"args: {trainer.args}")
logging.info(f"exp value:\n{trainer.exp}")

In [None]:
model = trainer.exp.get_model()
model.to(trainer.device)
model = trainer.load_pretrain_weight(model)
trainer.model = model

In [None]:
trainer.evaluator = trainer.exp.get_evaluator(batch_size=int(trainer.args.batch_size / 2))

### Evaluate the FP32 model on the COCO val dataset 

In [None]:
*_, summary = trainer.evaluator.evaluate(trainer.model)

In [None]:
print(summary)

## Perform PTQ & evaluate the accuracy

### Prepare Quantization config & Quantizer


In [None]:
from quark.torch import ModelQuantizer
from quark.torch.quantization.config.config import Config, QuantizationConfig, QuantizationSpec
from quark.torch.quantization.config.type import Dtype, QSchemeType, QuantizationMode, RoundType, ScaleType
from quark.torch.quantization.observer.observer import PerTensorPowOf2MinMSEObserver

In [None]:
INT8_PER_WEIGHT_TENSOR_SPEC = QuantizationSpec(
    dtype=Dtype.int8,
    qscheme=QSchemeType.per_tensor,
    observer_cls=PerTensorPowOf2MinMSEObserver,
    symmetric=True,
    scale_type=ScaleType.float,
    round_method=RoundType.half_even,
    is_dynamic=False,
)
quant_config = QuantizationConfig(
    weight=INT8_PER_WEIGHT_TENSOR_SPEC,
    input_tensors=INT8_PER_WEIGHT_TENSOR_SPEC,
    output_tensors=INT8_PER_WEIGHT_TENSOR_SPEC,
    bias=INT8_PER_WEIGHT_TENSOR_SPEC,
)
quant_config = Config(global_quant_config=quant_config, quant_mode=QuantizationMode.fx_graph_mode)
trainer.quantizer = ModelQuantizer(quant_config)

### Prepare calibration Dataset & Fx graph model

In [None]:
calib_data = [x[0].to(trainer.device) for x in list(itertools.islice(trainer.evaluator.dataloader, 1))]
dummy_input = torch.randn(1, 3, *trainer.exp.input_size).to(trainer.device)
trainer.model = trainer.model.eval()

NOTE: Based on the original YOLO_X Tiny repo code, loss calculation and bounding-boxes decode code are integrated in YOLO_X Tiny `forward`, we modify the code and let the `trainer.model.base_model` only contain the backbone network. We only need to quantize this part of the model.

In [None]:
graph_model = torch.export.export_for_training(trainer.model.base_model, (dummy_input,)).module()
graph_model = torch.fx.GraphModule(graph_model, graph_model.graph)
trainer.model.base_model = graph_model

### Perform PTQ & evaluate the quantized model

In [None]:
quantized_model = trainer.quantizer.quantize_model(graph_model, calib_data)

In [None]:
trainer.model.base_model = quantized_model

In [None]:
*_, summary = trainer.evaluator.evaluate(trainer.model)

In [None]:
print(summary)

## Perform QAT based on PTQ results

1. Based on the PTQ results, we perform the PTQ, through training, and adjust the weight/bias.
This can get higher results.
2. We adopt the training code from the original YOLO-X Tiny repo, and we train the model from 280 epoch. Based on the development time and our work focused mainly on the Quark Fx QAT tool, we only tried several parameters to perform training. Differently, we using one single GPU to perform training to largely reduce the training complexity. The user can try other hyperparameters to get higher results.

### Prepare the Dataloader & Optimizer etc.

In [None]:
from data import DataPrefetcher
from trainer import ModelEMA

In [None]:
trainer.no_aug = trainer.start_epoch >= trainer.max_epoch - trainer.exp.no_aug_epochs
trainer.train_loader = trainer.exp.get_data_loader(
    batch_size=trainer.args.batch_size, no_aug=trainer.no_aug, cache_img=None
)
logging.info("init prefetcher, this might take one minute or less...")
trainer.prefetcher = DataPrefetcher(trainer.train_loader)

trainer.max_iter = len(trainer.train_loader)
trainer.lr_scheduler = trainer.exp.get_lr_scheduler(
    trainer.exp.basic_lr_per_img * trainer.args.batch_size, trainer.max_iter
)
trainer.optimizer = trainer.exp.get_optimizer(trainer.args.batch_size)
#  ------ using ema for better coverage ---
if trainer.use_model_ema:
    trainer.ema_model = ModelEMA(trainer.model, trainer.args.ema_decay)  # 0.9995
    trainer.ema_model.updates = trainer.max_iter * trainer.start_epoch

### Perform training to further improve accuracy
NOTE: We only training one epoch for demonstration

In [None]:
logging.info("Training start...")
# logging.info("\n{}".format(trainer.model))
trainer.epoch = 280
logging.info(f"---> start train epoch{trainer.epoch + 1}")

**NOTE**: in function, `train_in_iter`, 
  1. We close the observer, meaning, during training the scale will not change;
  2. Based on experience, we found that during training, we close the `bn` update that can get higher results.

In [None]:
trainer.train_in_iter()

### Evaluate the model

To simplify, we directly load the fintuned weight to test accuracy

In [None]:
trainer.model.load_state_dict(
    torch.load("./YOLOX_outputs/yolo_x_tiny_exp_3031/best_ckpt.pth", weights_only=False)["model"]
)

In [None]:
*_, summary = trainer.evaluator.evaluate(trainer.model)

In [None]:
print(summary)

### Freeze model & export to onnx


#### Freeze model
For better deployment in the AMD NPU device, we apply several hardware optimizations (e.g. adjust the scale, insert multiply nodes to perform adjustment for hardware)

In [None]:
freezeded_model = trainer.quantizer.freeze(trainer.model.base_model.eval())
trainer.model.base_model = freezeded_model

#### Export to ONNX

In [None]:
from quark.torch import export_onnx

In [None]:
# NOTE for NPU compile, it is better using batch-size = 1 for better compliance
example_inputs = (torch.rand(1, 3, 416, 416).to(trainer.device),)
export_onnx(model=trainer.model, output_dir="./export_onnx/", input_args=example_inputs[0])

#### Simplity the Onnx model and visualize

In [None]:
import onnx
from onnxsim import simplify

quant_model = onnx.load("./export_onnx/quark_model.onnx")
model_simp, check = simplify(quant_model)
onnx.save_model(model_simp, "./export_onnx/sample_quark_model.onnx")

Using `netron` to visualize the model (Optional)
```shell
$netron  ./export_onnx/sample_quark_model.onnx
```