In [211]:
import cv2
import torch
import numpy as np

from ultralytics.nn.tasks import DetectionModel

## Chargement du modèle .best

## Définition des classes et du choix de la classe

In [212]:
classe_choisie = "2x4_Jaune"

In [213]:
classes = [
    "1x2_Blanc",
    "1x2_Bleu",
    "1x2_Jaune",
    "1x2_Marron",
    "1x2_Noir",
    "1x2_Rouge",
    "1x2_Vert clear",
    "1x2_Vert dark",
    "1x4_Blanc",
    "1x4_Jaune",
    "1x4_Noir",
    "1x4_Rouge",
    "1x4_Vert clear",
    "1x4_Vert dark",
    "2x2_Blanc",
    "2x2_Bleu",
    "2x2_Jaune",
    "2x2_Marron",
    "2x2_Rouge",
    "2x2_Vert clear",
    "2x2_Vert dark",
    "2x4_Blanc",
    "2x4_Bleu",
    "2x4_Jaune",
    "2x4_Rouge",
    "2x4_Vert dark"
]

## Fonction de traitement pour la video

In [214]:
# Configuration initiale
VIDEO_SOURCE = "/home/dim/clone_repo/BrickSearch/data/raw/videos/lego_video_test.mp4"
CONFIDENCE_THRESHOLD = 0.8
MODEL_NAME = "y12x_100_640_3000_16_0_065"
MODEL_PATH = f"/home/dim/clone_repo/BrickSearch/outputs/output_models/{MODEL_NAME}/weights/best.pt"
OUTPUT_FILE = "/home/dim/clone_repo/BrickSearch/outputs/video_annotees/result_" + VIDEO_SOURCE.replace("/home/dim/clone_repo/BrickSearch/data/raw/videos/", "")

In [215]:
def initialize_models():
    """
    Initialise les modèles .

    Returns:
        yolo_model entrainé lego: 
    """
# Chargement du modèle sur le GPU 
    checkpoint = torch.load(MODEL_PATH, map_location=torch.device('cuda'))

    model = checkpoint['model']
    model.eval()

    return model

In [216]:
def initialize_video_capture_and_writer():
    """
    Initialise la capture vidéo et l'écriture vidéo.

    Returns:
        cap: Objet de capture vidéo.
        video_writer: Objet d'écriture vidéo.
    """    
    cap = cv2.VideoCapture(VIDEO_SOURCE)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
    video_writer = cv2.VideoWriter(OUTPUT_FILE, fourcc, fps, (frame_width, frame_height))


    return cap, video_writer

In [217]:
def detect_objects(model, frame):
    """
    Détecte les objets dans un frame avec YOLO.

    Args:
        model: Modèle YOLO.
        frame: Frame vidéo.

    Returns:
        detections_list: Liste des détections avec leurs coordonnées, confiance et classe.
    """
    results = model(frame, conf=CONFIDENCE_THRESHOLD, verbose=False)[0]
    detections_list = []
    for det in results.boxes.data:
        x1, y1, x2, y2, conf, cls = det.cpu().numpy()
        if int(cls) == 0:
            largeur = x2 - x1
            hauteur = y2 - y1
            detections_list.append([[x1, y1, largeur, hauteur], conf, int(cls)])
    return detections_list

In [218]:
def update_tracks(tracker, detections_list, frame):
    """
    Met à jour les pistes avec DeepSort.

    Args:
        tracker: Tracker DeepSort.
        detections_list: Liste des détections.
        frame: Frame vidéo.

    Returns:
        track_boxes: Liste des boîtes de suivi.
        track_ids: Liste des IDs de suivi.
    """
    tracks = tracker.update_tracks(detections_list, frame=frame)
    track_boxes = []
    track_ids = []
    for track in tracks:
        if not track.is_confirmed():
            continue
        track_id = track.track_id
        bbox = track.to_ltrb()
        track_boxes.append(bbox)
        track_ids.append(track_id)
    return track_boxes, track_ids

In [219]:
def annotate_frame(frame, track_boxes, track_ids, annotator):
    """
    Annote les boîtes de suivi sur le frame.

    Args:
        frame: Frame vidéo.
        track_boxes: Liste des boîtes de suivi.
        track_ids: Liste des IDs de suivi.
        annotator: Annotateur pour les boîtes de détection.

    Returns:
        annotated_frame: Frame annoté.
    """
    annotated_frame = frame.copy()
    if track_boxes:
        detections_obj = Detections(
            xyxy=np.array(track_boxes),
            class_id=np.array(track_ids, dtype=int)
        )
        annotated_frame = annotator.annotate(scene=annotated_frame, detections=detections_obj)
        for bbox, tid in zip(track_boxes, track_ids):
            x1, y1, x2, y2 = map(int, bbox)
            cv2.putText(annotated_frame, f"ID {tid}", (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
    return annotated_frame

In [220]:
def process_video():
    """
    Traite la vidéo en détectant, suivant et annotant les objets et les poses.
    """
    # yolo_model, pose_landmarker = initialize_models()
    # tracker = initialize_tracker()
    cap, video_writer = initialize_video_capture_and_writer()
    # annotator = BoxAnnotator(thickness=2)

    while cap.isOpened():
        success, frame = cap.read()
        if not success:
            break

        # detections_list = detect_objects(yolo_model, frame)
        # track_boxes, track_ids = update_tracks(tracker, detections_list, frame)
        # annotated_frame = annotate_frame(frame, track_boxes, track_ids, annotator)

        timestamp_ms = int(cap.get(cv2.CAP_PROP_POS_MSEC))
        roi_timestamp = timestamp_ms

        # for bbox, tid in zip(track_boxes, track_ids):
        #     x1, y1, x2, y2 = map(int, bbox)
        #     if x1 < 0 or y1 < 0 or x2 > frame.shape[1] or y2 > frame.shape[0]:
        #         continue

        #     roi = annotated_frame[y1:y2, x1:x2]

        #     if roi.size == 0:
        #         continue

        #     pose_landmarker_result = detect_poses(pose_landmarker, roi, roi_timestamp)
        #     roi_timestamp += 1
        #     roi = draw_poses(roi, pose_landmarker_result, tid)
        #     annotated_frame[y1:y2, x1:x2] = roi


        # Définition des paramètres de la fenêtre d'affichage
        WINDOW_NAME = 'YOLOv8 Detection'
        WINDOW_WIDTH = 800   # Nouvelle largeur souhaitée
        WINDOW_HEIGHT = 600  # Nouvelle hauteur souhaitée
        WINDOW_POS_X = 100   # Position horizontale sur l'écran
        WINDOW_POS_Y = 100   # Position verticale sur l'écran

        # Création et configuration de la fenêtre
        cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(WINDOW_NAME, WINDOW_WIDTH, WINDOW_HEIGHT)
        cv2.moveWindow(WINDOW_NAME, WINDOW_POS_X, WINDOW_POS_Y)

        
        cv2.imshow("YOLOv8 Detection", frame)


        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    video_writer.release()
    cv2.destroyAllWindows()

In [221]:
process_video()