In [None]:
import os
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
from collections import defaultdict

# ------------------ Track-Oriented Multiple Hypothesis Tracking (TOMHT) ------------------
class Track:
    def __init__(self, track_id, bbox):
        self.track_id = track_id
        self.bbox = bbox
        self.age = 0
        self.hits = 1
        self.hypotheses = [bbox]

class TOMHTTracker:
    def __init__(self, iou_threshold=0.3, max_age=3, min_hits=3, max_hypotheses=10):
        self.tracks = []
        self.next_id = 1
        self.iou_threshold = iou_threshold
        self.max_age = max_age
        self.min_hits = min_hits
        self.max_hypotheses = max_hypotheses

    def iou(self, boxA, boxB):
        xA = max(boxA[0], boxB[0])
        yA = max(boxA[1], boxB[1])
        xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
        yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])
        interArea = max(0, xB - xA) * max(0, yB - yA)
        boxAArea = boxA[2] * boxA[3]
        boxBArea = boxB[2] * boxB[3]
        return interArea / float(boxAArea + boxBArea - interArea + 1e-6)

    def update(self, detections):
        assigned_tracks = set()
        assigned_dets = set()
        new_tracks = []

        if len(self.tracks) == 0:
            for det in detections:
                self.tracks.append(Track(self.next_id, det))
                self.next_id += 1
            return self.get_active_tracks()

        cost_matrix = np.zeros((len(self.tracks), len(detections)), dtype=np.float32)
        for i, track in enumerate(self.tracks):
            for j, det in enumerate(detections):
                cost_matrix[i, j] = 1 - self.iou(track.bbox, det)

        row_ind, col_ind = linear_sum_assignment(cost_matrix)

        for r, c in zip(row_ind, col_ind):
            if cost_matrix[r, c] < 1 - self.iou_threshold:
                track = self.tracks[r]
                track.bbox = detections[c]
                track.age = 0
                track.hits += 1
                track.hypotheses.append(detections[c])

                if len(track.hypotheses) > self.max_hypotheses:
                    track.hypotheses = track.hypotheses[-self.max_hypotheses:]

                assigned_tracks.add(r)
                assigned_dets.add(c)

        for i, track in enumerate(self.tracks):
            if i not in assigned_tracks:
                track.age += 1

        self.tracks = [t for t in self.tracks if t.age <= self.max_age]

        for j, det in enumerate(detections):
            if j not in assigned_dets:
                new_tracks.append(Track(self.next_id, det))
                self.next_id += 1

        self.tracks.extend(new_tracks)
        return self.get_active_tracks()

    def get_active_tracks(self):
        return [t for t in self.tracks if t.hits >= self.min_hits]

# ------------------ 工具函数 ------------------
def get_file_lists(folder):
    return sorted([os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.json')])

def compute_iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
    yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])
    interArea = max(0, xB - xA) * max(0, yB - yA)
    boxAArea = boxA[2] * boxA[3]
    boxBArea = boxB[2] * boxB[3]
    return interArea / float(boxAArea + boxBArea - interArea + 1e-6)

def draw_tracks(image, tracks):
    for track in tracks:
        x, y, w, h = track.bbox
        cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (0, 255, 0), 2)
        cv2.putText(image, f'ID:{track.track_id}', (int(x), int(y - 10)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
    return image

def evaluate_metrics(TP, FP, FN):
    precision = TP / (TP + FP + 1e-6)
    recall = TP / (TP + FN + 1e-6)
    f1 = 2 * precision * recall / (precision + recall + 1e-6)

    print("\n=== Evaluation Metrics ===")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")

    plt.bar(["Precision", "Recall", "F1"], [precision, recall, f1], color=["blue", "green", "red"])
    plt.title("TOMHT Evaluation Metrics")
    plt.ylim(0, 1)
    plt.savefig("evaluation_metrics.png")
    plt.show()

def calculate_map(all_predictions, all_gts, iou_thresholds=np.arange(0.5, 1.0, 0.05)):
    aps = []
    for thresh in iou_thresholds:
        tp, fp, total_gt = 0, 0, 0
        for preds, gts in zip(all_predictions, all_gts):
            matched_gt = set()
            total_gt += len(gts)
            for pred in preds:
                matched = False
                for i, gt in enumerate(gts):
                    if i not in matched_gt and compute_iou(pred, gt) >= thresh:
                        matched_gt.add(i)
                        matched = True
                        break
                if matched:
                    tp += 1
                else:
                    fp += 1
        fn = total_gt - tp
        precision = tp / (tp + fp + 1e-6)
        recall = tp / (tp + fn + 1e-6)
        aps.append(precision)
        print(f"AP@{thresh:.2f}: {precision:.4f}")

    print(f"\n=== mAP Summary ===")
    print(f"mAP@50:     {aps[0]:.4f}")
    print(f"mAP@50-95:  {np.mean(aps):.4f}")

# ------------------ 主函数 ------------------
def main():
    folder = '/Users/jingjing/Desktop/single_island/X_band_label'
    output_dir = './output'
    os.makedirs(output_dir, exist_ok=True)

    json_files = get_file_lists(folder)
    tracker = TOMHTTracker(iou_threshold=0.3, max_age=3, min_hits=3, max_hypotheses=10)

    TP, FP, FN = 0, 0, 0
    IOU_THRESH = 0.5

    all_predictions = []
    all_groundtruths = []

    for frame_idx, json_file in enumerate(json_files):
        with open(json_file, 'r') as f:
            data = json.load(f)

        annotations = data.get('annotations', [])
        bboxes = [[ann['xmin'], ann['ymin'], ann['width'], ann['height']] for ann in annotations]

        tracks = tracker.update(bboxes)
        track_bboxes = [t.bbox for t in tracks]

        all_predictions.append(track_bboxes)
        all_groundtruths.append(bboxes)

        matched_gt = set()
        matched_pred = set()

        if len(bboxes) > 0 and len(track_bboxes) > 0:
            eval_cost = np.ones((len(bboxes), len(track_bboxes)), dtype=np.float32)
            for i, gt in enumerate(bboxes):
                for j, pred in enumerate(track_bboxes):
                    iou_score = compute_iou(gt, pred)
                    if iou_score > IOU_THRESH:
                        eval_cost[i, j] = 1 - iou_score

            gt_ind, pred_ind = linear_sum_assignment(eval_cost)
            for i, j in zip(gt_ind, pred_ind):
                if compute_iou(bboxes[i], track_bboxes[j]) > IOU_THRESH:
                    matched_gt.add(i)
                    matched_pred.add(j)

        TP += len(matched_gt)
        FP += len(track_bboxes) - len(matched_pred)
        FN += len(bboxes) - len(matched_gt)

        image_filename = data["image"]["file_name"]
        image_path = os.path.join(folder, image_filename)
        if not os.path.exists(image_path):
            print(f"⚠️ 图像文件未找到: {image_path}")
            continue

        image = cv2.imread(image_path)
        image = draw_tracks(image, tracks)
        out_path = os.path.join(output_dir, f"frame_{frame_idx:04d}.png")
        cv2.imwrite(out_path, image)

        print(f"✅ Frame {frame_idx}: {len(bboxes)} GTs, {len(tracks)} tracks")

    evaluate_metrics(TP, FP, FN)
    calculate_map(all_predictions, all_groundtruths)

if __name__ == '__main__':
    main()
1