In [1]:
import cv2
from ultralytics import YOLO

In [2]:
model = YOLO("DetectionWeights.pt")

In [3]:
model.names

{0: 'center', 1: 'hole', 2: 'screw'}

In [159]:
def check_similar(given, trial_data):

    x = given[0]
    y = given[1]

    for trial in trial_data:

        trial_x = trial[0]
        trial_y = trial[1]

        if abs(x - trial_x) < 150 and abs(y - trial_y) < 150:
            return True
    
    return False


def live_detection(src, model):
    cap = cv2.VideoCapture(src)

    while cap.isOpened():

        ret, frame = cap.read()

        if not ret:
            break

        results = model(frame, conf=0.65)
        annotated_frame = results[0].plot()

        cv2.imshow("YOLO Inference", annotated_frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()


def make_image(center, holes, screws, frame):

    for elem in ([center], holes, screws):

        color = [(0, 255, 0), (255, 0, 0), (0, 0, 255)][([center], holes, screws).index(elem)]

        for data in elem:

            x, y, w, h = data
            cv2.rectangle(frame, (x - w//2, y - h//2), (x + w//2, y + h//2), color, 2)
    
    return frame


def get_data(center_id, hole_id, screw_id, src, model):

    cap = cv2.VideoCapture(src)

    while cap.isOpened():

        ret, frame = cap.read()

        if not ret:
            break

        results = model(frame, conf=0.65)

        center_data = None
        hole_data = []
        screw_data = []

        for detections in results[0].boxes:
            data = list(map(int, detections.xywh.tolist()[0]))
            item_class = detections.cls[0].item()

            if item_class == center_id:
                center_data = data

            elif item_class == hole_id:
                hole_data.append(data)

            elif item_class == screw_id:
                screw_data.append(data)

        copy_list = []

        for screw in screw_data:
            if not check_similar(screw, hole_data):
                
                copy_list.append(screw)
        
        screw_data = copy_list

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

        if center_data is not None and len(hole_data) > 0 and len(screw_data) > 0:
            img = make_image(center_data, hole_data, screw_data, frame.copy())
            cv2.imwrite(f"Result.jpg", img)
            cap.release()
            cv2.destroyAllWindows()
            return center_data, hole_data, screw_data

In [162]:
print(get_data(0, 1, 2, 0, model))


0: 480x640 1 center, 3 holes, 3 screws, 46.0ms
Speed: 2.0ms preprocess, 46.0ms inference, 1.0ms postprocess per image at shape (1, 3, 480, 640)
([250, 266, 20, 24], [[174, 308, 14, 16], [108, 239, 15, 16], [176, 241, 14, 16]], [[404, 189, 13, 15], [401, 236, 12, 15]])
