In [1]:
import sys
import os

PROJECT_ROOT = os.path.abspath("..")
sys.path.insert(0, PROJECT_ROOT)

In [2]:
import cv2
import joblib
import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error
from datetime import datetime
from tqdm import tqdm

from src.svm_detector.features.hog import extract_hog
from src.svm_detector.inference.sliding_window import run_svm_inference
from src.svm_detector.inference.nms import nms
from config import config
from pathlib import Path

In [11]:
# ---------------- CONFIG ----------------
PROJECT_ROOT = Path.cwd().parent
WINDOW_SIZE = config.WIN_SIZE
STEP = 16
SCORE_THRESHOLD = 0.8
IOU_THRESHOLD = 0.2
IS_NMS = True

IMAGE_DIR = f"{PROJECT_ROOT}/dataset/yolo/test/images"
LABEL_DIR = f"{PROJECT_ROOT}/dataset/yolo/test/labels"

SVM_PATH = f"{PROJECT_ROOT}/models/svm/svm_chicken_v1.5_base.pkl"

In [4]:
def iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    inter_w = max(0, xB - xA)
    inter_h = max(0, yB - yA)
    inter_area = inter_w * inter_h

    if inter_area == 0:
        return 0.0

    boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])

    return inter_area / (boxA_area + boxB_area - inter_area)


In [5]:
def load_yolo_boxes(label_path, img_w, img_h):
    boxes = []
    if not os.path.exists(label_path):
        return boxes

    with open(label_path, "r") as f:
        for line in f:
            _, cx, cy, w, h = map(float, line.split())

            cx *= img_w
            cy *= img_h
            w *= img_w
            h *= img_h

            x1 = int(cx - w / 2)
            y1 = int(cy - h / 2)
            x2 = int(cx + w / 2)
            y2 = int(cy + h / 2)

            boxes.append((x1, y1, x2, y2))
    return boxes


In [6]:
def evaluate_detection(detections, gt_boxes):
    matched_gt = set()
    tp = 0
    fp = 0

    for det in detections:
        det_box = det[:4]
        matched = False

        for i, gt in enumerate(gt_boxes):
            if i in matched_gt:
                continue

            if iou(det_box, gt) > 0:
                tp += 1
                matched_gt.add(i)
                matched = True
                break

        if not matched:
            fp += 1

    fn = len(gt_boxes) - len(matched_gt)
    return tp, fp, fn


In [13]:
def apply_nms(detections, score_thresh=0.0, nms_thresh=0.4):
    if len(detections) == 0:
        return []

    boxes = []
    scores = []

    for det in detections:
        x1, y1, x2, y2 = det[:4]
        boxes.append([x1, y1, x2 - x1, y2 - y1])
        scores.append(float(det[4]))

    indices = nms(boxes, scores, score_thresh, nms_thresh)
    if len(indices) == 0:
        return []

    return [detections[i] for i in indices.flatten()]

In [14]:
total_tp = total_fp = total_fn = 0

image_files = [
    f for f in os.listdir(IMAGE_DIR)
    if f.endswith((".jpg", ".png", ".jpeg"))
]

svm = joblib.load(SVM_PATH)

for img_name in tqdm(image_files, desc="SVM Detection Eval"):
    img_path = os.path.join(IMAGE_DIR, img_name)
    label_path = os.path.join(LABEL_DIR, os.path.splitext(img_name)[0] + ".txt")

    img = cv2.imread(img_path)
    if img is None:
        continue

    h, w = img.shape[:2]

    detections = run_svm_inference(
        img=img,
        svm=svm,
        feature_extractor=extract_hog,
        window_size=WINDOW_SIZE,
        step=STEP,
        threshold=SCORE_THRESHOLD,
    )

    if IS_NMS:
        detections = apply_nms(
            detections,
            score_thresh=SCORE_THRESHOLD,
            nms_thresh=IOU_THRESHOLD,
        )

    gt_boxes = load_yolo_boxes(label_path, w, h)

    tp, fp, fn = evaluate_detection(detections, gt_boxes)

    total_tp += tp
    total_fp += fp
    total_fn += fn


SVM Detection Eval: 100%|██████████| 1000/1000 [01:47<00:00,  9.27it/s]


In [16]:
precision = total_tp / (total_tp + total_fp + 1e-8)
recall    = total_tp / (total_tp + total_fn + 1e-8)
f1        = 2 * precision * recall / (precision + recall + 1e-8)

print("========= DETECTION METRICS =========")
print(f"TP: {total_tp}")
print(f"FP: {total_fp}")
print(f"FN: {total_fn}")
print(f"Precision: {precision:.4f}")
print(f"Recall   : {recall:.4f}")
print(f"F1-score : {f1:.4f}")
print("====================================")


TP: 1740
FP: 1837
FN: 5862
Precision: 0.4864
Recall   : 0.2289
F1-score : 0.3113
