In [None]:
import os
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment

# ------------------ MHT Tracker ------------------
class Track:
    def __init__(self, track_id, bbox):
        self.track_id = track_id
        self.bbox = bbox
        self.age = 0
        self.hits = 1

class MHTTracker:
    def __init__(self, iou_threshold=0.3, max_age=5, min_hits=2):
        self.tracks = []
        self.next_id = 1
        self.iou_threshold = iou_threshold
        self.max_age = max_age
        self.min_hits = min_hits

    def iou(self, boxA, boxB):
        xA = max(boxA[0], boxB[0])
        yA = max(boxA[1], boxB[1])
        xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
        yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])
        interArea = max(0, xB - xA) * max(0, yB - yA)
        boxAArea = boxA[2] * boxA[3]
        boxBArea = boxB[2] * boxB[3]
        return interArea / float(boxAArea + boxBArea - interArea + 1e-6)

    def update(self, detections):
        assigned_tracks = set()
        assigned_dets = set()

        if len(self.tracks) == 0:
            for det in detections:
                self.tracks.append(Track(self.next_id, det))
                self.next_id += 1
            return self.get_active_tracks()

        cost_matrix = np.zeros((len(self.tracks), len(detections)), dtype=np.float32)
        for i, track in enumerate(self.tracks):
            for j, det in enumerate(detections):
                cost_matrix[i, j] = 1 - self.iou(track.bbox, det)

        row_ind, col_ind = linear_sum_assignment(cost_matrix)

        for r, c in zip(row_ind, col_ind):
            if cost_matrix[r, c] < 1 - self.iou_threshold:
                self.tracks[r].bbox = detections[c]
                self.tracks[r].age = 0
                self.tracks[r].hits += 1
                assigned_tracks.add(r)
                assigned_dets.add(c)

        for i, track in enumerate(self.tracks):
            if i not in assigned_tracks:
                track.age += 1

        self.tracks = [t for t in self.tracks if t.age <= self.max_age]

        for j, det in enumerate(detections):
            if j not in assigned_dets:
                self.tracks.append(Track(self.next_id, det))
                self.next_id += 1

        return self.get_active_tracks()

    def get_active_tracks(self):
        return [t for t in self.tracks if t.hits >= self.min_hits]

# ------------------ Utils ------------------
def get_file_lists(folder):
    return sorted([os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.json')])

def compute_iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
    yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])
    interArea = max(0, xB - xA) * max(0, yB - yA)
    boxAArea = boxA[2] * boxA[3]
    boxBArea = boxB[2] * boxB[3]
    return interArea / float(boxAArea + boxBArea - interArea + 1e-6)

def draw_tracks(image, tracks):
    for track in tracks:
        x, y, w, h = track.bbox
        cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (0, 255, 0), 2)
        cv2.putText(image, f'ID:{track.track_id}', (int(x), int(y - 10)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
    return image

# ------------------ Evaluation ------------------
def evaluate_map50_95(all_gt_boxes, all_pred_boxes):
    iou_thresholds = np.arange(0.5, 1.0, 0.05)
    APs = []

    for iou_thresh in iou_thresholds:
        TP, FP, total_gt = 0, 0, 0

        for gt_boxes, pred_boxes in zip(all_gt_boxes, all_pred_boxes):
            matched_gt = set()
            total_gt += len(gt_boxes)

            for pred in pred_boxes:
                matched = False
                for i, gt in enumerate(gt_boxes):
                    if i not in matched_gt and compute_iou(pred, gt) >= iou_thresh:
                        TP += 1
                        matched_gt.add(i)
                        matched = True
                        break
                if not matched:
                    FP += 1

        FN = total_gt - TP
        precision = TP / (TP + FP + 1e-6)
        recall = TP / (TP + FN + 1e-6)
        AP = precision  # 简化，因没有score信息
        APs.append(AP)
        print(f"IOU {iou_thresh:.2f} => Precision: {precision:.4f}, Recall: {recall:.4f}")

    mAP_50 = APs[0]
    mAP_50_95 = np.mean(APs)

    print("\n=== mAP Evaluation ===")
    print(f"mAP@0.50     : {mAP_50:.4f}")
    print(f"mAP@0.50:0.95: {mAP_50_95:.4f}")

    plt.plot(np.arange(0.5, 1.0, 0.05), APs, marker='o')
    plt.title("mAP@0.5:0.95 vs IOU Threshold")
    plt.xlabel("IOU Threshold")
    plt.ylabel("AP")
    plt.grid(True)
    plt.savefig("mAP_curve.png")
    plt.show()

# ------------------ Main Pipeline ------------------
def main():
    folder = '/Users/jingjing/Desktop/single_island/X_band_label'  # 修改为你的路径
    output_dir = './output'
    os.makedirs(output_dir, exist_ok=True)

    json_files = get_file_lists(folder)
    tracker = MHTTracker(iou_threshold=0.4, max_age=7, min_hits=2)

    all_gt_boxes = []
    all_pred_boxes = []

    for frame_idx, json_file in enumerate(json_files):
        with open(json_file, 'r') as f:
            data = json.load(f)

        annotations = data.get('annotations', [])
        gt_boxes = [[ann['xmin'], ann['ymin'], ann['width'], ann['height']] for ann in annotations]

        tracks = tracker.update(gt_boxes)
        pred_boxes = [t.bbox for t in tracks]

        all_gt_boxes.append(gt_boxes)
        all_pred_boxes.append(pred_boxes)

        image_filename = data["image"].get("file_name", "")
        image_path = os.path.join(folder, image_filename)
        if not os.path.exists(image_path):
            print(f"⚠️ 图像文件未找到: {image_path}")
            continue

        image = cv2.imread(image_path)
        if image is None:
            print(f"⚠️ 无法读取图像: {image_path}")
            continue

        image = draw_tracks(image, tracks)
        out_path = os.path.join(output_dir, f"frame_{frame_idx:04d}.png")
        cv2.imwrite(out_path, image)

        print(f"✅ Frame {frame_idx}: {len(gt_boxes)} GTs, {len(pred_boxes)} tracks")

    evaluate_map50_95(all_gt_boxes, all_pred_boxes)

if __name__ == '__main__':
    main()
