In [10]:
from ultralytics import YOLO
import supervision as sv
import torch
import cv2
import numpy as np
import os
from collections import Counter

model = YOLO('models/yolov8n_threat_detection.pt')

class_names = ['ammo', 'firearm', 'grenade', 'knife', 'pistol', 'rocket']

def detect_with_nms(image_path, iou_threshold=0.60):
    image = cv2.imread(image_path)
    results = model(image, verbose=False)[0]
        
    nms_detections = sv.Detections.from_ultralytics(results).with_nms(threshold=iou_threshold)
    
    return nms_detections

def calculate_metrics(predictions, targets):
    pred_counter = Counter()
    target_counter = Counter()

    for pred, target in zip(predictions, targets):
        pred_counter.update(pred['labels'].tolist())
        target_counter.update(target['labels'].tolist())

    all_classes = set(pred_counter.keys()) | set(target_counter.keys())
    
    tp = sum(min(pred_counter[c], target_counter[c]) for c in all_classes)
    fp = sum(max(pred_counter[c] - target_counter[c], 0) for c in all_classes)
    fn = sum(max(target_counter[c] - pred_counter[c], 0) for c in all_classes)

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0


    return {
        'Precision': precision,
        'Recall': recall,
        'F1 Score': f1
    }

def evaluate_dataset(dataset_path):
    predictions = []
    targets = []

    images_path = os.path.join(dataset_path, 'images')
    labels_path = os.path.join(dataset_path, 'labels')

    image_files = [f for f in os.listdir(images_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
    total_images = len(image_files)

    for i, img_file in enumerate(image_files):
        img_path = os.path.join(images_path, img_file)
        txt_file = os.path.splitext(img_file)[0] + '.txt'
        txt_path = os.path.join(labels_path, txt_file)
        
        try:
            detections = detect_with_nms(img_path)
            
            pred = {
                'boxes': torch.from_numpy(detections.xyxy),
                'scores': torch.from_numpy(detections.confidence),
                'labels': torch.from_numpy(detections.class_id)
            }
            
            if os.path.exists(txt_path):
                with open(txt_path, 'r') as f:
                    lines = f.readlines()
                
                gt_boxes = []
                gt_labels = []
                for line in lines:
                    class_id, x, y, w, h = map(float, line.strip().split())
                    gt_boxes.append([x, y, x+w, y+h])
                    gt_labels.append(int(class_id))
                
                target = {
                    'boxes': torch.tensor(gt_boxes),
                    'labels': torch.tensor(gt_labels)
                }
            else:
                print(f"Warning: No annotation file found for {img_file}")
                target = {
                    'boxes': torch.zeros((0, 4)),
                    'labels': torch.zeros(0, dtype=torch.long)
                }
            
            predictions.append(pred)
            targets.append(target)
        
        except Exception as e:
            print(f"Error processing {img_file}: {str(e)}")
        
        if (i + 1) % 100 == 0 or (i + 1) == total_images:
            print(f"Processed {i + 1}/{total_images} images")

    if predictions and targets:
        metrics = calculate_metrics(predictions, targets)
    else:
        print("Error: No valid predictions or ground truth found.")
        metrics = {
            'mAP@0.5': 0.0,
            'Precision': 0.0,
            'Recall': 0.0,
            'F1 Score': 0.0
        }
    
    return metrics

dataset_path = 'datasets/dangerous-objects/valid'
print("Starting evaluation...")
results = evaluate_dataset(dataset_path)

print("\nEvaluation Results:")
for metric, value in results.items():
    print(f"{metric}: {value:.4f}")

Starting evaluation...
Processed 100/1749 images
Processed 200/1749 images
Processed 300/1749 images
Processed 400/1749 images
Processed 500/1749 images
Processed 600/1749 images
Processed 700/1749 images
Processed 800/1749 images
Processed 900/1749 images
Processed 1000/1749 images
Processed 1100/1749 images
Processed 1200/1749 images
Processed 1300/1749 images
Processed 1400/1749 images
Processed 1500/1749 images
Processed 1600/1749 images
Processed 1700/1749 images
Processed 1749/1749 images

Evaluation Results:
Precision: 1.0000
Recall: 0.8620
F1 Score: 0.9259
