In [None]:
from ultralytics import YOLO
import cv2
import os
import json
import time
from datetime import datetime
import numpy as np

class HierarchicalDetector:
    def __init__(self, main_model_path='yolov8n.pt', sub_model_path='best.pt'):
        self.main_model = YOLO(main_model_path)
        self.sub_model = YOLO(sub_model_path)
        self.main_model_classes = self.main_model.model.names
        self.sub_model_classes = self.sub_model.model.names
        self.object_count = {}
        self.save_dir = '../data/detected_objects'
        os.makedirs(self.save_dir, exist_ok=True)
        

        self.frame_size = (480, 360)
        self.main_conf_threshold = 0.3
        self.sub_conf_threshold = 0.3
        
  
        self.prev_detections = []
        
       
        self.saved_objects = set()
    
    def format_detections(self, detections):
        formatted_detections = []
        
        for det in detections:
            main_detection = {
                "object": det['object'],
                "id": det['id'],
                "bbox": det['bbox']
            }
            
            if det['subobjects']:
                sub = det['subobjects'][0]
                main_detection["subobject"] = {
                    "object": sub['object'],
                    "id": sub['id'],
                    "bbox": sub['bbox']
                }
            
            formatted_detections.append(main_detection)
        
        return formatted_detections

    def get_unique_id(self, class_name):
        if class_name not in self.object_count:
            self.object_count[class_name] = 0
        self.object_count[class_name] += 1
        return self.object_count[class_name]
    
    def draw_detections(self, frame, detections):
        """Draw detections on frame"""
        display_frame = frame.copy()
        
        for det in detections:
            main_bbox = det['bbox']
            main_obj_name = det['object']
            main_obj_id = det['id']
            main_conf = det['confidence']
            
            cv2.rectangle(display_frame, 
                         (main_bbox[0], main_bbox[1]),
                         (main_bbox[2], main_bbox[3]), 
                         (0, 255, 0), 2)
            label = f"{main_obj_name} {main_obj_id} ({main_conf:.2f})"
            cv2.putText(display_frame, label, 
                       (main_bbox[0], main_bbox[1] - 10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
            
            for sub_obj in det['subobjects']:
                sub_bbox = sub_obj['bbox']
                sub_obj_name = sub_obj['object']
                sub_obj_id = sub_obj['id']
                sub_conf = sub_obj['confidence']
                
                cv2.rectangle(display_frame,
                            (sub_bbox[0], sub_bbox[1]),
                            (sub_bbox[2], sub_bbox[3]),
                            (255, 0, 0), 2)
                sub_label = f"{sub_obj_name} {sub_obj_id} ({sub_conf:.2f})"
                cv2.putText(display_frame, sub_label,
                           (sub_bbox[0], sub_bbox[1] - 10),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
        
        return display_frame

    def process_frame(self, frame):
        detections = []
        
        frame = cv2.resize(frame, self.frame_size)
        
        
        results_main = self.main_model(frame, conf=self.main_conf_threshold)
        
        for result_main in results_main:
            boxes = result_main.boxes
            if boxes is None or len(boxes) == 0:
                continue
            
            for box, main_cls, conf in zip(boxes.xyxy, boxes.cls, boxes.conf):
                main_bbox = list(map(int, box))
                main_cls_idx = int(main_cls.item())
                main_conf = float(conf.item())
                main_obj_name = self.main_model_classes[main_cls_idx]
                main_obj_id = self.get_unique_id(main_obj_name)
                

            
                main_obj_identifier = f"{main_obj_name}{main_obj_id}"
                
             
                if (main_bbox[2] - main_bbox[0]) > 20 and (main_bbox[3] - main_bbox[1]) > 20:
                    cropped = frame[main_bbox[1]:main_bbox[3], main_bbox[0]:main_bbox[2]]
                    sub_objects = []
                    
                   
                    results_sub = self.sub_model(cropped, conf=self.sub_conf_threshold)
                    
                    for result_sub in results_sub:
                        sub_boxes = result_sub.boxes
                        if sub_boxes is None or len(sub_boxes) == 0:
                            continue
                        
                        for sub_box, sub_cls, sub_conf in zip(sub_boxes.xyxy, sub_boxes.cls, sub_boxes.conf):
                            sub_cls_idx = int(sub_cls.item())
                            sub_conf = float(sub_conf.item())
                            sub_obj_name = self.sub_model_classes[sub_cls_idx]
                            sub_obj_id = self.get_unique_id(sub_obj_name)
                            
                           
                            sub_obj_identifier = f"{main_obj_identifier}_{sub_obj_name}"
                            
                            
                            sub_bbox = list(map(int, sub_box))
                            global_bbox = [
                                main_bbox[0] + sub_bbox[0],
                                main_bbox[1] + sub_bbox[1],
                                main_bbox[0] + sub_bbox[2],
                                main_bbox[1] + sub_bbox[3]
                            ]
                            
                           
                            if sub_conf > 0.5 and sub_obj_identifier not in self.saved_objects:
                                save_path = os.path.join(
                                    self.save_dir,
                                    f"{main_obj_name}{main_obj_id}_{sub_obj_name}{sub_obj_id}.jpg"
                                )
                                cv2.imwrite(save_path, cropped)
                                self.saved_objects.add(sub_obj_identifier)
                                
                                sub_objects.append({
                                    "object": sub_obj_name,
                                    "id": sub_obj_id,
                                    "confidence": sub_conf,
                                    "bbox": global_bbox,
                                    "image_path": save_path
                                })
                            elif sub_conf > 0.5: 
                                sub_objects.append({
                                    "object": sub_obj_name,
                                    "id": sub_obj_id,
                                    "confidence": sub_conf,
                                    "bbox": global_bbox
                                })
                    
                    detection = {
                        "object": main_obj_name,
                        "id": main_obj_id,
                        "confidence": main_conf,
                        "bbox": main_bbox,
                        "subobjects": sub_objects
                    }
                    detections.append(detection)
        
     
        if len(detections) > 0:
            self.prev_detections = detections
        
        
        if len(detections) == 0 and len(self.prev_detections) > 0:
            detections = self.prev_detections
        
        
        display_frame = self.draw_detections(frame, detections)
        
        return display_frame, detections

    def process_video(self, video_path, output_path=None):
        """Process video with smooth output"""
        cap = cv2.VideoCapture(video_path)
        frame_count = 0
        total_time = 0
        all_detections = []
        
        if output_path:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_path, fourcc, 30.0, self.frame_size)
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            
            start_time = time.time()
            
           
            display_frame, detections = self.process_frame(frame)
            
            
            formatted_detections = self.format_detections(detections)
            
            process_time = time.time() - start_time
            total_time += process_time
            current_fps = 1 / (process_time + 1e-6)
            
         
            cv2.putText(display_frame, f"FPS: {current_fps:.1f}",
                       (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            if output_path:
                out.write(display_frame)
            
            cv2.imshow('Detection', display_frame)
            
            if len(formatted_detections) > 0:
                all_detections.append(formatted_detections)
            
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
            
            frame_count += 1
        
        metrics = {
            "total_frames": frame_count,
            "total_time": total_time,
            "average_fps": frame_count / (total_time + 1e-6)
        }
        
        cap.release()
        if output_path:
            out.release()
        cv2.destroyAllWindows()
        
        return metrics, all_detections


if __name__ == "__main__":
    detector = HierarchicalDetector('yolov8n.pt', 'best.pt')
    video_path = '../data/test.mp4'
    output_path = '../results/output.mp4'
    
    print("Processing video...")
    metrics, detections = detector.process_video(video_path, output_path)
    
    with open('../results/detections.json', 'w') as f:
        json.dump(detections, f, indent=4)
    
    print("\nPerformance Metrics:")
    print(f"Total Frames: {metrics['total_frames']}")
    print(f"Total Time: {metrics['total_time']:.2f} seconds")
    print(f"Average FPS: {metrics['average_fps']:.2f}")

Processing video...

0: 480x640 8 persons, 1 car, 4 motorcycles, 1 bus, 4 trucks, 92.2ms
Speed: 5.5ms preprocess, 92.2ms inference, 6.9ms postprocess per image at shape (1, 3, 480, 640)

0: 640x256 (no detections), 45.2ms
Speed: 1.5ms preprocess, 45.2ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 256)

0: 640x544 1 tyre, 100.0ms
Speed: 3.5ms preprocess, 100.0ms inference, 1.0ms postprocess per image at shape (1, 3, 640, 544)

0: 640x640 1 door, 1 tyre, 130.7ms
Speed: 3.6ms preprocess, 130.7ms inference, 11.5ms postprocess per image at shape (1, 3, 640, 640)

0: 640x256 (no detections), 59.5ms
Speed: 20.0ms preprocess, 59.5ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 256)

0: 640x544 (no detections), 96.3ms
Speed: 3.0ms preprocess, 96.3ms inference, 0.4ms postprocess per image at shape (1, 3, 640, 544)

0: 640x640 1 helmet, 108.7ms
Speed: 4.3ms preprocess, 108.7ms inference, 0.8ms postprocess per image at shape (1, 3, 640, 640)

0: 640x288 (no detection

2025-01-11 12:03:41.461 Python[69328:5155189] +[IMKClient subclass]: chose IMKClient_Modern
2025-01-11 12:03:41.461 Python[69328:5155189] +[IMKInputSession subclass]: chose IMKInputSession_Modern


0: 640x608 (no detections), 114.0ms
Speed: 3.1ms preprocess, 114.0ms inference, 0.4ms postprocess per image at shape (1, 3, 640, 608)

0: 640x288 (no detections), 61.7ms
Speed: 1.9ms preprocess, 61.7ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 288)

0: 640x640 (no detections), 98.1ms
Speed: 3.1ms preprocess, 98.1ms inference, 0.5ms postprocess per image at shape (1, 3, 640, 640)

0: 640x320 1 helmet, 57.7ms
Speed: 2.4ms preprocess, 57.7ms inference, 0.5ms postprocess per image at shape (1, 3, 640, 320)

0: 640x256 (no detections), 57.5ms
Speed: 1.8ms preprocess, 57.5ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 256)

0: 640x512 (no detections), 93.1ms
Speed: 2.6ms preprocess, 93.1ms inference, 0.4ms postprocess per image at shape (1, 3, 640, 512)

0: 640x288 (no detections), 51.6ms
Speed: 2.0ms preprocess, 51.6ms inference, 0.3ms postprocess per image at shape (1, 3, 640, 288)

0: 640x320 (no detections), 55.9ms
Speed: 2.0ms preprocess, 55.9ms inferen