In [None]:
import csv
import math
import os
import time
from datetime import datetime  # Needed for DeepSORT 'today' argument potentially

import cv2
import numpy as np
import torch

# Import DeepSORT
from deep_sort_realtime.deepsort_tracker import DeepSort
from ultralytics import YOLO

# # Helper function to calculate IoU (No longer needed with 'others' method)
# def calculate_iou(box1, box2):
#     # ... (IoU calculation code removed) ...

class SyringeVolumeEstimator:
    # Added default values based on deepsort_tracker.py documentation
    def __init__(self,
                 max_iou_distance=0.7,
                 max_age=30,
                 n_init=3,
                 nms_max_overlap=1.0,
                 max_cosine_distance=0.2,
                 nn_budget=None,
                 embedder="mobilenet", # Use default embedder
                 half=True, # Use half precision for mobilenet if GPU
                 bgr=True, # Input is BGR
                 embedder_model_name=None,
                 embedder_wts=None,
                 polygon=False,
                 ):
        """Initialize the YOLO model, DeepSORT tracker, device, and possible diameters, using parameters based on deep-sort-realtime v1.3.2 docs."""
        # Load and evaluate the YOLO *pose* model
        self.model_path = "runs/pose/train-pose11n-v25-P50/weights/best.pt"
        print(f"Loading YOLO model from: {self.model_path}")
        if not os.path.exists(self.model_path):
             raise FileNotFoundError(f"YOLO model not found at {self.model_path}. Please ensure the path is correct.")
        self.model = YOLO(self.model_path)

        # Set device based on availability
        if torch.cuda.is_available():
            self.device = "cuda"
        elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # More robust check for MPS
            self.device = "mps"
        else:
            self.device = "cpu"
        self.model.to(self.device)
        print(f"Using device: {self.device}")

        # Determine if GPU should be used for the embedder
        use_embedder_gpu = (self.device != 'cpu') # True if cuda or mps

        # Initialize DeepSORT tracker based on v1.3.2 documentation
        try:
            self.tracker = DeepSort(
                max_iou_distance=max_iou_distance,
                max_age=max_age,
                n_init=n_init,
                nms_max_overlap=nms_max_overlap,
                max_cosine_distance=max_cosine_distance,
                nn_budget=nn_budget,
                # override_track_class=None, # Optional: Use a custom Track class
                embedder=embedder,          # Specify the embedder type
                half=half,                  # Use half precision if GPU?
                bgr=bgr,                    # Input color format
                embedder_gpu=use_embedder_gpu, # Set GPU usage for embedder
                embedder_model_name=embedder_model_name, # e.g., for torchreid
                embedder_wts=embedder_wts,  # Path to specific weights
                polygon=polygon,            # Using bounding boxes, not polygons
                today=datetime.now().date() # Optional: for daily track ID reset
            )
            print("DeepSORT tracker initialized successfully with specified parameters.")
        except TypeError as e:
             print(f"Error initializing DeepSORT: {e}")
             print("Please double-check deep-sort-realtime installation and compatibility with provided arguments.")
             raise
        except Exception as e:
             print(f"An unexpected error occurred during DeepSORT initialization: {e}")
             raise


        # Define possible syringe diameters in cm
        self.possible_diameters = [0.45, 1.0, 1.25, 2.0]

    def draw_volume_table(self, frame: np.ndarray, volumes: list, table_x: int, table_y: int, track_id: str) -> None:
        """Draw a table on the frame showing diameters and volumes with track ID."""
        table_width = 250
        table_height = 200  # For header, track ID, and 4 diameters

        # Ensure coordinates are within frame bounds
        frame_h, frame_w = frame.shape[:2]
        table_x = max(0, min(table_x, frame_w - table_width))
        table_y = max(0, min(table_y, frame_h - table_height))

        # Draw light gray background
        cv2.rectangle(frame, (table_x, table_y), (table_x + table_width, table_y + table_height), (220, 220, 220), -1)
        # Draw track ID
        cv2.putText(frame, f"Syringe ID: {track_id}", (table_x + 10, table_y + 20),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
        # Draw headers
        cv2.putText(frame, "Diameter", (table_x + 10, table_y + 50),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
        cv2.putText(frame, "mL", (table_x + 150, table_y + 50),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
        # Draw rows for each diameter and volume
        for i, (D, volume) in enumerate(volumes):
            y = table_y + 80 + i * 30
            cv2.putText(frame, f"{D:.2f}", (table_x + 10, y),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
            if volume is not None and not math.isnan(volume):
                cv2.putText(frame, f"{volume:.2f}", (table_x + 150, y),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
            else:
                cv2.putText(frame, "N/A", (table_x + 150, y),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)

    def process_frame(self, frame: np.ndarray, timestamp: float, writer: csv.writer) -> np.ndarray:
        """Process frame: Detect(YOLO) -> Track(DeepSORT using 'others') -> Calc Vol -> Log -> Draw."""

        # 1. Run YOLOv8 Pose Prediction
        results = self.model.predict(source=frame, device=self.device, verbose=False, conf=0.6)
        result = results[0] # Get results for the single frame
        annotated_frame = frame.copy() # Work on a copy for annotations

        # 2. Prepare Detections for DeepSORT & Supplementary Data ('others')
        detections_for_deepsort = []
        other_data_for_deepsort = [] # List to hold keypoints corresponding to detections

        if result.boxes is not None and len(result.boxes) > 0:
            for i, box in enumerate(result.boxes):
                # Ensure keypoints exist for this detection index
                if result.keypoints is None or len(result.keypoints.xy) <= i or result.keypoints.xy[i].shape[0] < 4:
                    continue

                xyxy = box.xyxy[0].cpu().numpy()
                conf = box.conf[0].cpu().numpy()
                cls = int(box.cls[0].cpu().numpy())

                # Convert xyxy to [left, top, width, height] for DeepSORT
                x1, y1, x2, y2 = map(int, xyxy)
                w = x2 - x1
                h = y2 - y1
                if w <= 0 or h <= 0: continue # Skip invalid boxes
                bbox_ltwh = [x1, y1, w, h]

                # Add detection tuple: ( [left,top,w,h], confidence, detection_class )
                detections_for_deepsort.append((bbox_ltwh, float(conf), cls))

                # Store corresponding keypoints in the 'others' list
                kpts = result.keypoints.xy[i].cpu().numpy()
                other_data_for_deepsort.append(kpts)

        # 3. Update DeepSORT Tracker, passing keypoints via 'others'
        if detections_for_deepsort:
             # Pass frame for built-in embedder, and keypoints via 'others'
            tracks = self.tracker.update_tracks(detections_for_deepsort, frame=frame, others=other_data_for_deepsort)
        else:
            # Still call update_tracks with empty list if no detections
            tracks = self.tracker.update_tracks([], frame=frame, others=[])

        # 4. Process Active Tracks from DeepSORT
        if not tracks:
            # Log an empty row if no objects are tracked
            row = [timestamp, np.nan, np.nan, np.nan] + [np.nan for _ in self.possible_diameters]
            writer.writerow(row)
            return annotated_frame

        processed_track_ids = set() # Keep track of IDs processed in this frame

        for track in tracks:
            if not track.is_confirmed():
                continue

            track_id = track.track_id
            if track_id in processed_track_ids:
                continue
            processed_track_ids.add(track_id)

            # Get the *current* bounding box estimate from DeepSORT [x1, y1, x2, y2] using to_ltrb
            # Use to_ltrb as per documentation (instead of potentially confusing to_tlbr)
            bbox_ltrb_track = track.to_ltrb()
            if bbox_ltrb_track is None: continue # Skip if track has no valid bbox yet
            x1_track, y1_track, x2_track, y2_track = map(int, bbox_ltrb_track)

            # --- Retrieve Keypoints using get_det_supplementary ---
            kpts = track.get_det_supplementary()
            # ----------------------------------------------------

            if kpts is None or not isinstance(kpts, np.ndarray) or kpts.shape[0] < 4:
                 # Draw box but log NaN for volumes if keypoints aren't found/valid
                 cv2.rectangle(annotated_frame, (x1_track, y1_track), (x2_track, y2_track), (0, 255, 0), 2)
                 cv2.putText(annotated_frame, f"ID: {track_id} (No Kpts)", (x1_track, y1_track - 10),
                             cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

                 center_x = (x1_track + x2_track) / 2.0
                 center_y = (y1_track + y2_track) / 2.0
                 volumes = [np.nan] * len(self.possible_diameters)
                 row = [timestamp, track_id, center_x, center_y] + volumes
                 writer.writerow(row)
                 continue

            # We have the track ID, the current bbox (ltrb), and the keypoints (kpts)
            try:
                center_x = (x1_track + x2_track) / 2.0
                center_y = (y1_track + y2_track) / 2.0
                ll_point, ul_point, ur_point, lr_point = kpts[:4]

                # Calculate width and height in pixels from keypoints
                width_pixels = (np.linalg.norm(lr_point - ll_point) + np.linalg.norm(ur_point - ul_point)) / 2.0
                height_pixels = (np.linalg.norm(ul_point - ll_point) + np.linalg.norm(ur_point - lr_point)) / 2.0

                volumes = []
                if width_pixels <= 1e-6 or height_pixels <= 1e-6: # Check for near-zero dimensions
                    volumes = [np.nan] * len(self.possible_diameters)
                    # print(f"Warning: Near-zero dimensions calculated for track {track_id}. Pixels: w={width_pixels:.2f}, h={height_pixels:.2f}")
                else:
                    # Calculate volumes for all possible diameters
                    for D in self.possible_diameters:
                        scale_factor_D = D / width_pixels
                        H_cm = height_pixels * scale_factor_D
                        if 0 < H_cm <= 30:  # Validate height (max 30 cm)
                            volume_D = math.pi * (D / 2.0) ** 2 * H_cm
                        else:
                            volume_D = np.nan
                        volumes.append(volume_D)

                # Log data to CSV
                row = [timestamp, track_id, center_x, center_y] + volumes
                writer.writerow(row)

                # Draw DeepSORT's bounding box and ID
                cv2.rectangle(annotated_frame, (x1_track, y1_track), (x2_track, y2_track), (0, 255, 0), 2) # Green for tracked box
                cv2.putText(annotated_frame, f"ID: {track_id}", (x1_track, y1_track - 10),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

                # Draw volume table using DeepSORT's box coordinates
                table_x = x2_track + 10  # Right of bounding box
                table_y = y1_track       # Top of bounding box
                self.draw_volume_table(annotated_frame, list(zip(self.possible_diameters, volumes)), table_x, table_y, track_id)

            except Exception as e:
                print(f"Error during volume calculation/drawing for track {track_id}: {e}")
                # Draw the box even if calculation fails
                cv2.rectangle(annotated_frame, (x1_track, y1_track), (x2_track, y2_track), (0, 0, 255), 2) # Red for error
                cv2.putText(annotated_frame, f"ID: {track_id} (Error)", (x1_track, y1_track - 10),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
                # Log NaN for volumes on error
                center_x = (x1_track + x2_track) / 2.0
                center_y = (y1_track + y2_track) / 2.0
                volumes = [np.nan] * len(self.possible_diameters)
                row = [timestamp, track_id, center_x, center_y] + volumes
                writer.writerow(row)
                continue

        # If no detections were made by YOLO initially AND no tracks were updated/confirmed
        if (result.boxes is None or len(result.boxes) == 0) and not processed_track_ids:
             # This might be redundant if tracker handles empty updates correctly, but safe to keep
            row = [timestamp, np.nan, np.nan, np.nan] + [np.nan for _ in self.possible_diameters]
            writer.writerow(row)

        return annotated_frame

    def run(self, input_source='webcam', video_path=None, csv_path='syringe_data_deepsort_v1.3.2.csv'):
        """Run the main loop using DeepSORT v1.3.2 features."""

        # Set up video capture based on input source
        if input_source == 'video':
            if video_path is None:
                raise ValueError("video_path must be provided for input_source='video'")
            if not os.path.exists(video_path):
                 raise FileNotFoundError(f"Video file not found at {video_path}")
            cap = cv2.VideoCapture(video_path)
            fps = cap.get(cv2.CAP_PROP_FPS)
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            fourcc = cv2.VideoWriter_fourcc(*'mp4v') # or 'XVID', 'MJPG'
            base, ext = os.path.splitext(video_path)
            output_path = f"{base}_processed_deepsort_v1.3.2{ext}"
            out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
            print(f"Processing video: {video_path}, saving to: {output_path}")
        else:  # webcam
            cap = cv2.VideoCapture(0)
            out = None
            print("Starting webcam input...")


        if not cap.isOpened():
            raise IOError(f"Cannot open {'video file ' + video_path if input_source == 'video' else 'webcam'}")

        # Check if CSV exists, create with header if not
        write_header = not os.path.exists(csv_path) or os.path.getsize(csv_path) == 0
        with open(csv_path, 'a', newline='') as csvfile:
            writer = csv.writer(csvfile)
            if write_header:
                header = ['timestamp', 'track_id', 'center_x', 'center_y'] + [f'volume_D{D:.2f}' for D in self.possible_diameters]
                writer.writerow(header)
                print(f"Created/Opened CSV for appending: {csv_path}")

            frame_count = 0
            try:
                while True:
                    ret, frame = cap.read()
                    if not ret:
                        print("End of input source or cannot read frame.")
                        break

                    frame_count += 1
                    if input_source == 'video':
                        timestamp = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
                    else:
                        timestamp = time.time()

                    start_time = time.time()
                    annotated_frame = self.process_frame(frame, timestamp, writer)
                    end_time = time.time()
                    fps_current = 1.0 / (end_time - start_time) if (end_time - start_time) > 0 else 0

                    cv2.putText(annotated_frame, f"FPS: {fps_current:.2f}", (10, 30),
                                cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

                    if out is not None:
                        out.write(annotated_frame)

                    cv2.imshow('Syringe Volume Measurement (DeepSORT v1.3.2)', annotated_frame)

                    key = cv2.waitKey(1) & 0xFF
                    if key == ord('q'):
                        print("Quitting...")
                        break
                    elif key == ord('p'): # Pause functionality
                         print("Paused. Press any key to continue...")
                         cv2.waitKey(-1)

            except Exception as e:
                 print(f"\nAn error occurred during processing loop (frame {frame_count}): {e}")
                 import traceback
                 traceback.print_exc()
            finally:
                print(f"Processed {frame_count} frames.")
                cap.release()
                if out is not None:
                    out.release()
                    print(f"Output video saved to {output_path}")
                cv2.destroyAllWindows()
                print("Resources released.")


if __name__ == "__main__":
    # Ensure you have the necessary libraries installed:
    # pip install ultralytics opencv-python deep-sort-realtime==1.3.2 torch numpy torchvision
    # For GPU embedder (default): ensure PyTorch with CUDA/MPS support is installed.
    # You might also need specific ONNX runtime if using certain embedders: pip install onnxruntime (or onnxruntime-gpu)
    try:
        # Initialize with parameters from DeepSORT docs (adjust max_age etc. as needed)
        estimator = SyringeVolumeEstimator(
            max_age=50,         # Increased max_age slightly
            n_init=3,
            max_cosine_distance=0.3, # Adjusted cosine distance threshold
            nn_budget=None,     # No budget limit on appearance features
            embedder="mobilenet", # Using the default built-in pytorch embedder
            half=True,          # Use FP16 for embedder if on GPU
            bgr=True            # OpenCV provides BGR frames
            )

        # --- Choose Input Source ---
        USE_WEBCAM = True # Set to False to use video file

        if USE_WEBCAM:
             estimator.run(input_source='webcam')
        else:
             # --- IMPORTANT: SET YOUR VIDEO FILE PATH HERE ---
             video_file = 'IMG_4952.mov' # <--- CHANGE THIS PATH
             # ---------------------------------------------
             if os.path.exists(video_file):
                 estimator.run(input_source='video', video_path=video_file)
             else:
                 print(f"Video file not found: {video_file}. Check the path.")
                 print("Please set the correct 'video_file' path in the script.")

    except FileNotFoundError as fnf_error:
         print(f"Initialization failed: {fnf_error}")
    except ImportError as imp_error:
         print(f"Import Error: {imp_error}. Have you installed all required libraries (ultralytics, opencv-python, deep-sort-realtime==1.3.2, torch, numpy, torchvision)?")
    except Exception as main_error:
         print(f"An unexpected error occurred: {main_error}")
         import traceback
         traceback.print_exc()

Loading YOLO model from: runs/pose/train-pose11n-v25-P50/weights/best.pt
Using device: mps
DeepSORT tracker initialized successfully with specified parameters.
Starting webcam input...


2025-03-27 19:48:05.942 python[16327:5316961] +[IMKClient subclass]: chose IMKClient_Modern
2025-03-27 19:48:05.942 python[16327:5316961] +[IMKInputSession subclass]: chose IMKInputSession_Modern
