Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLOv9 with Quantization-Aware Training (QAT) for TensorRT #327

Open
levipereira opened this issue Apr 6, 2024 · 41 comments
Open

YOLOv9 with Quantization-Aware Training (QAT) for TensorRT #327

levipereira opened this issue Apr 6, 2024 · 41 comments

Comments

@levipereira
Copy link

levipereira commented Apr 6, 2024

YOLOv9 with Quantization-Aware Training (QAT) for TensorRT

https://github.com/levipereira/yolov9-qat/
This repository hosts an implementation of YOLOv9 integrated with Quantization-Aware Training (QAT), optimized for deployment on TensorRT-supported platforms to achieve hardware-accelerated inference. It aims to deliver an efficient, low-latency version of YOLOv9 for real-time object detection applications. If you're not planning to deploy your model using TensorRT, it's advisable not to proceed with this implementation.

image

Implementation Details:

  • The repository provides a patch that adds QAT functionality to the original YOLOv9 codebase.
  • The patch is designed to be applied to the main YOLOv9 repository, enabling training with QAT.
  • This implementation is finely tuned for TensorRT, a hardware-accelerated inference library, enhancing performance.
  • Users interested in object detection using YOLOv9 with QAT on TensorRT platforms can leverage this repository, offering a ready-to-use solution.

Perfomance Report

@WongKinYiu I've successfully created a comprehensive implementation of Quantization in a separate repository. It works as a patch for the original YOLOv9 version. However, there are still some challenges to address as the implementation is functional but has room for improvement.

I'm closing the issue #253 and will continue the discussion in this thread. If possible, please replace the reference to issue #253 with this new issue #327 in the Useful Links section.

I'll provide the latency reports shortly.

@WongKinYiu
Copy link
Owner

Updated.
Thanks a lot.

@levipereira
Copy link
Author

levipereira commented Apr 8, 2024

Perfomance / Accuracy

TensorRT version: 10.0.0

Model

YOLOv9-C-converted

Accuracy Report

YOLOv9-C

Evaluation Results

Eval Model AP AP50 Precision Recall
Origin (Pytorch) 0.529 0.699 0.743 0.634
INT8 (TensorRT) 0.527 0.695 0.746 0.627

Evaluation Comparison

Eval Model AP AP50 Precision Recall
INT8 (TensorRT) vs Origin (Pytorch)
-0.002 -0.004 +0.003 -0.007

Latency/Throughput Report using only TensorRT

Device

GPU
Device NVIDIA GeForce RTX 4090
Compute Capability 8.9
SMs 128
Device Global Memory 24207 MiB
Application Compute Clock Rate 2.58 GHz
Application Memory Clock Rate 10.501 GHz

Latency/Throughput

Model Name Batch Size Latency (99%) Throughput (qps) Total Inferences (IPS)
(FP16) 1 1.25 ms 803 803
4 3.37 ms 300 1200
8 6.6 ms 153 1224
12 10 ms 99 1188
INT8 1 0.99 ms 1006 1006
4 2.12 ms 473 1892
8 3.84 ms 261 2088
12 5.59 ms 178 2136

Latency/Throughput Comparison

Model Name Batch Size Latency (99%) Throughput (qps) Total Inferences
INT8 vs FP16
1 -20.8% +25.2% +25.2%
4 -37.1% +57.7% +57.7%
8 -41.1% +70.6% +70.6%
12 -46.9% +79.8% +78.9%

Full Report

@levipereira
Copy link
Author

@WongKinYiu Do you happen to have a YOLOv9-C or YOLOv9-E model trained with ReLU or ReLU6 activation functions? I need it for performance testing with quantization. If available and you could share it, it would greatly help me.

@WongKinYiu
Copy link
Owner

Sorry for late reply, yolov9-relu.pt is here.
Not yet re-parameterized.

@WongKinYiu
Copy link
Owner

yolov9-relu-converted.pt

@levipereira
Copy link
Author

levipereira commented Apr 17, 2024

@WongKinYiu Thank you for providing the weights file.
As I suspected, the ReLU activation function delivers much better performance (latency) compared to SiLU. Depending on the scenario, it might be worth sacrificing a bit of accuracy for the sake of latency.

The current results have been quite satisfactory, achieving a minimum latency value of 0.84ms.
My next goal is to test with the ReLU6 function.

Below are the tables of the results:

YOLOv9 - with ReLU

Perfomance / Accuracy

TensorRT version: 10.0.0

Device

GPU
Device NVIDIA GeForce RTX 4090
Compute Capability 8.9
SMs 128
Device Global Memory 24207 MiB
Application Compute Clock Rate 2.58 GHz
Application Memory Clock Rate 10.501 GHz

Accuracy Report

Evaluation Results

Eval Model AP AP50 Precision Recall
Origin (PyTorch) 0.519 0.69 0.719 0.629
INT8 (PyTorch) 0.518 0.69 0.725 0.623
INT8 (TensorRT) 0.517 0.685 0.723 0.623

Evaluation Comparison

Eval Model AP Diff AP50 Diff Precision Diff Recall Diff
INT8 (TensorRT) vs Origin (PyTorch) -0.002 -0.005 +0.004 -0.006

Latency/Throughput Report using TensorRT

Latency/Throughput

Model Name Batch Size Latency (99%) Throughput (qps) Total Inferences (IPS)
YOLOv9-ReLU (FP16) 1 1.15 ms 868 868
12 8.81 ms 115 1380
YOLOv9-ReLU (INT8) 1 0.84 ms 1186 1186
12 4.59 ms 218 2616

Latency/Throughput Comparison

Model Name Batch Size Latency (99%) Diff Throughput (qps) Diff Total Inferences (IPS) Diff
(INT8) vs (FP16)
1 -27.0% +36.5% +36.5%
12 -47.9% +89.6% +89.6%

@snehashis1997
Copy link

can we infer the Pytorch int8 model? what is the benchmark report pytorch int8 vs trt int8?

@WongKinYiu
Copy link
Owner

@levipereira

Could you help for examine the latency/throughput without NMS?
Thanks in advance.

@levipereira
Copy link
Author

@WongKinYiu
These tests were performed without NMS.
Below is a table with additional tests.
https://github.com/levipereira/yolov9-qat?tab=readme-ov-file#latencythroughput

@WongKinYiu
Copy link
Owner

Thanks!

@WongKinYiu
Copy link
Owner

@levipereira

Excuse me, I would like to borrow your time again.
Could you please help me for examine the latency/throughput of following models?
yolov9-c-coarse.pt, yolov9-c-fine.pt, lh-yolov9-c-coarse.pt, lh-yolov9-c-fine.pt.
Thanks very much.

@levipereira
Copy link
Author

@WongKinYiu

Latency/Throughput

LH-YOLOV9-C-FINE

Precision Batch Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
FP16 1 1.21 ms 824 824
FP16 8 6.18 ms 164 1312
INT8 1 0.95 ms 1051 1051
INT8 8 3.55 ms 281 2248
INT8 12 5.17 ms 195 2340

LH-YOLOV9-C-COARSE

Precision Batch Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
FP16 1 1.21 ms 822 822
FP16 8 6.22 ms 162 1296
INT8 1 0.95 ms 1050 1050
INT8 8 3.56 ms 281 2248
INT8 12 5.18 ms 193 2316

YOLOV9-C-FINE

Precision Batch Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
FP16 1 1.25 ms 800 800
FP16 8 6.65 ms 152 1216
INT8 1 0.97 ms 1033 1033
INT8 8 3.67 ms 272 2176
INT8 12 5.32 ms 189 2268

YOLOV9-C-COARSE

Precision Batch Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
FP16 1 1.25 ms 804 804
FP16 8 6.62 ms 152 1216
INT8 1 0.98 ms 1026 1026
INT8 8 3.68 ms 271 2168
INT8 12 5.34 ms 189 2268

Evaluation

YOLOV9-C-COARSE

Eval Model AP AP50 Precision Recall
Origin 0.527 0.699 0.74 0.633
QAT-TRT 0.524 0.692 0.723 0.638

YOLOV9-C-FINE

Eval Model AP AP50 Precision Recall
Origin 0.523 0.699 0.738 0.63
QAT-TRT 0.522 0.693 0.743 0.622

LH-YOLOV9-C-FINE

Eval Model AP AP50 Precision Recall
Origin 0.525 0.701 0.723 0.639
QAT-TRT 0.524 0.695 0.733 0.629

LH-YOLOV9-C-COARSE

Eval Model AP AP50 Precision Recall
Origin 0.527 0.699 0.724 0.642
QAT-TRT 0.524 0.693 0.726 0.631

@levipereira
Copy link
Author

@WongKinYiu
QAT models have shown substantial improvements in latency and performance with minimal accuracy loss. It would be advantageous for the community to begin incorporating support for QAT in the codebase, allowing it to be activated post-training.
https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/index.html#further-optimization

@WongKinYiu
Copy link
Owner

Thanks a lot.
It seems fine branch lost less accuracy than coarse branch after QAT.
The coarse-to-fine models will finish training tomorrow.
I will update finial weights soon.

By the way, fine branch need not nms for post-processing.
You could also estimate accuracy without nms for fine branch models.

@levipereira
Copy link
Author

levipereira commented Jun 2, 2024

I am currently using the YOLOv9 code found at this link: https://github.com/levipereira/yolov9-qat/blob/master/val_trt.py#L249-L290. If you already have the corresponding code to evaluate without NMS, I would greatly appreciate it.

@WongKinYiu
Copy link
Owner

Currently I just remove nms part of non_max_suppresion in general.py to implement no_max_suppression.

def no_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nm=0,  # number of masks
):
    """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    if isinstance(prediction, (list, tuple)):  # YOLO model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    device = prediction.device
    mps = 'mps' in device.type  # Apple MPS
    if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
        prediction = prediction.cpu()
    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[1] - nm - 4  # number of classes
    mi = 4 + nc  # mask start index
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    max_wh = 7680  # (pixels) maximum box width and height
    max_nms = 300  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 2.5 + 0.05 * bs  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x.T[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
            v[:, :4] = lb[:, 1:5]  # box
            v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Detections matrix nx6 (xyxy, conf, cls)
        box, cls, mask = x.split((4, nc, nm), 1)
        box = xywh2xyxy(box)  # center_x, center_y, width, height) to (x1, y1, x2, y2)
        if multi_label:
            i, j = (cls > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = cls.max(1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence
        else:
            x = x[x[:, 4].argsort(descending=True)]  # sort by confidence

        output[xi] = x
        if mps:
            output[xi] = output[xi].to(device)
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded

    return output

@WongKinYiu
Copy link
Owner

WongKinYiu commented Jun 3, 2024

clean up the code.

def no_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nm=0,  # number of masks
):
    """No Maximum Suppression on inference results

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    if isinstance(prediction, (list, tuple)):  # YOLO model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    device = prediction.device
    mps = 'mps' in device.type  # Apple MPS
    if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
        prediction = prediction.cpu()
    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[1] - nm - 4  # number of classes
    mi = 4 + nc  # mask start index
    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

    # Settings
    time_limit = 2.5 + 0.05 * bs  # seconds to quit after
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)

    t = time.time()
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        x = x.T[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
            v[:, :4] = lb[:, 1:5]  # box
            v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Detections matrix nx6 (xyxy, conf, cls)
        box, cls, mask = x.split((4, nc, nm), 1)
        box = xywh2xyxy(box)  # center_x, center_y, width, height) to (x1, y1, x2, y2)
        if multi_label:
            i, j = (cls > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = cls.max(1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_det:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_det]]  # sort by confidence
        else:
            x = x[x[:, 4].argsort(descending=True)]  # sort by confidence

        output[xi] = x
        if mps:
            output[xi] = output[xi].to(device)
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded

    return output

@levipereira
Copy link
Author

levipereira commented Jun 3, 2024

I have retrained the model completely due to the change that eliminates the need to process the NMS. We pick the best model based on mAP.

YOLOV9-C-FINE (NO-NMS)

Eval Model AP AP50 Precision Recall
Origin 0.523 0.693 0.735 0.621
QAT-TRT 0.52 0.686 0.735 0.617

LH-YOLOV9-C-FINE (NO-NMS)

Eval Model AP AP50 Precision Recall
Origin 0.524 0.695 0.728 0.629
QAT-TRT 0.521 0.687 0.731 0.618

Note: Despite the good results, the model is still not 100% quantized. We can further improve performance with minimal loss. Additionally, we can recover the loss due to quantization by applying other calibration methods.

@WongKinYiu
Copy link
Owner

lh-yolov9-c-coarse.pt, lh-yolov9-c-fine.pt are updated.

@levipereira
Copy link
Author

levipereira commented Jun 3, 2024

LH-YOLOV9-C-FINE (mse) NMS-free

Eval Model AP AP50 Precision Recall
Origin 0.526 0.696 0.743 0.628
PTQ (Baseline) 0.514 0.681 0.724 0.613
QAT (PyTorch) 0.517 0.683 0.744 0.604
QAT-TRT 0.516 0.679 0.724 0.617

LH-YOLOV9-C-FINE (percentile=99.999) NMS-free

Eval Model AP AP50 Precision Recall
Origin 0.526 0.696 0.743 0.628
PTQ (Baseline) 0.517 0.686 0.742 0.615
QAT (PyTorch) 0.518 0.687 0.738 0.616
QAT-TRT 0.517 0.682 0.739 0.614

LH-YOLOV9-C-COARSE (mse)

Eval Model AP AP50 Precision Recall
Origin 0.528 0.700 0.743 0.634
PTQ (Baseline) 0.524 0.696 0.734 0.631
QAT (PyTorch) 0.526 0.697 0.741 0.631
QAT-TRT 0.526 0.692 0.733 0.634

LH-YOLOV9-C-COARSE (percentile=99.999)

Eval Model AP AP50 Precision Recall
Origin 0.528 0.700 0.743 0.634
PTQ (Baseline) 0.525 0.697 0.742 0.628
QAT (PyTorch) 0.527 0.699 0.742 0.634
QAT-TRT 0.526 0.692 0.744 0.631

I initially performed the default MSE calibration, but the results were unsatisfactory. Consequently, I modified the calibration method to use percentile=99.999, which yielded better outcomes. I believe that the these model has more sensitive layers that need to be treated differently. Additionally, I need to explore the new HEAD of the model since I only performed quantization for YOLOv9-C/E.

I am generating a latency report.

@levipereira
Copy link
Author

LH-YOLOV9-C-FINE (INT8)

Batch Size Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
1 0.94 ms 1059 1059
8 3.56 ms 282 2256
12 5.18 ms 194 2328

LH-YOLOV9-C-COARSE(INT8)

Batch Size Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
1 0.95 1048 1048
8 3.56 ms 281 2248
12 5.21 ms 193 2316

@WongKinYiu
Copy link
Owner

Thanks!

It seems old weights (#327 (comment)) have more stable QAT performance.
Although new weights (#327 (comment)) have higher original AP, they drop more performance after QAT.

Since old weights and new weights are trained by different strategies, maybe it is worth to discuss the relation between pretrain methods and QAT step.

I will provide weights of YOLOV9-C-FINE trained by same way as #327 (comment) in few days to make sure if sensitive layers are caused by different training methods.

If yes, I could try to analyze and design QAT friendly pretrained methods in the future.

Thank you for bring this possible research direction to me.

@levipereira
Copy link
Author

levipereira commented Jun 5, 2024

I ran the tests to find the most sensitive layer (PQT Baseline), and here are the results:

LH-YOLOV9-C-FINE

(#327 (comment))
Top0: Using fp16 model.22, ap = 0.52310

Sensitive summary:
Top0: Using fp16 model.22, ap = 0.52310
Top1: Using fp16 model.4, ap = 0.51660
Top2: Using fp16 model.3, ap = 0.51650
Top3: Using fp16 model.1, ap = 0.51570
Top4: Using fp16 model.15, ap = 0.51560
Top5: Using fp16 model.17, ap = 0.51560
Top6: Using fp16 model.2, ap = 0.51550
Top7: Using fp16 model.8, ap = 0.51550
Top8: Using fp16 model.11, ap = 0.51550
Top9: Using fp16 model.14, ap = 0.51550
Top10: Using fp16 model.21, ap = 0.51550
Top11: Using fp16 model.0, ap = 0.51540
Top12: Using fp16 model.9, ap = 0.51540
Top13: Using fp16 model.18, ap = 0.51540
Top14: Using fp16 model.5, ap = 0.51530
Top15: Using fp16 model.6, ap = 0.51530
Top16: Using fp16 model.12, ap = 0.51530
Top17: Using fp16 model.19, ap = 0.51530
Top18: Using fp16 model.20, ap = 0.51520
Top19: Using fp16 model.7, ap = 0.51500
Top20: Using fp16 PTQ, ap = 0.51490
Top21: Using fp16 model.10, ap = 0.51490

Today my day was quite busy, but I believe I will be able to run the training with layer 22 using fp16 and see the performance and accuracy results.

@levipereira
Copy link
Author

levipereira commented Jun 5, 2024

Indeed, layer 22 is the most sensitive layer. I disabled the quantization in layer 22 and managed to recover the precision with better performance at batch size 1. However, when increasing the batch size to 8 or 12, there is a slight increase in latency and a decrease in throughput.

(#327 (comment))

LH-YOLOV9-C-FINE (mse) NMS-free

Eval Model AP AP50 Precision Recall
Origin 0.526 0.696 0.743 0.628
PTQ (Baseline) 0.522 0.693 0.74 0.627
QAT (PyTorch) 0.524 0.694 0.738 0.626
QAT-TRT 0.525 0.69 0.743 0.622

LH-YOLOV9-C-FINE (INT8)

Batch Size Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
1 0.92 ms 1079 1079
8 3.76 ms 266 2128
12 5.65 ms 178 2136

@WongKinYiu
Copy link
Owner

I upload the weights and update the file name.

training method 1:
yolov9-c-coarse.pt, yolov9-c-fine.pt, lh-yolov9-c-coarse.pt, lh-yolov9-c-fine.pt.

training method 2:
yolov9-c-coarse-.pt, yolov9-c-fine-.pt, lh-yolov9-c-coarse-.pt, lh-yolov9-c-fine-.pt.

By the way, could you help for examine latency/throughput of tiny/small/medium models also.
yolov9-t-converted.pt, yolov9-s-converted.pt, yolov9-m-converted.pt.

Thanks.

@levipereira
Copy link
Author

levipereira commented Jun 6, 2024

I have observed that the last layer of model is often the most sensitive to quantization. This sensitivity arises because this layer tends to generate more outliers. From a quantization perspective, these outliers are normalized, leading to a loss of precision, as these outliers are crucial for the model’s accuracy.

By changing the training method, we have effectively reduced the generation of outliers, which are critical for quantization. The different training approach has shown to produce fewer values that are considered outliers, thus preserving the precision and overall performance of the quantized model.

To address the sensitivity of the final layer to quantization, I implemented a straightforward approach: disabling the quantization of layer 22. Instead of retraining the model, I simply disabled the quantization for this specific layer and re-evaluated the model to assess the impact on performance.

Quantization Disabled at layer 22 is indicated by the suffix -D22.

I performed the calibration using MSE, although in some cases, using percentile = 99.999 proved to be more efficient.

Model Performance Tables (MSE)

YOLOv9-C-Coarse - Method 1

Eval Model AP AP50 Precision Recall
Origin 0.527 0.699 0.74 0.633
PTQ 0.523 0.696 0.733 0.629
QAT-PyT 0.525 0.697 0.741 0.624
QAT-TRT 0.524 0.692 0.722 0.638
QAT-PyT-D22 0.525 0.694 0.732 0.631
QAT-TRT-D22 0.524 0.692 0.725 0.635

YOLOv9-C-Coarse - Method 2

Eval Model AP AP50 Precision Recall
Origin 0.527 0.693 0.728 0.636
PTQ 0.525 0.692 0.739 0.630
QAT-PyT 0.526 0.692 0.738 0.629
QAT-TRT 0.526 0.693 0.725 0.638
QAT-PyT-D22 0.526 0.693 0.730 0.635
QAT-TRT-D22 0.526 0.693 0.730 0.634

LH-YOLOv9-C-Coarse - Method 1

Eval Model AP AP50 Precision Recall
Origin 0.527 0.699 0.724 0.642
PTQ 0.524 0.696 0.728 0.628
QAT-PyT 0.525 0.697 0.731 0.633
QAT-TRT 0.524 0.692 0.722 0.638
QAT-PyT-D22 0.525 0.694 0.715 0.639
QAT-TRT-D22 0.524 0.693 0.722 0.638

LH-YOLOv9-C-Coarse - Method 2

Eval Model AP AP50 Precision Recall
Origin 0.528 0.696 0.743 0.634
PTQ 0.524 0.693 0.734 0.631
QAT-PyT 0.526 0.693 0.741 0.631
QAT-TRT 0.525 0.693 0.734 0.634
QAT-PyT-D22 0.527 0.694 0.742 0.631
QAT-TRT-D22 0.526 0.693 0.734 0.634

YOLOv9-C-Fine - Method 1

Eval Model AP AP50 Precision Recall
Origin 0.523 0.689 0.735 0.621
PTQ 0.519 0.685 0.723 0.627
QAT-PyT 0.520 0.687 0.740 0.619
QAT-TRT 0.520 0.686 0.734 0.619
QAT-PyT-D22 0.520 0.686 0.737 0.619
QAT-TRT-D22 0.520 0.686 0.734 0.620

YOLOv9-C-Fine - Method 2

Eval Model AP AP50 Precision Recall
Origin 0.523 0.688 0.733 0.626
PTQ 0.521 0.685 0.725 0.623
QAT-PyT 0.522 0.686 0.734 0.621
QAT-TRT 0.522 0.685 0.734 0.615
QAT-PyT-D22 0.522 0.686 0.711 0.629
QAT-TRT-D22 0.522 0.685 0.726 0.620

LH-YOLOv9-C-Fine - Method 1

Eval Model AP AP50 Precision Recall
Origin 0.524 0.695 0.728 0.629
PTQ 0.520 0.691 0.730 0.620
QAT-PyT 0.521 0.691 0.740 0.614
QAT-TRT 0.521 0.687 0.741 0.616
QAT-PyT-D22 0.522 0.689 0.733 0.620
QAT-TRT-D22 0.523 0.689 0.741 0.615

LH-YOLOv9-C-Fine - Method 2

Eval Model AP AP50 Precision Recall
Origin 0.526 0.692 0.743 0.628
PTQ 0.515 0.678 0.724 0.613
QAT-PyT 0.517 0.679 0.735 0.611
QAT-TRT 0.516 0.679 0.730 0.612
QAT-PyT-D22 0.524 0.690 0.752 0.620
QAT-TRT-D22 0.524 0.690 0.750 0.620

I still owe the tests for the remaining models as well as the latency tests, which I will send as soon as possible.

@levipereira
Copy link
Author

Result of Tiny/Small/Medium

I have encountered several performance issues regarding latency and throughput in the quantized Tiny, Small, and Medium models. They performed worse than the FP16 models, generating many reformat operations that directly impacted the model's latency. I am currently researching and studying the behavior of quantization in these models to resolve the issue.

Latency

yolov9-t-converted (FP16)

Batch Size Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
1 0.78 ms 1282 1282
12 2.32 ms 432 5184
24 4.40 ms 228 5472
32 5.94 ms 169 5408

yolov9-s-converted (FP16)

Batch Size Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
1 0.92 ms 1086 1086
12 4.19 ms 240 2880
24 8.25 ms 122 2928
32 11.07 ms 91 2912

yolov9-m-converted (FP16)

Batch Size Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
1 1.22 ms 818 818
8 5.81 ms 175 1400
12 8.78 ms 116 1392
24 17.90 ms 57 1368
32 24.09 ms 42 1344

Evaluation

yolov9-t-converted

Eval Model AP AP50 Precision Recall
Origin 0.38 0.527 0.633 0.481
PTQ 0.374 0.52 0.63 0.475
QAT - Best 0.376 0.523 0.641 0.473
QAT - TRT 0.376 0.523 0.64 0.476

yolov9-s-converted

Eval Model AP AP50 Precision Recall
Origin 0.467 0.629 0.698 0.569
PTQ 0.462 0.622 0.689 0.565
QAT - Best 0.465 0.626 0.688 0.576
QAT - TRT 0.465 0.626 0.692 0.57

yolov9-m-converted

Eval Model AP AP50 Precision Recall
Origin 0.513 0.677 0.713 0.62
PTQ 0.512 0.675 0.713 0.619
QAT - Best 0.512 0.675 0.712 0.621
QAT - TRT 0.511 0.675 0.712 0.622

@WongKinYiu
Copy link
Owner

Thanks.

Yes, the throughput seems strange.
yolov9-m gets same throughput as tolov9-c.

The main difference between t/s/m and c/e is that t/s/m use AConv and c/e use ADown for downsampling.
In Pytorch AConv is faster than ADown, but I am not sure the situation on trt and quantization.

@levipereira
Copy link
Author

I have noticed that model performance is often measured solely by latency. However, during my research, I discovered that different models can have very similar latencies with a batch size of 1. But as the batch size increases, they show significant differences in both throughput and latency. Therefore, testing only with a batch size of 1 and focusing solely on latency can lead to incorrect conclusions about the model's potential.

To accurately measure a model's potential, we should consider both latency and batch size. On the GPU, models can have a certain latency, but increasing the batch size doesn't cause latency to grow proportionally. This is evident in the performance tables. Thus, the best model is the one that achieves the highest throughput with the largest batch size and the lowest latency.

I will attempt to illustrate my finds visually.

Testing with Batch Size = 1

GR Active means GPU was on 100% (The percentage of cycles the compute engine is active.)

image

SM Active only 24% - What means a lot space to increase Batch Size

image

Testing with Batch Size = 24

GR still 100%

SM Active was 85%

image

When SM Active reaches 100%, the model's performance drops, resulting in increased latency and decreased throughput.

Therefore, when measuring the potential of the model, we should also consider the batch size. The best model is the one that achieves the highest throughput with the largest batch size and the lowest latency.

@levipereira
Copy link
Author

levipereira commented Jun 8, 2024

yolov9-m gets same throughput as tolov9-c.

I will perform profiling to see the differences.

@WongKinYiu
Copy link
Owner

Batch 1 and large batch are both important.

Large batch inference is importance on cloud service.
Batch 1 inference is important for streaming input.

yolov9-m gets same throughput as tolov9-c

I have encounter similar issues (small model and large model have same inference speed) on yolov4 when using some build-in pytorch version in nvidia docker.
My solution is reinstall pytorch and related dependency.
I am not sure if you face same issue.

@levipereira
Copy link
Author

levipereira commented Jun 8, 2024

I have encounter similar issues (small model and large model have same inference speed) on yolov4 when using some build-in pytorch version in nvidia docker.

I will test these models on different servers and TensorRT version.

I often see performance reports comparing perfomance between YOLO Series models with a batch size of 1, using latency as the primary comparison parameter. However, without testing the variable batch size, it's possible that some models may have significantly worse performance when using larger batch sizes compared to others. A classic example was test of batch size 1 on YOLOv9-t with a latency of 0.7 ms versus YOLOv9-s with 0.9 ms and the throughput difference was only about 280 IPS.
However, when tested with a batch size of 12, the YOLOv9-t latency was almost double that of YOLOv9-s, with YOLOv9-t at 2.32 ms and YOLOv9-s at 4.19 ms. In terms of throughput, the difference was significant, with YOLOv9-t achieving nearly 2500 more IPS than YOLOv9-s, with YOLOv9-t at 5184 IPS and YOLOv9-s at 2880 IPS.
By focusing solely on latency at batch size 1, I would be overlooking the full potential of YOLOv9-t in terms of performance.

@WongKinYiu
Copy link
Owner

laugh12321 gets similar inference speed as your reports.

Three possible reasons:

  1. number of layers: e > t = s > c > m
  2. AConv vs ADown
  3. environment

Since c model has 13 times flops of t model, it really strange.
I have never meet this situation on our platform.
Could you help for trying to switch to root user sudo -s and test speed.

@WongKinYiu
Copy link
Owner

To check if the number of layer is the one of reason,
could you help for test gelan-s2.pt.

number of layers: e > t = s > s1 = c > m

@levipereira
Copy link
Author

could you help for test gelan-s2.pt.

gelan-s2 (FP16)

Batch Size Latency (percentile 99%) Throughput (qps) Total Throughput (IPS)
1 0.805 ms 1242 1242
12 3.872 ms 259 3108
24 7.766 ms 130 3120
32 10.388 ms 97 3104

Since c model has 13 times flops of t model, it really strange.
Could you help for trying to switch to root user sudo -s and test speed.

I don't believe the problem is with the host or the installation. Maybe be some bug/issue in TensorRT, because only a few models exhibit this strange behavior.
Will install TensorRT Engine Explorer and get results.

I'm having a lot of difficulty identifying why the t/s/m model is performing poorly when quantized. I've noticed a lot reformat operations due different scales. I implemented AConv similar to ADown, but the poor results persist. I also observed some DFL operations in the slice of the initial layers what differ from Yolov9-c. However, I'm still investigating this carefully.

Model FP16

image

Model QAT

Theses Reformat are killing me
image

@WongKinYiu
Copy link
Owner

Thank you for your effort.

Yes, it seems there are many unnecessary reformat layers are generated by tensorrt.
I am not sure if this help.
"It is possible to make TensorRT avoid inserting reformatting at the network boundaries, by setting the builder configuration flag DIRECT_IO. "

@levipereira
Copy link
Author

about #327 (comment)
I performed profiling for each model individually and then conducted a comparative analysis between the models:
C vs. M
M vs. S
S vs. T
Google Drive files:
https://drive.google.com/drive/folders/18vBxAWZmQ1KUV7Tga9yH_fL5YSbbdYzw?usp=sharing
image

@WongKinYiu
Copy link
Owner

Well, do not know why after convert to tensorrt, yolov9-m has many layers.

@levipereira
Copy link
Author

I have been analyzing the models and noticed that YOLOv9-C vs. YOLOv9-M has several Reformat operations where some nodes were not fused. The same issue occurs with the QAT models, where some nodes, despite being on the same scale, are not being fused, resulting in multiple Reformat operations.

I searched on GitHub and found several users experiencing issues with node fusion, where TensorRT did not support certain fusions. Given that these models introduce new modules, it is possible that this has caused issues with TensorRT.

We need to open another front to address this issue in the TensorRT repository to understand where the potential problem lies.

@WongKinYiu
Copy link
Owner

Could you help for take a look if YOLOv7 have same issue.
If no, I could point out the main difference between YOLOv7 and YOLOv9 architectures.

@levipereira
Copy link
Author

These past few days I was away on a business trip. I'm returning now and we will pick up where we left off. I'm sorry for the delay in responding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants