In [1]:
import os
import cv2
# import uuid
from functions import *
from ultralytics import YOLO
from paddleocr import PaddleOCR
from collections import defaultdict

In [2]:
weight_vehicle_detection_path = "/home/namtt/WorkSpace/MyProjects/Vehicle/weights/vehicle_detection.pt"
tracker_path = "/home/namtt/WorkSpace/MyProjects/Vehicle/weights/tracker/botsort.yaml"
weight_license_plate_detection_path = "/home/namtt/WorkSpace/MyProjects/Vehicle/weights/license_plate_detection.pt"

video_path = "/home/namtt/WorkSpace/MyProjects/Vehicle/data/IMG_4527.MOV"
device = "cpu"
class_names = ["car", "motor", "bus", "truck"]
colors = [
    (0, 0, 255),
    (0, 255, 0),
    (255, 0, 0),
    (0, 255, 255) 
]

model_vehicle = YOLO(weight_vehicle_detection_path)
model_vehicle.eval()
model_license_plate = YOLO(weight_license_plate_detection_path)
model_license_plate.eval()
ocr_model = PaddleOCR(
    use_doc_orientation_classify=False,
    use_doc_unwarping=False,
    use_textline_orientation=False
)

[32mCreating model: ('PP-OCRv5_server_det', None)[0m
[32mModel files already exist. Using cached files. To redownload, please delete the directory manually: `/home/namtt/.paddlex/official_models/PP-OCRv5_server_det`.[0m
[32mCreating model: ('PP-OCRv5_server_rec', None)[0m
[32mModel files already exist. Using cached files. To redownload, please delete the directory manually: `/home/namtt/.paddlex/official_models/PP-OCRv5_server_rec`.[0m


In [None]:
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
w = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
scale_plot = 0.4
scale_process = 0.4

output_filename = 'output.mp4'
fourcc = cv2.VideoWriter_fourcc(*'mp4v') 

out = cv2.VideoWriter(output_filename, fourcc, fps, (int(w*scale_process), int(h*scale_process)))
skip = 0.2
skip_frame = max(1, int(fps*skip))
frame_count = 0

In [None]:
# create_data = SelectAreaDetect(video_path, scale_plot)
# create_data.create_txt()

: 

In [None]:
lass_detect = defaultdict()
while True:
    ret, frame = cap.read()
    if not ret:
        break
    frame_count += 1
    if frame_count % skip_frame != 0:
        continue
    
    frame = cv2.resize(frame, (int(w*scale_process), int(h*scale_process)))
    print(f"Tracking frame {frame_count}.")
    vehicle_result = model_vehicle.track(
        frame,
        conf=0.75,
        iou=0.7,
        tracker=tracker_path,
        device=device,
        persist=True,
        verbose=False
    )[0]
    
    vehicle_crops = []
    vehicle_data = []

    for bbox in vehicle_result.boxes:
        if bbox.id is not None:
            x1, y1, x2, y2 = bbox.xyxy[0].cpu().numpy().astype(int)
            id_ = bbox.id[0].cpu().numpy().astype(int)
            cls = bbox.cls[0].cpu().numpy().astype(int)
            
            vehicle_crops.append(frame[y1:y2, x1:x2])
            vehicle_data.append({
                "id": int(id_),
                "cls": cls,
                "bbox": (x1, y1, x2, y2),
                "color": colors[cls]
            })

    if not vehicle_crops:
        print(f"    Don't have any vehicle in the frame.")
        out.write(frame)
        continue
    
    print(f"    Have {len(vehicle_crops)} object in frame.")
    print(f"Detect License Plate.")
    all_lp_results = model_license_plate.predict(vehicle_crops, conf=0.8, verbose=False, iou=0.8, device=device)

    license_crops = []
    license_data = []

    for i, lp_result in enumerate(all_lp_results):
        current_vehicle = vehicle_data[i]
        x1, y1, x2, y2 = current_vehicle["bbox"]
        color = current_vehicle["color"]

        try:
            lp_box = lp_result.boxes[0] 
            lx1, ly1, lx2, ly2 = lp_box.xyxy[0].cpu().numpy().astype(int)
            
            license = vehicle_crops[i][ly1:ly2, lx1:lx2] 
            license = process_img(license)
            license = cv2.merge([license, license, license])
            license_crops.append(license)
            license_data.append(current_vehicle)
        except:
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            if current_vehicle["id"] in lass_detect:
                 label = lass_detect[current_vehicle["id"]]['label']
                 cv2.putText(frame, f"{label}", (x1, y2-3), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)
            else:
                 cv2.putText(frame, f"Can't OCR", (x1, y2-3), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)
            pass

    if not license_crops:
        print(f"Can't detect any license plate.")
        out.write(frame)
        continue
    print(f"    Can detect {len(license_crops)} in {len(vehicle_crops)} objects.")
    print(f"OCR.......\n")
    all_ocr_results = ocr_model.predict(license_crops)

    for i, ocr_result in enumerate(all_ocr_results):
        
        current_vehicle = license_data[i]
        id_ = current_vehicle["id"]
        x1, y1, x2, y2 = current_vehicle["bbox"]
        color = current_vehicle["color"]
        cls = current_vehicle["cls"]

        ocr_data = ocr_result
        label = ""
        conf = 0

        if ocr_data and ocr_data['rec_texts']:
            num_texts = len(ocr_data['rec_texts'])
            if num_texts > 0:
                for j in range(num_texts):
                    label += ocr_data['rec_texts'][j]
                    conf += ocr_data['rec_scores'][j]
                conf /= num_texts
                label = post_process_license_plate(label, int(cls))

        label_to_draw = ""
        if id_ not in lass_detect:
            lass_detect[id_] = {"label": label, "conf": conf}
            label_to_draw = label
            cv2.putText(frame, f"{label_to_draw}", (x1, y2-3), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)
        else:
            if lass_detect[id_]['conf'] < conf:
                lass_detect[id_]['conf'] = conf
                lass_detect[id_]['label'] = label
            label_to_draw = lass_detect[id_]['label']
            cv2.putText(frame, f"{label_to_draw}", (x1, y2-3), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)

        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)

    out.write(frame)

cap.release()
out.release()
cv2.destroyAllWindows()

qt.qpa.plugin: Could not find the Qt platform plugin "wayland" in "/home/namtt/miniconda3/envs/trafic_detection/lib/python3.10/site-packages/cv2/qt/plugins"
