# Загрузка библиотек

In [1]:
import torch
import numpy as np
from ultralytics import YOLO
import os
import time
from PIL import Image
import onnxruntime as onnx
from pathlib import Path
from IPython.display import clear_output
from tqdm import tqdm
import onnxruntime as ort
import pandas as pd

# Загрузка моделей и данных

In [2]:
path_models = Path('models')
model_terminal_od_onnx = YOLO(path_models / 'model_terminal_od.onnx', task='detect')
model_defect_od_onnx = YOLO(path_models / 'model_defect_od_crop.onnx', task='detect')

model_terminal_od_pt = YOLO(path_models / 'model_terminal_od.pt')
model_defect_od_pt = YOLO(path_models / 'model_defect_od_crop.pt')

In [3]:
model_defect_od_onnx_fixed = ort.InferenceSession(
    path_models / 'model_defect_od_crop.onnx', 
    providers=['CPUExecutionProvider']
)
model_terminal_od_onnx_fixed = ort.InferenceSession(
    path_models / 'model_terminal_od.onnx', 
    providers=['CPUExecutionProvider']
)

In [9]:
path_filenames = Path('samples')
filenames = [path_filenames / x for x in os.listdir(path_filenames)]
print(f'Количество семплов: {len(filenames)}')

Количество семплов: 273


In [4]:
# models heat-up
temp = model_terminal_od_onnx(filenames[0], imgsz=640, verbose=False)
temp = model_defect_od_onnx(filenames[0], imgsz=1280, verbose=False)

temp = model_terminal_od_pt(filenames[0], imgsz=640, verbose=False)
temp = model_defect_od_pt(filenames[0], imgsz=1280, verbose=False)
del temp

Loading models\model_terminal_od.onnx for ONNX Runtime inference...
Loading models\model_defect_od_crop.onnx for ONNX Runtime inference...


# Options comparison

In [156]:
n_iters = 200
df_speed_comparison = pd.DataFrame()

## Pytorch model inference

In [182]:
duration_pt = [0] * n_iters

In [183]:
%%time
for i in tqdm(range(n_iters)):
    start = time.time()
    sample_file = filenames[i]
    model_terminal_od_pt(sample_file, imgsz=640, verbose=False)
    end = time.time()
    duration_pt[i] = end - start

100%|██████████| 200/200 [00:52<00:00,  3.83it/s]

CPU times: total: 16.1 s
Wall time: 52.2 s





In [186]:
df_speed_comparison['pt'] = duration_pt
df_speed_comparison.describe()

Unnamed: 0,pt
count,200.0
mean,0.260425
std,0.020674
min,0.206111
25%,0.249782
50%,0.264422
75%,0.273562
max,0.315854


## ONNX model inference

In [162]:
def iou(box1,box2):
    return intersection(box1,box2)/union(box1,box2)

def union(box1,box2):
    box1_x1,box1_y1,box1_x2,box1_y2 = box1[:4]
    box2_x1,box2_y1,box2_x2,box2_y2 = box2[:4]
    box1_area = (box1_x2-box1_x1)*(box1_y2-box1_y1)
    box2_area = (box2_x2-box2_x1)*(box2_y2-box2_y1)
    return box1_area + box2_area - intersection(box1,box2)

def intersection(box1,box2):
    box1_x1,box1_y1,box1_x2,box1_y2 = box1[:4]
    box2_x1,box2_y1,box2_x2,box2_y2 = box2[:4]
    x1 = max(box1_x1,box2_x1)
    y1 = max(box1_y1,box2_y1)
    x2 = min(box1_x2,box2_x2)
    y2 = min(box1_y2,box2_y2)
    return (x2-x1)*(y2-y1)

In [190]:
duration_onnx = [0] * n_iters

In [191]:
%%time
for i in tqdm(range(n_iters)):
    start = time.time()
    sample_file = filenames[i]
    img = Image.open(
        sample_file
    )
    img_width, img_height = img.size
    img_terminal = img.resize((640, 640)).convert("RGB")
    img_terminal = (np.array(img_terminal).transpose(2, 0, 1).reshape(1, 3, 640, 640)/255.0).astype(np.float32)

    outputs = model_terminal_od_onnx_fixed.run(
        ["output0"], {"images":img_terminal}
    )
    output = outputs[0][0].transpose()
    filtered_rows = output[
        output[:, 4:].max(axis=1) > 0.5
    ]
    filtered_rows = filtered_rows[
        filtered_rows[:, 4].argsort()
    ]
    x1s = (filtered_rows[:, 0] - filtered_rows[:, 2]/2) / 640 * img_width
    y1s = (filtered_rows[:, 1] - filtered_rows[:, 3]/2) / 640 * img_height
    x2s = (filtered_rows[:, 0] + filtered_rows[:, 2]/2) / 640 * img_width
    y2s = (filtered_rows[:, 1] + filtered_rows[:, 3]/2) / 640 * img_height
    boxes = np.array([x1s, y1s, x2s, y2s]).T
    result = []
    while len(boxes)>0:
        result.append(boxes[0])
        boxes = [box for box in boxes if iou(box, boxes[0])<0.7]
    end = time.time()
    duration_onnx[i] = end - start

100%|██████████| 200/200 [00:45<00:00,  4.41it/s]

CPU times: total: 8min 37s
Wall time: 45.4 s





In [192]:
df_speed_comparison['onnx'] = duration_onnx
df_speed_comparison.describe()

Unnamed: 0,pt,onnx
count,200.0,200.0
mean,0.260425,0.22546
std,0.020674,0.024065
min,0.206111,0.168308
25%,0.249782,0.206248
50%,0.264422,0.228366
75%,0.273562,0.242439
max,0.315854,0.294009


## ONNX + numpy (impr)

In [None]:
duration_onnx_impr = [0] * n_iters

In [193]:
%%time
for i in tqdm(range(n_iters)):
    start = time.time()
    sample_file = filenames[i]
    img = Image.open(
        sample_file
    )
    img_width, img_height = img.size
    img_terminal = img.resize((640, 640)).convert("RGB")
    img_terminal = (np.array(img_terminal).transpose(2, 0, 1).reshape(1, 3, 640, 640)/255.0).astype(np.float32)
    outputs = model_terminal_od_onnx_fixed.run(
        ["output0"], {"images":img_terminal}
    )
    output = outputs[0][0].transpose()
    filtered_rows = output[
        output[:, 4:].max(axis=1) > 0.5
    ]
    filtered_rows = filtered_rows[
        filtered_rows[:, 4].argsort()
    ]
    x1s = (filtered_rows[:, 0] - filtered_rows[:, 2]/2) / 640 * img_width
    y1s = (filtered_rows[:, 1] - filtered_rows[:, 3]/2) / 640 * img_height
    x2s = (filtered_rows[:, 0] + filtered_rows[:, 2]/2) / 640 * img_width
    y2s = (filtered_rows[:, 1] + filtered_rows[:, 3]/2) / 640 * img_height
    boxes = np.array([x1s, y1s, x2s, y2s]).T
    
    result = []
    while len(boxes)>0:
        result.append(boxes[0])
        box2_x1, box2_y1, box2_x2, box2_y2 = boxes[0]
        box1_x1 = boxes[:, 0]
        box1_y1 = boxes[:, 1]
        box1_x2 = boxes[:, 2]
        box1_y2 = boxes[:, 3]
        x1 = np.maximum(box1_x1, box2_x1)
        y1 = np.maximum(box1_y1, box2_y1)
        x2 = np.minimum(box1_x2, box2_x2)
        y2 = np.minimum(box1_y2, box2_y2)
        intersec = (x2-x1)*(y2-y1)
        box1_area = (box1_x2-box1_x1)*(box1_y2-box1_y1)
        box2_area = (box2_x2-box2_x1)*(box2_y2-box2_y1)
        uni = box1_area + box2_area - intersec
        metric = intersec / uni
        boxes = boxes[metric<0.7]
    end = time.time()
    duration_onnx_impr[i] = end - start

100%|██████████| 200/200 [00:46<00:00,  4.30it/s]

CPU times: total: 8min 45s
Wall time: 46.5 s





In [194]:
df_speed_comparison['onnx_impr'] = duration_onnx_impr
df_speed_comparison.describe()

Unnamed: 0,pt,onnx,onnx_impr
count,200.0,200.0,200.0
mean,0.260425,0.22546,0.231182
std,0.020674,0.024065,0.027059
min,0.206111,0.168308,0.16905
25%,0.249782,0.206248,0.211434
50%,0.264422,0.228366,0.233389
75%,0.273562,0.242439,0.250211
max,0.315854,0.294009,0.300204


Стандартный инференс через onnx оказался самым быстрым. Попробуем ускорить функции обработки numpy

## ONNX + jax

In [195]:
import jax.numpy as jnp
from jax import jit

@jit
def get_boxes(filtered_rows, img_width, img_height):
    x1s = (filtered_rows[:, 0] - filtered_rows[:, 2]/2) / 640 * img_width
    y1s = (filtered_rows[:, 1] - filtered_rows[:, 3]/2) / 640 * img_height
    x2s = (filtered_rows[:, 0] + filtered_rows[:, 2]/2) / 640 * img_width
    y2s = (filtered_rows[:, 1] + filtered_rows[:, 3]/2) / 640 * img_height
    boxes = jnp.array([x1s, y1s, x2s, y2s]).T
    return boxes

@jit
def convert_pil_tonumpy(img):
    return (
        img.transpose(2, 0, 1).reshape(1, 3, 640, 640)/255.0
    ).astype(np.float32)

In [196]:
duration_onnx_jax = [0] * n_iters

In [197]:
%%time
for i in tqdm(range(n_iters)):
    start = time.time()
    sample_file = filenames[i]
    img = Image.open(
        sample_file
    )
    img_width, img_height = img.size
    img_terminal = img.resize((640, 640)).convert("RGB")
    
    img_terminal = (np.array(img_terminal).transpose(2, 0, 1).reshape(1, 3, 640, 640)/255.0).astype(np.float32)
    outputs = model_terminal_od_onnx_fixed.run(
        ["output0"], {"images":img_terminal}
    )
    output = outputs[0][0].transpose()
    filtered_rows = output[
        output[:, 4:].max(axis=1) > 0.5
    ]
    filtered_rows = filtered_rows[
        filtered_rows[:, 4].argsort()
    ]
    boxes = get_boxes(filtered_rows, img_width, img_height)
    result = []
    while len(boxes)>0:
        result.append(boxes[0])
        boxes = [box for box in boxes if iou(box, boxes[0])<0.7]
    end = time.time()
    duration_onnx_jax[i] = end - start

100%|██████████| 200/200 [00:47<00:00,  4.18it/s]

CPU times: total: 8min 55s
Wall time: 47.8 s





In [198]:
df_speed_comparison['onnx_jax'] = duration_onnx_jax
df_speed_comparison.describe()

Unnamed: 0,pt,onnx,onnx_impr,onnx_jax
count,200.0,200.0,200.0,200.0
mean,0.260425,0.22546,0.231182,0.238068
std,0.020674,0.024065,0.027059,0.026977
min,0.206111,0.168308,0.16905,0.173427
25%,0.249782,0.206248,0.211434,0.216849
50%,0.264422,0.228366,0.233389,0.235754
75%,0.273562,0.242439,0.250211,0.257273
max,0.315854,0.294009,0.300204,0.316618


## ONNX + jax + jax image processing

In [200]:
duration_onnx_jax_img_preproc = [0] * n_iters

In [201]:
%%time
for i in tqdm(range(n_iters)):
    start = time.time()
    sample_file = filenames[i]
    img = Image.open(
        sample_file
    )
    img_width, img_height = img.size
    img_terminal = img.resize((640, 640)).convert("RGB")
    img_terminal = jnp.array(img_terminal)
    img_terminal = convert_pil_tonumpy(img_terminal)
    outputs = model_terminal_od_onnx_fixed.run(
        ["output0"], {"images":np.array(img_terminal)}
    )
    output = outputs[0][0].transpose()
    filtered_rows = output[
        output[:, 4:].max(axis=1) > 0.5
    ]
    filtered_rows = filtered_rows[
        filtered_rows[:, 4].argsort()
    ]
    boxes = get_boxes(filtered_rows, img_width, img_height)
    result = []
    while len(boxes)>0:
        result.append(boxes[0])
        boxes = [box for box in boxes if iou(box, boxes[0])<0.7]
    end = time.time()
    duration_onnx_jax_img_preproc[i] = end - start

100%|██████████| 200/200 [00:44<00:00,  4.48it/s]

CPU times: total: 8min 27s
Wall time: 44.6 s





In [202]:
df_speed_comparison['onnx_jax_img_prepoc'] = duration_onnx_jax_img_preproc
df_speed_comparison.describe()

Unnamed: 0,pt,onnx,onnx_impr,onnx_jax,onnx_jax_img_prepoc
count,200.0,200.0,200.0,200.0,200.0
mean,0.260425,0.22546,0.231182,0.238068,0.22213
std,0.020674,0.024065,0.027059,0.026977,0.023743
min,0.206111,0.168308,0.16905,0.173427,0.169524
25%,0.249782,0.206248,0.211434,0.216849,0.200513
50%,0.264422,0.228366,0.233389,0.235754,0.217443
75%,0.273562,0.242439,0.250211,0.257273,0.237611
max,0.315854,0.294009,0.300204,0.316618,0.29061


Jaxlib не ставится на сигму :(

## Numba

In [203]:
numba = [0] * n_iters

In [204]:
from numba import njit

@njit
def get_boxes(filtered_rows, img_width, img_height):
    x1s = (filtered_rows[:, 0] - filtered_rows[:, 2]/2) / 640 * img_width
    y1s = (filtered_rows[:, 1] - filtered_rows[:, 3]/2) / 640 * img_height
    x2s = (filtered_rows[:, 0] + filtered_rows[:, 2]/2) / 640 * img_width
    y2s = (filtered_rows[:, 1] + filtered_rows[:, 3]/2) / 640 * img_height
    return x1s, y1s, x2s, y2s

@njit
def convert_pil_tonumpy(img):
    return (img/255.0).astype(np.float32)

In [205]:
%%time
for i in tqdm(range(n_iters)):
    start = time.time()
    sample_file = filenames[i]
    img = Image.open(
        sample_file
    )
    img_width, img_height = img.size
    img_terminal = img.resize((640, 640)).convert("RGB")
    img_terminal = np.array(img_terminal).transpose(2, 0, 1).reshape(1, 3, 640, 640)
    img_terminal = convert_pil_tonumpy(img_terminal)
    outputs = model_terminal_od_onnx_fixed.run(
        ["output0"], {"images": img_terminal}
    )
    output = outputs[0][0].transpose()
    filtered_rows = output[
        output[:, 4:].max(axis=1) > 0.5
    ]
    filtered_rows = filtered_rows[
        filtered_rows[:, 4].argsort()
    ]
    x1s, y1s, x2s, y2s = get_boxes(filtered_rows, img_width, img_height)
    boxes = np.array([x1s, y1s, x2s, y2s]).T
    result = []
    while len(boxes)>0:
        result.append(boxes[0])
        boxes = [box for box in boxes if iou(box, boxes[0])<0.7]
    end = time.time()
    numba[i] = end - start

100%|██████████| 200/200 [00:47<00:00,  4.21it/s]

CPU times: total: 8min 53s
Wall time: 47.5 s





In [206]:
df_speed_comparison['numba'] = numba
df_speed_comparison.describe()

Unnamed: 0,pt,onnx,onnx_impr,onnx_jax,onnx_jax_img_prepoc,numba
count,200.0,200.0,200.0,200.0,200.0,200.0
mean,0.260425,0.22546,0.231182,0.238068,0.22213,0.235822
std,0.020674,0.024065,0.027059,0.026977,0.023743,0.038839
min,0.206111,0.168308,0.16905,0.173427,0.169524,0.166474
25%,0.249782,0.206248,0.211434,0.216849,0.200513,0.216788
50%,0.264422,0.228366,0.233389,0.235754,0.217443,0.233446
75%,0.273562,0.242439,0.250211,0.257273,0.237611,0.250046
max,0.315854,0.294009,0.300204,0.316618,0.29061,0.66793


## Numba + Numba IOU

In [207]:
numba_iou = [0] * n_iters

In [208]:
@njit
def get_metric(boxes):
    box2_x1, box2_y1, box2_x2, box2_y2 = boxes[0]
    box1_x1 = boxes[:, 0]
    box1_y1 = boxes[:, 1]
    box1_x2 = boxes[:, 2]
    box1_y2 = boxes[:, 3]
    x1 = np.maximum(box1_x1, box2_x1)
    y1 = np.maximum(box1_y1, box2_y1)
    x2 = np.minimum(box1_x2, box2_x2)
    y2 = np.minimum(box1_y2, box2_y2)
    intersec = (x2-x1)*(y2-y1)
    box1_area = (box1_x2-box1_x1)*(box1_y2-box1_y1)
    box2_area = (box2_x2-box2_x1)*(box2_y2-box2_y1)
    uni = box1_area + box2_area - intersec
    metric = intersec / uni
    return metric

In [209]:
%%time
for i in tqdm(range(n_iters)):
    start = time.time()
    sample_file = filenames[i]
    img = Image.open(
        sample_file
    )
    img_width, img_height = img.size
    img_terminal = img.resize((640, 640)).convert("RGB")
    img_terminal = (np.array(img_terminal).transpose(2, 0, 1).reshape(1, 3, 640, 640)/255.0).astype(np.float32)
    outputs = model_terminal_od_onnx_fixed.run(
        ["output0"], {"images":img_terminal}
    )
    output = outputs[0][0].transpose()
    filtered_rows = output[
        output[:, 4:].max(axis=1) > 0.5
    ]
    filtered_rows = filtered_rows[
        filtered_rows[:, 4].argsort()
    ]
    x1s = (filtered_rows[:, 0] - filtered_rows[:, 2]/2) / 640 * img_width
    y1s = (filtered_rows[:, 1] - filtered_rows[:, 3]/2) / 640 * img_height
    x2s = (filtered_rows[:, 0] + filtered_rows[:, 2]/2) / 640 * img_width
    y2s = (filtered_rows[:, 1] + filtered_rows[:, 3]/2) / 640 * img_height
    boxes = np.array([x1s, y1s, x2s, y2s]).T
    result = []
    while len(boxes)>0:
        result.append(boxes[0])
        metric = get_metric(boxes)
        boxes = boxes[metric<0.7]
    end = time.time()
    numba_iou[i] = end - start

100%|██████████| 200/200 [00:47<00:00,  4.19it/s]

CPU times: total: 9min 1s
Wall time: 47.7 s





In [210]:
df_speed_comparison['numba_iou'] = numba_iou
df_speed_comparison.describe()

Unnamed: 0,pt,onnx,onnx_impr,onnx_jax,onnx_jax_img_prepoc,numba,numba_iou
count,200.0,200.0,200.0,200.0,200.0,200.0,200.0
mean,0.260425,0.22546,0.231182,0.238068,0.22213,0.235822,0.237853
std,0.020674,0.024065,0.027059,0.026977,0.023743,0.038839,0.049245
min,0.206111,0.168308,0.16905,0.173427,0.169524,0.166474,0.166616
25%,0.249782,0.206248,0.211434,0.216849,0.200513,0.216788,0.213131
50%,0.264422,0.228366,0.233389,0.235754,0.217443,0.233446,0.23553
75%,0.273562,0.242439,0.250211,0.257273,0.237611,0.250046,0.254107
max,0.315854,0.294009,0.300204,0.316618,0.29061,0.66793,0.659166


Стало хуже, лучшее время показывает вариант Numba

Добавим обработку картинки через Numba

In [213]:
# from PIL import ImageDraw

# img = Image.open(sample_file)
# draw = ImageDraw.Draw(img)

# for box in result:
#     x1,y1,x2,y2 = box
#     draw.rectangle((x1,y1,x2,y2),None,"#00ff00")
# img