In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from typing import List, Dict
from pathlib import Path


In [None]:
def compute_iou(boxA, boxB):
    """
    box format: [x1, y1, x2, y2]
    """
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    inter_area = max(0, xB - xA) * max(0, yB - yA)

    boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])

    union = boxA_area + boxB_area - inter_area + 1e-6
    return inter_area / union


In [None]:
class SimpleSORTTracker:
    def __init__(self, iou_threshold=0.3, max_lost=5):
        self.iou_threshold = iou_threshold
        self.max_lost = max_lost

        self.next_id = 0
        self.tracks = {}      # id -> bbox
        self.lost_counts = {} # id -> lost frames

    def update(self, detections: List[np.ndarray]):
        """
        detections: list of [x1,y1,x2,y2]
        returns: dict {track_id: bbox}
        """
        updated_tracks = {}
        used_detections = set()

        # Match existing tracks
        for tid, prev_box in self.tracks.items():
            best_iou = 0
            best_det_idx = -1

            for i, det in enumerate(detections):
                if i in used_detections:
                    continue
                iou = compute_iou(prev_box, det)
                if iou > best_iou:
                    best_iou = iou
                    best_det_idx = i

            if best_iou > self.iou_threshold:
                updated_tracks[tid] = detections[best_det_idx]
                self.lost_counts[tid] = 0
                used_detections.add(best_det_idx)
            else:
                self.lost_counts[tid] += 1
                if self.lost_counts[tid] <= self.max_lost:
                    updated_tracks[tid] = prev_box

        # Create new tracks
        for i, det in enumerate(detections):
            if i not in used_detections:
                updated_tracks[self.next_id] = det
                self.lost_counts[self.next_id] = 0
                self.next_id += 1

        self.tracks = updated_tracks
        return self.tracks


In [None]:
def track_wagons_over_time(
    detections_per_frame: List[List[np.ndarray]]
):
    tracker = SimpleSORTTracker()
    history = []

    for frame_idx, detections in enumerate(detections_per_frame):
        tracks = tracker.update(detections)
        history.append({
            "frame_index": frame_idx,
            "tracks": tracks.copy()
        })

    return history


In [None]:
def count_unique_wagons(track_history):
    wagon_ids = set()

    for entry in track_history:
        for tid in entry["tracks"].keys():
            wagon_ids.add(tid)

    return len(wagon_ids)


In [None]:
def visualize_tracking(
    image: np.ndarray,
    tracks: Dict[int, np.ndarray]
):
    vis = image.copy()

    for tid, box in tracks.items():
        x1, y1, x2, y2 = map(int, box)
        cv2.rectangle(vis, (x1,y1), (x2,y2), (0,255,0), 2)
        cv2.putText(
            vis,
            f"ID {tid}",
            (x1, y1 - 5),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.6,
            (0,255,0),
            2
        )

    plt.figure(figsize=(6,4))
    plt.imshow(vis)
    plt.axis("off")
    plt.show()


In [None]:
# Simulated detections across 6 frames
detections_per_frame = [
    [np.array([100, 50, 400, 300])],
    [np.array([110, 50, 410, 300])],
    [np.array([120, 50, 420, 300])],
    [np.array([500, 50, 800, 300])],
    [np.array([510, 50, 810, 300])],
    [np.array([520, 50, 820, 300])]
]

track_history = track_wagons_over_time(detections_per_frame)

print("Total wagons counted:", count_unique_wagons(track_history))


In [None]:
def build_tracking_log(track_history):
    log = []

    for entry in track_history:
        for tid, box in entry["tracks"].items():
            log.append({
                "frame": entry["frame_index"],
                "track_id": tid,
                "bbox": box
            })

    return log


In [None]:
def multi_camera_wagon_count(camera_histories: Dict[str, List[Dict]]):
    global_ids = set()

    for cam_id, history in camera_histories.items():
        for entry in history:
            for tid in entry["tracks"]:
                global_ids.add(f"{cam_id}_{tid}")

    return len(global_ids)
