In [None]:
import sys
import time
import random
import datetime
import cv2
import numpy as np
import mss
import psutil
import pyautogui # Keep for potential future use, but not used in current drawing logic
from PySide6 import QtCore, QtGui, QtWidgets
from PySide6.QtCore import Qt, QTimer, QPoint, QPointF, QThread, Signal, Slot, QMutex, QMutexLocker, QObject, QRect
from PySide6.QtGui import QColor, QFont, QPainter, QPen, QFontDatabase, QImage, QPixmap, QPolygonF
from PySide6.QtWidgets import (QApplication, QWidget, QMainWindow, QVBoxLayout, QGridLayout,
                             QComboBox, QSpinBox, QDoubleSpinBox, QPushButton, QLabel,
                             QTextEdit, QCheckBox, QGroupBox)
import logging
from ultralytics import YOLO
import torch
import os
import collections # Added for deque

# --- Configurations ---
UPDATE_INTERVAL_MS = 500  # Interval for updating HUD text elements
SCANLINE_SPEED_MS = 15    # Speed of the scanline effect
CAPTURE_INTERVAL_MS = 10  # Target capture interval (100 FPS)
SYSTEM_INFO_INTERVAL_MS = 1000 # Interval for updating system info
DEFAULT_CONFIDENCE_THRESHOLD = 0.4 # Default confidence threshold for detection
DEFAULT_NMS_THRESHOLD = 0.3        # Default NMS threshold
DEFAULT_POSE_CONFIDENCE_THRESHOLD = 0.5 # Default confidence for pose keypoints
FONT_NAME = "Press Start 2P"
FALLBACK_FONT = "Monospace"
FONT_SIZE_SMALL = 10
FONT_SIZE_MEDIUM = 12
RED_COLOR = QColor(255, 0, 0)
GREEN_COLOR = QColor(0, 255, 0)
BLUE_COLOR = QColor(0, 0, 255)
YELLOW_COLOR = QColor(255, 255, 0)
CYAN_COLOR = QColor(0, 255, 255)
MAGENTA_COLOR = QColor(255, 0, 255)
ORANGE_COLOR = QColor(255, 165, 0) # Added orange
TEXT_COLOR = QColor(255, 0, 0)

# --- Smoothing and Prediction Parameters ---
TARGET_CENTER_SMOOTHING_FACTOR = 0.3 # Alpha for EMA (lower = smoother, more lag) - For position
PATH_HISTORY_LENGTH = 30             # Number of past points to store for path tracking
# --- Velocity Smoothing ---
VELOCITY_SMOOTHING_FACTOR = 0.4      # Alpha for EMA for velocity vector (renamed from TRAJECTORY_SMOOTHING_FACTOR)
VELOCITY_HISTORY_LENGTH = 15         # Number of recent velocity samples to store for acceleration calculation
# --- Acceleration Smoothing ---
ACCEL_SMOOTHING_FACTOR = 0.5         # Alpha for EMA for acceleration vector
# --- Trajectory Prediction ---
TRAJECTORY_PREDICTION_POINTS = 5     # Number of recent points to use for initial velocity calculation (kept for reference, but direct velocity smoothing is primary)
TRAJECTORY_PREDICTION_DURATION = 0.5 # Seconds into the future to predict (slightly increased)
PREDICTION_TIME_STEP = 1.0 / 30.0    # Time step for physics prediction loop (e.g., 30 steps per second)


# --- Model Names ---
OBJECT_DETECTION_MODEL = 'yolov8n.pt'
POSE_ESTIMATION_MODEL = 'yolov8n-pose.pt'

# --- Pose Estimation Keypoint Connections (COCO format) ---
POSE_CONNECTIONS = [
    (0, 1), (0, 2), (1, 3), (2, 4),  # Head
    (5, 6), (5, 7), (7, 9), (6, 8), (8, 10),  # Torso/Arms
    (11, 12), (5, 11), (6, 12), # Shoulders to Hips
    (11, 13), (13, 15), (12, 14), (14, 16)  # Legs
]
POSE_COLORS = [QColor(255, 0, 0), QColor(0, 255, 0), QColor(0, 0, 255),
               QColor(255, 255, 0), QColor(0, 255, 255), QColor(255, 0, 255)]


# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(threadName)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Helper Functions ---
def random_hex(length):
    """Generate a random hexadecimal string of specified length."""
    return ''.join(random.choice('ABCDEF0123456789') for _ in range(length))

def get_gpu_info():
    """Fetches GPU name and memory usage if CUDA is available."""
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        try:
            # Use recommended mem_get_info if available
            total_mem, free_mem = torch.cuda.mem_get_info(0)
            used_mem = total_mem - free_mem
            mem_usage = f"{(used_mem / (1024**3)):.1f}/{(total_mem / (1024**3)):.1f} GB"
        except AttributeError:
            # Fallback for older PyTorch versions
            total_mem = torch.cuda.get_device_properties(0).total_memory
            used_mem = torch.cuda.memory_allocated(0)
            mem_usage = f"{(used_mem / (1024**3)):.1f}/{(total_mem / (1024**3)):.1f} GB (Allocated)"
        except Exception as e:
             logger.error(f"Error getting GPU memory info: {e}", exc_info=True)
             mem_usage = "N/A (Error reading memory)"
        return gpu_name, mem_usage
    return "N/A (CUDA not available)", "N/A"

# --- Custom Logging Handler ---
class QTextEditLogger(logging.Handler):
    """Sends log records to a QTextEdit widget in a thread-safe manner."""
    def __init__(self, text_edit):
        super().__init__()
        self.text_edit = text_edit
        self.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))

    def emit(self, record):
        msg = self.format(record)
        # Use invokeMethod for thread-safe GUI updates
        QtCore.QMetaObject.invokeMethod(
            self.text_edit,
            "append",
            QtCore.Qt.QueuedConnection, # Ensure update happens in GUI thread
            QtCore.Q_ARG(str, msg)
        )

# --- Screen Capture Thread ---
class ScreenCaptureThread(QThread):
    """Captures the screen at a specified interval."""
    frame_ready = Signal(np.ndarray, float) # Emit frame (BGR) and capture timestamp
    status_update = Signal(str)

    def __init__(self, monitor_spec):
        super().__init__()
        self.monitor_spec = monitor_spec
        self.running = False
        self.sct = None
        self._capture_interval_ms = CAPTURE_INTERVAL_MS
        self._lock = QMutex() # Mutex for thread-safe access to interval

    def run(self):
        self.running = True
        try:
            self.sct = mss.mss() # Initialize screen capture object
            logger.info(f"Screen capture started for monitor: {self.monitor_spec}")
            # last_capture_time = time.perf_counter() # Use high-resolution timer

            while self.running:
                capture_start_time = time.perf_counter()
                try:
                    sct_img = self.sct.grab(self.monitor_spec) # Grab the screen region
                    frame = np.array(sct_img) # Convert to numpy array

                    # Ensure frame is BGR format (OpenCV standard)
                    if frame.shape[2] == 4: # BGRA
                        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)
                    elif frame.shape[2] == 3: # BGR (or RGB, assume BGR)
                        frame_bgr = frame
                    else:
                        logger.warning(f"Unexpected frame channel count: {frame.shape[2]}")
                        continue # Skip this frame

                    current_time = time.time() # Use system time for timestamping frame data
                    self.frame_ready.emit(frame_bgr, current_time) # Emit the captured frame and timestamp

                except mss.ScreenShotError as e:
                    self.status_update.emit(f"Screen capture error: {e}")
                    logger.error(f"Screen capture error: {e}")
                    time.sleep(1) # Wait before retrying on screen capture error
                except Exception as e:
                    self.status_update.emit(f"Unexpected screen capture error: {e}")
                    logger.error(f"Unexpected screen capture error: {e}", exc_info=True)
                    time.sleep(1) # Wait on other errors

                # Calculate sleep time to maintain target capture interval
                elapsed = time.perf_counter() - capture_start_time
                with QMutexLocker(self._lock): # Lock access to interval variable
                    interval_sec = self._capture_interval_ms / 1000.0
                sleep_time = max(0, interval_sec - elapsed)
                if sleep_time > 0:
                    time.sleep(sleep_time) # Sleep for the remaining interval time

        finally:
            if self.sct:
                self.sct.close() # Clean up screen capture object
            logger.info("Screen capture stopped.")

    def stop(self):
        """Signals the thread to stop running."""
        self.running = False

    @Slot(int)
    def update_capture_interval(self, interval):
        """Updates the screen capture interval (thread-safe)."""
        with QMutexLocker(self._lock): # Ensure thread-safe update
            if interval > 0:
                logger.info(f"Updating capture interval to {interval} ms")
                self._capture_interval_ms = interval
            else:
                 logger.warning(f"Ignoring invalid capture interval: {interval} ms")

# --- Base Worker Thread for YOLO Models ---
class BaseYoloWorker(QThread):
    """Base class for YOLO detection and pose estimation threads."""
    status_update = Signal(str) # Signal for status messages

    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name
        self.model = None # YOLO model instance
        self.input_frame = None # Latest frame received
        self.frame_timestamp = 0.0 # Timestamp of the latest frame
        self.running = False # Flag to control the thread loop
        self._enabled = True # Flag to enable/disable processing
        self._confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD
        self._nms_threshold = DEFAULT_NMS_THRESHOLD # Used by object detection
        self._frame_lock = QMutex() # Mutex for thread-safe frame access
        self.device = None # PyTorch device ('cuda' or 'cpu')
        self.processing_time_ms = 0.0 # Time taken for the last inference

    def load_model(self):
        """Loads the YOLO model onto the appropriate device."""
        self.status_update.emit(f"Loading model {self.model_name}...")
        logger.info(f"Attempting to load YOLO model: {self.model_name}")
        try:
            self.model = YOLO(self.model_name) # Load the model using ultralytics library
            # Determine device (GPU if available, else CPU)
            if torch.cuda.is_available():
                self.device = torch.device('cuda')
                self.model.to(self.device) # Move model to GPU
                device_name = torch.cuda.get_device_name(0)
                logger.info(f"Model '{self.model_name}' loaded on GPU: {device_name}.")
                self.status_update.emit(f"Model '{self.model_name}' loaded on GPU.")
            else:
                self.device = torch.device('cpu')
                self.model.to(self.device) # Move model to CPU
                logger.info(f"Model '{self.model_name}' loaded on CPU.")
                self.status_update.emit(f"Model '{self.model_name}' loaded on CPU.")
            return True
        except Exception as e:
            error_msg = f"Failed to load model '{self.model_name}': {e}"
            logger.error(error_msg, exc_info=True)
            self.status_update.emit(error_msg)
            self.model = None
            self.device = None
            return False

    @Slot(np.ndarray, float)
    def set_frame(self, frame, timestamp):
        """Receives a new frame for processing (thread-safe)."""
        with QMutexLocker(self._frame_lock):
            # Keep only the latest frame and its timestamp
            self.input_frame = frame.copy() # Copy frame to avoid issues if source changes
            self.frame_timestamp = timestamp

    def run(self):
        """Main processing loop for the thread."""
        if not self.load_model(): # Attempt to load the model first
            self.running = False
            return # Exit if model loading fails

        self.running = True
        logger.info(f"{self.__class__.__name__} thread started.")
        while self.running:
            if self._enabled: # Only process if enabled
                frame_to_process = None
                timestamp_to_process = 0.0
                # Safely get the latest frame
                with QMutexLocker(self._frame_lock):
                    if self.input_frame is not None:
                        frame_to_process = self.input_frame
                        timestamp_to_process = self.frame_timestamp
                        self.input_frame = None # Consume the frame so it's not processed again

                # Process the frame if one was available
                if frame_to_process is not None and self.model is not None:
                    start_time = time.perf_counter()
                    try:
                        # Perform inference using the loaded model
                        results = self.model(
                            frame_to_process,
                            conf=self._confidence_threshold, # Apply confidence threshold
                            iou=self._nms_threshold, # Apply NMS threshold
                            verbose=False, # Suppress ultralytics console output
                            device=self.device # Specify the device for inference
                        )
                        # Process the results (implemented in subclasses)
                        self.process_results(results, timestamp_to_process)

                    except Exception as e:
                        logger.error(f"{self.__class__.__name__} error: {e}", exc_info=True)
                        self.status_update.emit(f"{self.__class__.__name__} error: {e}")
                    finally:
                         # Calculate processing time for this frame
                        end_time = time.perf_counter()
                        self.processing_time_ms = (end_time - start_time) * 1000

                else:
                    # No frame available, sleep briefly to avoid busy-waiting
                    self.msleep(5) # Sleep for 5 milliseconds
            else:
                # Thread is disabled, sleep longer
                self.msleep(50)

        # Cleanup when the thread loop exits
        logger.info(f"{self.__class__.__name__} thread stopped.")
        self.model = None # Release model object
        if torch.cuda.is_available():
            torch.cuda.empty_cache() # Clear GPU cache if CUDA was used

    def process_results(self, results, timestamp):
        """Placeholder: Subclasses must implement this to process model output."""
        raise NotImplementedError("Subclasses must implement process_results")

    def stop(self):
        """Signals the thread to stop running."""
        self.running = False

    @Slot(bool)
    def set_enabled(self, enabled):
        """Enables or disables processing in the thread."""
        self._enabled = enabled
        logger.info(f"{self.__class__.__name__} {'enabled' if enabled else 'disabled'}")
        if not enabled:
             # Clear potential pending frame when disabled to prevent processing old data
             with QMutexLocker(self._frame_lock):
                 self.input_frame = None

    @Slot(float)
    def update_confidence_threshold(self, threshold):
        """Updates the confidence threshold used for inference."""
        self._confidence_threshold = threshold
        logger.info(f"{self.__class__.__name__} confidence threshold updated to {threshold:.2f}")

    @Slot(float)
    def update_nms_threshold(self, threshold):
        """Updates the NMS threshold (primarily for object detection)."""
        self._nms_threshold = threshold
        logger.info(f"{self.__class__.__name__} NMS threshold updated to {threshold:.2f}")

    def get_processing_time(self):
        """Returns the processing time of the last inference."""
        return self.processing_time_ms

# --- Detection Thread ---
class DetectionThread(BaseYoloWorker):
    """Performs object detection using a YOLO model."""
    # Emits: list of (label, confidence, box_tuple), timestamp, processing_time_ms
    detections_ready = Signal(list, float, float)

    def __init__(self, model_name=OBJECT_DETECTION_MODEL):
        super().__init__(model_name)
        # Override base defaults if needed (already set in base class)
        # self._confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD
        # self._nms_threshold = DEFAULT_NMS_THRESHOLD

    def process_results(self, results, timestamp):
        """Processes YOLO object detection results and emits them."""
        detections = []
        if results and results[0]: # Check if results are valid
            boxes = results[0].boxes # Access the detected boxes
            for box in boxes:
                # Extract box coordinates, confidence, and class ID
                # Move results to CPU before converting to Python types for signal emission
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                confidence = box.conf[0].cpu().numpy()
                class_id = int(box.cls[0].cpu().numpy())
                # Get the class label name from the model's names dictionary
                label = self.model.names.get(class_id, f"ID:{class_id}")
                # Append detection info as a tuple
                detections.append((label, float(confidence), (int(x1), int(y1), int(x2), int(y2))))

        # Emit the processed detections, timestamp, and processing time
        self.detections_ready.emit(detections, timestamp, self.get_processing_time())

# --- Pose Estimation Thread ---
class PoseEstimationThread(BaseYoloWorker):
    """Performs pose estimation using a YOLO pose model."""
    # Emits: list of (keypoints_array, box_tuple, confidence), timestamp, processing_time_ms
    # keypoints_array is N x 3 (x, y, confidence) for N keypoints
    poses_ready = Signal(list, float, float)

    def __init__(self, model_name=POSE_ESTIMATION_MODEL):
        super().__init__(model_name)
        # Use pose-specific confidence threshold
        self._confidence_threshold = DEFAULT_POSE_CONFIDENCE_THRESHOLD
        # NMS threshold might be used differently or less relevant for pose
        self._nms_threshold = DEFAULT_NMS_THRESHOLD

    def process_results(self, results, timestamp):
        """Processes YOLO pose estimation results and emits them."""
        poses = []
        if results and results[0]: # Check if results are valid
            keypoints_list = results[0].keypoints # Access keypoints directly from results
            boxes = results[0].boxes # Get associated bounding boxes if needed

            # Iterate through each detected pose instance
            for i, kpts in enumerate(keypoints_list):
                # kpts object contains keypoint data (xy coordinates, confidence)
                # Combine them into an N x 3 array (x, y, conf) on the CPU
                kpts_data_cpu = kpts.data[0].cpu().numpy() # Get tensor data, move to CPU

                # Extract bounding box and overall confidence for this pose instance
                box_data = boxes[i]
                x1, y1, x2, y2 = box_data.xyxy[0].cpu().numpy()
                pose_confidence = box_data.conf[0].cpu().numpy() # Overall confidence for the detected person/pose

                # Optional: Filter individual keypoints by confidence before emitting
                # e.g., kpts_data_cpu[kpts_data_cpu[:, 2] < KEYPOINT_VISIBILITY_THRESHOLD] = [0, 0, 0]

                # Append pose info: keypoints array, bounding box, overall confidence
                poses.append((kpts_data_cpu, (int(x1), int(y1), int(x2), int(y2)), float(pose_confidence)))

        # Emit the processed poses, timestamp, and processing time
        self.poses_ready.emit(poses, timestamp, self.get_processing_time())


# --- Optical Flow Thread ---
class OpticalFlowThread(QThread):
    """Calculates sparse optical flow (GPU accelerated if OpenCV CUDA is available)."""
    flow_ready = Signal(list) # List of (start_point_tuple, end_point_tuple)
    status_update = Signal(str)

    def __init__(self):
        super().__init__()
        self.running = False
        self._enabled = False # Disabled by default
        self._frame_lock = QMutex()
        self.current_frame = None
        self.use_gpu = False # Flag indicating if GPU is used
        # GPU components
        self.gpu_detector = None
        self.gpu_lk_flow = None
        self.prev_gray_gpu = None
        self.prev_points_gpu = None
        # CPU components and parameters
        self.cpu_lk_params = dict(winSize=(21, 21), maxLevel=3,
                                  criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03))
        self.cpu_feature_params = dict(maxCorners=100, qualityLevel=0.2, minDistance=7, blockSize=7)
        self.prev_gray_cpu = None
        self.prev_points_cpu = None
        self.processing_time_ms = 0.0

        # --- Check for OpenCV CUDA support ---
        try:
            if cv2.cuda.getCudaEnabledDeviceCount() > 0:
                logger.info("OpenCV CUDA detected. Attempting to use GPU for Optical Flow.")
                # Ensure parameters match CPU version if desired, converting types as needed
                gpu_feature_params = self.cpu_feature_params.copy()
                gpu_feature_params['qualityLevel'] = float(gpu_feature_params['qualityLevel'])
                gpu_feature_params['minDistance'] = float(gpu_feature_params['minDistance'])

                # Create GPU detector and flow calculator
                self.gpu_detector = cv2.cuda.createGoodFeaturesToTrackDetector(cv2.CV_8UC1, **gpu_feature_params)
                self.gpu_lk_flow = cv2.cuda.SparsePyrLKOpticalFlow_create(
                    winSize=self.cpu_lk_params['winSize'],
                    maxLevel=self.cpu_lk_params['maxLevel'],
                    iters=self.cpu_lk_params['criteria'][2] # Number of iterations from criteria
                )
                self.use_gpu = True
                logger.info("Successfully initialized GPU Optical Flow components.")
                self.status_update.emit("Optical Flow: Using GPU")
            else:
                logger.info("No OpenCV CUDA devices found. Using CPU for Optical Flow.")
                self.status_update.emit("Optical Flow: Using CPU")
        except AttributeError:
            logger.warning("OpenCV CUDA module not found or unavailable in this build. Using CPU for Optical Flow.")
            self.status_update.emit("Optical Flow: Using CPU (CUDA module unavailable)")
            self.use_gpu = False
        except Exception as e:
            logger.error(f"Error initializing OpenCV CUDA components: {e}. Falling back to CPU.", exc_info=True)
            self.use_gpu = False
            self.status_update.emit("Optical Flow: Error initializing GPU, using CPU")


    @Slot(np.ndarray, float) # Accept timestamp, though not directly used in flow calc
    def set_frame(self, frame, timestamp):
        """Receives a new frame for processing (thread-safe)."""
        with QMutexLocker(self._frame_lock):
            self.current_frame = frame.copy()

    def run(self):
        """Main processing loop for optical flow calculation."""
        self.running = True
        logger.info(f"Optical Flow thread started (GPU Enabled: {self.use_gpu}).")
        while self.running:
            if self._enabled: # Only process if enabled
                frame = None
                # Safely get the latest frame
                with QMutexLocker(self._frame_lock):
                    if self.current_frame is not None:
                        frame = self.current_frame
                        self.current_frame = None # Consume frame

                if frame is not None:
                    start_time = time.perf_counter()
                    try:
                        # Choose GPU or CPU path
                        if self.use_gpu:
                            self.run_gpu(frame)
                        else:
                            self.run_cpu(frame)
                    except cv2.error as e:
                         # Handle OpenCV specific errors
                         logger.error(f"OpenCV error in Optical Flow: {e}", exc_info=True)
                         self.status_update.emit(f"Optical Flow Error: {e}")
                         self.reset_state() # Reset state on error
                         self.msleep(100) # Wait briefly after error
                    except Exception as e:
                         # Handle other unexpected errors
                         logger.error(f"Unexpected error in Optical Flow: {e}", exc_info=True)
                         self.status_update.emit(f"Optical Flow Error: {e}")
                         self.reset_state()
                         self.msleep(100)
                    finally:
                        end_time = time.perf_counter()
                        self.processing_time_ms = (end_time - start_time) * 1000
                else:
                     self.msleep(5) # No frame, sleep briefly
            else:
                # Disabled: Reset state if it wasn't already reset
                if self.prev_gray_cpu is not None or self.prev_gray_gpu is not None:
                     self.reset_state()
                self.msleep(50) # Sleep longer when disabled

        logger.info("Optical Flow thread stopped.")
        self.reset_state() # Ensure state is clean on exit

    def run_gpu(self, frame):
        """Calculates optical flow using GPU."""
        frame_gpu = cv2.cuda_GpuMat() # Create GPU matrix
        frame_gpu.upload(frame) # Upload frame data to GPU
        gray_gpu = cv2.cuda.cvtColor(frame_gpu, cv2.COLOR_BGR2GRAY) # Convert to grayscale on GPU
        flow_vectors = []
        points_to_track = False # Flag if we have points from the previous frame

        # --- Track existing points ---
        if self.prev_gray_gpu is not None and self.prev_points_gpu is not None and not self.prev_points_gpu.empty():
            points_to_track = True
            # Ensure prev_points_gpu is float32 (required by LK flow)
            if self.prev_points_gpu.type() != cv2.CV_32FC2:
                 logger.warning("prev_points_gpu is not CV_32FC2, attempting conversion.")
                 try:
                     temp_cpu = self.prev_points_gpu.download().astype(np.float32)
                     temp_gpu = cv2.cuda_GpuMat()
                     temp_gpu.upload(temp_cpu)
                     self.prev_points_gpu = temp_gpu
                     if self.prev_points_gpu.empty(): # Check if conversion failed
                          points_to_track = False
                          logger.error("Failed to convert prev_points_gpu to CV_32FC2.")
                 except Exception as e:
                      points_to_track = False
                      logger.error(f"Error converting prev_points_gpu: {e}")


            if points_to_track:
                # Calculate optical flow on GPU
                next_points_gpu, status_gpu, err_gpu = self.gpu_lk_flow.calc(
                    self.prev_gray_gpu, gray_gpu, self.prev_points_gpu, None
                )

                # Process results if calculation was successful
                if next_points_gpu is not None and status_gpu is not None:
                    status = status_gpu.download().flatten() # Download status vector
                    # Filter points based on status BEFORE downloading coordinates for efficiency
                    # Note: Indexing GpuMat directly like this requires newer OpenCV versions
                    try:
                        prev_points_gpu_filtered = self.prev_points_gpu[status == 1]
                        next_points_gpu_filtered = next_points_gpu[status == 1]
                    except cv2.error as e:
                        # Fallback for older OpenCV: download all, then filter
                        logger.warning(f"GPU Mat indexing failed (likely older OpenCV): {e}. Using CPU filtering.")
                        prev_pts_cpu = self.prev_points_gpu.download().reshape(-1, 2)
                        next_pts_cpu = next_points_gpu.download().reshape(-1, 2)
                        good_old_cpu = prev_pts_cpu[status == 1]
                        good_new_cpu = next_pts_cpu[status == 1]
                        if len(good_new_cpu) > 0:
                             prev_points_gpu_filtered = cv2.cuda_GpuMat()
                             prev_points_gpu_filtered.upload(good_old_cpu.astype(np.float32).reshape(-1, 1, 2))
                             next_points_gpu_filtered = cv2.cuda_GpuMat()
                             next_points_gpu_filtered.upload(good_new_cpu.astype(np.float32).reshape(-1, 1, 2))
                        else:
                             prev_points_gpu_filtered = cv2.cuda_GpuMat() # Empty
                             next_points_gpu_filtered = cv2.cuda_GpuMat() # Empty


                    # Download filtered points if any survived
                    if not next_points_gpu_filtered.empty() and not prev_points_gpu_filtered.empty():
                        good_new = next_points_gpu_filtered.download().reshape(-1, 2)
                        good_old = prev_points_gpu_filtered.download().reshape(-1, 2)

                        # Create flow vectors (start_pt, end_pt)
                        flow_vectors = [(tuple(map(int, p)), tuple(map(int, q))) for p, q in zip(good_old, good_new)]
                        # Update previous points to the successfully tracked new points
                        self.prev_points_gpu = next_points_gpu_filtered
                    else:
                        points_to_track = False # Lost all points
                        self.prev_points_gpu = None # Reset points
                else:
                    points_to_track = False # Calculation failed
                    self.prev_points_gpu = None

        # --- Detect new features if needed ---
        # Detect if no points were tracked or if point count dropped significantly
        detect_new = False
        if not points_to_track:
            detect_new = True
        elif self.prev_points_gpu is not None and self.prev_points_gpu.rows() < self.cpu_feature_params['maxCorners'] * 0.5:
             detect_new = True


        if detect_new:
            # logger.debug("Detecting new features (GPU)...")
            # Detect good features to track on the GPU
            detected_points_gpu_mat = self.gpu_detector.detect(gray_gpu, None) # Mask is None

            # Process detected features
            if detected_points_gpu_mat is not None and not detected_points_gpu_mat.empty():
                 # Ensure the detected points are in the correct format (N x 1 x 2, float32) for LK flow
                 self.prev_points_gpu = detected_points_gpu_mat.reshape(-1, 1, 2)
                 if self.prev_points_gpu.type() != cv2.CV_32FC2:
                     # Convert if necessary
                     try:
                         temp_cpu = self.prev_points_gpu.download().astype(np.float32)
                         temp_gpu = cv2.cuda_GpuMat()
                         temp_gpu.upload(temp_cpu)
                         self.prev_points_gpu = temp_gpu
                     except Exception as e:
                          self.prev_points_gpu = None # Reset if conversion fails
                          logger.error(f"Error converting detected GPU points: {e}")

                 # logger.debug(f"Detected {self.prev_points_gpu.rows() if self.prev_points_gpu is not None else 0} new features (GPU).")
            else:
                self.prev_points_gpu = None # No features found
                # logger.debug("No new features detected (GPU).")


        # Update previous frame's grayscale image
        self.prev_gray_gpu = gray_gpu
        # Emit calculated flow vectors
        if flow_vectors:
            self.flow_ready.emit(flow_vectors)


    def run_cpu(self, frame):
        """Calculates optical flow using CPU."""
        gray_cpu = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) # Convert to grayscale
        flow_vectors = []
        points_to_track = False

        # --- Track existing points ---
        if self.prev_gray_cpu is not None and self.prev_points_cpu is not None and len(self.prev_points_cpu) > 0:
            points_to_track = True
            # Calculate optical flow using Lucas-Kanade method
            next_points_cpu, status, err = cv2.calcOpticalFlowPyrLK(
                self.prev_gray_cpu, gray_cpu, self.prev_points_cpu, None, **self.cpu_lk_params
            )

            # Filter points based on status
            if next_points_cpu is not None and status is not None:
                good_new = next_points_cpu[status == 1]
                good_old = self.prev_points_cpu[status == 1]

                if len(good_new) > 0:
                     # Create flow vectors
                     flow_vectors = [(tuple(map(int, p)), tuple(map(int, q))) for p, q in zip(good_old, good_new)]
                     # Update previous points to the successfully tracked new points
                     self.prev_points_cpu = good_new.reshape(-1, 1, 2)
                else:
                    points_to_track = False # Lost all points
                    self.prev_points_cpu = None # Reset points
            else:
                points_to_track = False # Calculation failed
                self.prev_points_cpu = None


        # --- Detect new features if needed ---
        detect_new = False
        if not points_to_track:
            detect_new = True
        elif self.prev_points_cpu is not None and len(self.prev_points_cpu) < self.cpu_feature_params['maxCorners'] * 0.5:
            detect_new = True

        if detect_new:
            # logger.debug("Detecting new features (CPU)...")
            # Detect good features to track on CPU
            self.prev_points_cpu = cv2.goodFeaturesToTrack(gray_cpu, mask=None, **self.cpu_feature_params)
            # if self.prev_points_cpu is not None:
            #      logger.debug(f"Detected {len(self.prev_points_cpu)} new features (CPU).")
            # else:
            #      logger.debug("No new features detected (CPU).")


        # Update previous frame's grayscale image
        self.prev_gray_cpu = gray_cpu.copy()
        # Emit calculated flow vectors
        if flow_vectors:
            self.flow_ready.emit(flow_vectors)


    def stop(self):
        """Signals the thread to stop running."""
        self.running = False

    def reset_state(self):
        """Resets the internal state (previous frame, points) for both CPU and GPU."""
        self.prev_gray_gpu = None
        self.prev_points_gpu = None
        self.prev_gray_cpu = None
        self.prev_points_cpu = None
        # logger.debug("Optical flow state reset.")

    @Slot(bool)
    def set_enabled(self, enabled):
        """Enables or disables processing in the thread."""
        if self._enabled != enabled:
            self._enabled = enabled
            logger.info(f"Optical Flow {'enabled' if enabled else 'disabled'}")
            if not enabled:
                # Reset state when disabling to clear old points/frames
                self.reset_state()

    def get_processing_time(self):
        """Returns the processing time of the last calculation."""
        return self.processing_time_ms


# --- Depth Estimation Thread ---
class DepthEstimationThread(QThread):
    """Performs depth estimation using a MiDaS model (GPU if available)."""
    depth_ready = Signal(np.ndarray) # Emits normalized depth map (0-1, CPU numpy array)
    status_update = Signal(str)

    def __init__(self, model_type="MiDaS_small"): # Default to small model
        super().__init__()
        self.model_type = model_type
        self.model = None
        self.transform = None # Preprocessing transform
        self.running = False
        self._enabled = False # Disabled by default
        self._frame_lock = QMutex()
        self.current_frame = None
        self.device = None # PyTorch device
        self.processing_time_ms = 0.0

    def load_model(self):
        """Loads the MiDaS model and corresponding transform."""
        self.status_update.emit(f"Loading MiDaS model: {self.model_type}...")
        logger.info(f"Attempting to load MiDaS model: {self.model_type}")
        try:
            # Determine device
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
                logger.info("Using GPU for MiDaS.")
            else:
                self.device = torch.device("cpu")
                logger.info("Using CPU for MiDaS.")

            # Load model and transforms from PyTorch Hub (requires internet connection on first run)
            # Use trust_repo=True if needed for newer versions or custom models
            self.model = torch.hub.load("intel-isl/MiDaS", self.model_type, trust_repo=True)
            midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)

            # Select the appropriate transform based on the model type
            if self.model_type == "MiDaS_small":
                 self.transform = midas_transforms.small_transform
            elif "dpt_large" in self.model_type.lower() or "dpt_hybrid" in self.model_type.lower():
                 self.transform = midas_transforms.dpt_transform # DPT models use dpt_transform
            else: # Fallback for other MiDaS v2.1 models
                 self.transform = midas_transforms.dpt_transform

            self.model.to(self.device) # Move model to the selected device
            self.model.eval() # Set model to evaluation mode (important for inference)
            logger.info(f"MiDaS model '{self.model_type}' loaded on {self.device}.")
            self.status_update.emit(f"MiDaS model '{self.model_type}' loaded on {self.device}.")
            return True

        except Exception as e:
            error_msg = f"Failed to load MiDaS model '{self.model_type}': {e}"
            logger.error(error_msg, exc_info=True)
            self.status_update.emit(error_msg)
            self.model = None
            self.transform = None
            self.device = None
            return False

    @Slot(np.ndarray, float) # Accept timestamp, though not directly used
    def set_frame(self, frame, timestamp):
        """Receives a new frame for processing (thread-safe)."""
        with QMutexLocker(self._frame_lock):
            self.current_frame = frame.copy()

    def run(self):
        """Main processing loop for depth estimation."""
        if not self.load_model(): # Load model first
            self.running = False
            return

        self.running = True
        logger.info("Depth Estimation thread started.")
        while self.running:
            if self._enabled: # Only process if enabled
                frame_to_process = None
                # Safely get the latest frame
                with QMutexLocker(self._frame_lock):
                    if self.current_frame is not None:
                        frame_to_process = self.current_frame
                        self.current_frame = None # Consume frame

                if frame_to_process is not None and self.model is not None and self.transform is not None:
                    start_time = time.perf_counter()
                    try:
                        # --- Preprocessing ---
                        # Ensure input is RGB for MiDaS transforms (OpenCV uses BGR)
                        if frame_to_process.shape[2] == 3: # BGR
                            img_rgb = cv2.cvtColor(frame_to_process, cv2.COLOR_BGR2RGB)
                        else:
                             logger.warning("Depth estimation requires BGR input frame.")
                             continue # Skip if not 3 channels

                        # Apply the appropriate transform and move to device
                        input_batch = self.transform(img_rgb).to(self.device)

                        # --- Inference ---
                        with torch.no_grad(): # Disable gradient calculation for inference
                            prediction = self.model(input_batch)

                            # Resize prediction to original image size for visualization
                            prediction = torch.nn.functional.interpolate(
                                prediction.unsqueeze(1), # Add channel dimension
                                size=img_rgb.shape[:2], # Target height, width
                                mode="bicubic",        # Use bicubic interpolation for smoother results
                                align_corners=False,   # Recommended setting
                            ).squeeze() # Remove channel dimension

                        # --- Postprocessing ---
                        # Move depth map to CPU and normalize to 0-1 range
                        depth_map = prediction.cpu().numpy()
                        normalized_depth = cv2.normalize(depth_map, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)

                        # Emit the normalized depth map
                        self.depth_ready.emit(normalized_depth)

                    except Exception as e:
                        logger.error(f"Depth estimation error: {e}", exc_info=True)
                        self.status_update.emit(f"Depth estimation error: {e}")
                    finally:
                        end_time = time.perf_counter()
                        self.processing_time_ms = (end_time - start_time) * 1000
                else:
                    self.msleep(5) # No frame, sleep briefly
            else:
                 # Disabled: Clear potential pending frame
                 with QMutexLocker(self._frame_lock):
                      self.current_frame = None
                 self.msleep(50) # Sleep longer when disabled

        # Cleanup
        logger.info("Depth Estimation thread stopped.")
        self.model = None
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def stop(self):
        """Signals the thread to stop running."""
        self.running = False

    @Slot(bool)
    def set_enabled(self, enabled):
        """Enables or disables processing in the thread."""
        if self._enabled != enabled:
            self._enabled = enabled
            logger.info(f"Depth Estimation {'enabled' if enabled else 'disabled'}")
            if not enabled:
                 # Clear potential pending frame when disabled
                 with QMutexLocker(self._frame_lock):
                      self.current_frame = None

    def get_processing_time(self):
        """Returns the processing time of the last inference."""
        return self.processing_time_ms


# --- Overlay Widget ---
class OverlayWidget(QWidget):
    """Displays the transparent overlay with detections, paths, trajectory, etc."""
    def __init__(self, monitor_rect, parent=None):
        super().__init__(parent)
        # Configure window flags for a transparent, always-on-top overlay
        self.setWindowFlags(Qt.FramelessWindowHint | Qt.WindowStaysOnTopHint | Qt.Tool)
        self.setAttribute(Qt.WA_TranslucentBackground) # Enable transparency
        self.setAttribute(Qt.WA_NoSystemBackground, True) # Don't draw default background
        self.setAttribute(Qt.WA_PaintOnScreen) # Optimization hint

        # Set geometry to match the target monitor
        self.monitor_rect = monitor_rect
        self.setGeometry(monitor_rect.left(), monitor_rect.top(), monitor_rect.width(), monitor_rect.height())

        # --- Font Setup ---
        # Attempt to load custom font, fallback to monospace
        font_id = QFontDatabase.addApplicationFont("PressStart2P-Regular.ttf") # Ensure font file is present
        if font_id != -1:
            font_families = QFontDatabase.applicationFontFamilies(font_id)
            if font_families:
                self.hud_font_name = font_families[0]
                logger.info(f"Loaded font: {self.hud_font_name}")
            else:
                self.hud_font_name = FALLBACK_FONT
                logger.warning("Could not get font family name from ID, using fallback.")
        else:
            self.hud_font_name = FALLBACK_FONT
            logger.warning(f"Could not load font '{FONT_NAME}', using fallback '{FALLBACK_FONT}'.")

        self.font_small = QFont(self.hud_font_name, FONT_SIZE_SMALL)
        self.font_medium = QFont(self.hud_font_name, FONT_SIZE_MEDIUM)

        # --- Data Storage ---
        self.detections = [] # List of (label, conf, box)
        self.poses = []      # List of (keypoints_array, box, pose_conf)
        self.flow_vectors = [] # List of (start_pt, end_pt)
        self.depth_map = None  # Normalized depth map (numpy array)
        self.hud_texts = {}    # Dictionary for HUD key-value pairs

        # --- Smoothing & Prediction Data ---
        self.target_history = collections.deque(maxlen=PATH_HISTORY_LENGTH) # History of (timestamp, smoothed_pos)
        self.velocity_history = collections.deque(maxlen=VELOCITY_HISTORY_LENGTH) # History of (timestamp, smoothed_vel)
        self.smoothed_target_center = None # QPointF for smoothed crosshair position
        self.smoothed_velocity = QPointF(0, 0) # QPointF for smoothed velocity vector
        self.smoothed_acceleration = QPointF(0, 0) # QPointF for smoothed acceleration vector

        # --- Drawing Flags ---
        self.show_boxes = True
        self.show_labels = True
        self.show_paths = True
        self.show_trajectory = True # Controls the new physics-based trajectory
        self.show_crosshair = True
        self.show_poses = True # Linked to pose estimation enable state
        self.show_flow = False
        self.show_depth = False
        self.show_hud = True

        # --- Scanline Effect ---
        self.scanline_y = 0
        self.scanline_timer = QTimer(self)
        self.scanline_timer.timeout.connect(self.update_scanline)
        self.scanline_timer.start(SCANLINE_SPEED_MS)

        # --- HUD Update Timer ---
        self.hud_update_timer = QTimer(self)
        self.hud_update_timer.timeout.connect(self.update) # Trigger repaint for HUD updates
        self.hud_update_timer.start(UPDATE_INTERVAL_MS)

        # --- Performance Metrics ---
        self.capture_fps = 0.0
        self.detection_fps = 0.0
        self.pose_fps = 0.0
        self.flow_fps = 0.0
        self.depth_fps = 0.0
        self.last_capture_time = time.perf_counter()
        self.last_detection_time = time.perf_counter()
        self.last_pose_time = time.perf_counter()
        self.last_flow_time = time.perf_counter()
        self.last_depth_time = time.perf_counter()
        self.detection_proc_time = 0.0
        self.pose_proc_time = 0.0
        self.flow_proc_time = 0.0 # Added flow proc time tracking
        self.depth_proc_time = 0.0 # Added depth proc time tracking


    def update_scanline(self):
        """Updates the position of the scanline effect and triggers repaint."""
        self.scanline_y = (self.scanline_y + 5) % self.height() # Move scanline down
        self.update() # Request redraw

    @Slot(list, float, float)
    def update_detections(self, detections, timestamp, proc_time):
        """Receives detection results, updates target, smoothing, and triggers repaint."""
        self.detections = detections
        self.detection_proc_time = proc_time

        # --- Target Selection (Example: Largest 'person' box) ---
        target_box = None
        max_area = 0
        person_detections = [d for d in detections if d[0] == 'person'] # Filter

        if person_detections:
            for label, conf, box in person_detections:
                x1, y1, x2, y2 = box
                area = (x2 - x1) * (y2 - y1)
                if area > max_area:
                    max_area = area
                    target_box = box

        # --- Position Smoothing ---
        raw_target_center = None
        if target_box:
            x1, y1, x2, y2 = target_box
            raw_target_center = QPointF((x1 + x2) / 2, (y1 + y2) / 2) # Calculate center

        if raw_target_center:
            if self.smoothed_target_center is None:
                # Initialize smoothing on the first valid detection
                self.smoothed_target_center = raw_target_center
            else:
                # Apply EMA for position: smooth = alpha * new + (1 - alpha) * old
                self.smoothed_target_center = (TARGET_CENTER_SMOOTHING_FACTOR * raw_target_center +
                                              (1.0 - TARGET_CENTER_SMOOTHING_FACTOR) * self.smoothed_target_center)

            # --- Update Target Position History ---
            self.target_history.append((timestamp, self.smoothed_target_center)) # Store smoothed center

            # --- Velocity Calculation and Smoothing ---
            if len(self.target_history) >= 2:
                 # Use last two points for instantaneous velocity
                 t1, p1 = self.target_history[-2]
                 t2, p2 = self.target_history[-1] # Most recent point
                 dt = t2 - t1
                 if dt > 1e-6: # Avoid division by zero
                      raw_velocity = (p2 - p1) / dt # Calculate velocity vector
                      # Apply EMA for velocity
                      self.smoothed_velocity = (VELOCITY_SMOOTHING_FACTOR * raw_velocity +
                                                (1.0 - VELOCITY_SMOOTHING_FACTOR) * self.smoothed_velocity)

                      # --- Update Velocity History ---
                      self.velocity_history.append((timestamp, self.smoothed_velocity))

                      # --- Acceleration Calculation and Smoothing ---
                      if len(self.velocity_history) >= 2:
                           vt1, v1 = self.velocity_history[-2]
                           vt2, v2 = self.velocity_history[-1] # Most recent velocity
                           vdt = vt2 - vt1
                           if vdt > 1e-6:
                                raw_acceleration = (v2 - v1) / vdt # Calculate acceleration vector
                                # Apply EMA for acceleration
                                self.smoothed_acceleration = (ACCEL_SMOOTHING_FACTOR * raw_acceleration +
                                                             (1.0 - ACCEL_SMOOTHING_FACTOR) * self.smoothed_acceleration)
                           # else: # If dt is too small, keep previous acceleration
                           #     pass
                      # else: # Not enough velocity history, assume zero acceleration
                      #     self.smoothed_acceleration = QPointF(0, 0)

                 # else: # If dt is too small, keep previous velocity and acceleration
                 #     pass
            # else: # Not enough position history, assume zero velocity and acceleration
            #     self.smoothed_velocity = QPointF(0, 0)
            #     self.smoothed_acceleration = QPointF(0, 0)
            #     self.velocity_history.clear() # Clear velocity history too


        else:
            # No target detected this frame. Options:
            # 1. Keep last known velocity/acceleration (causes prediction to continue linearly)
            # 2. Decay velocity/acceleration towards zero (prediction slows down)
            # 3. Clear smoothed position (crosshair disappears)
            # Current approach: Keep last known values.
            # self.smoothed_target_center = None # Option 3
            pass


        # --- FPS Calculation ---
        now = time.perf_counter()
        time_diff = now - self.last_detection_time
        if time_diff > 0:
            self.detection_fps = 1.0 / time_diff
        self.last_detection_time = now

        self.update() # Trigger repaint

    @Slot(list, float, float)
    def update_poses(self, poses, timestamp, proc_time):
        """Receives pose estimation results."""
        self.poses = poses
        self.pose_proc_time = proc_time
        # --- FPS Calculation ---
        now = time.perf_counter()
        time_diff = now - self.last_pose_time
        if time_diff > 0:
            self.pose_fps = 1.0 / time_diff
        self.last_pose_time = now
        if self.show_poses: # Only repaint if poses are visible
             self.update()

    @Slot(list)
    def update_flow(self, flow_vectors):
        """Receives optical flow results."""
        self.flow_vectors = flow_vectors
        # Get processing time from the thread if available (needs modification in OpticalFlowThread)
        # self.flow_proc_time = self.sender().get_processing_time() # Example if thread emitted it
        # --- FPS Calculation (Approximate based on signal arrival) ---
        now = time.perf_counter()
        time_diff = now - self.last_flow_time
        if time_diff > 0:
            self.flow_fps = 1.0 / time_diff
        self.last_flow_time = now
        if self.show_flow: # Only repaint if flow is visible
            self.update()

    @Slot(np.ndarray)
    def update_depth(self, depth_map):
        """Receives depth estimation results."""
        self.depth_map = depth_map
        # Get processing time from the thread if available
        # self.depth_proc_time = self.sender().get_processing_time() # Example
        # --- FPS Calculation (Approximate based on signal arrival) ---
        now = time.perf_counter()
        time_diff = now - self.last_depth_time
        if time_diff > 0:
            self.depth_fps = 1.0 / time_diff
        self.last_depth_time = now
        if self.show_depth: # Only repaint if depth is visible
            self.update()

    @Slot(dict)
    def update_hud(self, hud_data):
        """Receives text data for the HUD."""
        self.hud_texts.update(hud_data)
        # No repaint needed here, hud_update_timer handles periodic updates

    @Slot(float)
    def update_capture_fps(self, fps):
        """Receives capture FPS."""
        self.capture_fps = fps

    # --- Toggling drawing elements ---
    @Slot(bool)
    def toggle_boxes(self, show): self.show_boxes = show; self.update()
    @Slot(bool)
    def toggle_labels(self, show): self.show_labels = show; self.update()
    @Slot(bool)
    def toggle_paths(self, show): self.show_paths = show; self.update()
    @Slot(bool)
    def toggle_trajectory(self, show): self.show_trajectory = show; self.update()
    @Slot(bool)
    def toggle_crosshair(self, show): self.show_crosshair = show; self.update()
    @Slot(bool)
    def toggle_poses(self, show): self.show_poses = show; self.update()
    @Slot(bool)
    def toggle_flow(self, show): self.show_flow = show; self.update()
    @Slot(bool)
    def toggle_depth(self, show): self.show_depth = show; self.update()
    @Slot(bool)
    def toggle_hud(self, show): self.show_hud = show; self.update()


    def paintEvent(self, event):
        """Draws all overlay elements: depth, scanline, flow, detections, poses, path, trajectory, crosshair, HUD."""
        painter = QPainter(self)
        painter.setRenderHint(QPainter.Antialiasing) # Enable anti-aliasing for smoother lines/curves

        # --- Clear Background (Essential for transparency) ---
        painter.fillRect(self.rect(), Qt.transparent)

        # --- 1. Depth Map (Draw first, as background) ---
        if self.show_depth and self.depth_map is not None:
            try:
                # Convert normalized float32 depth map (0-1) to 8-bit grayscale QImage
                depth_8bit = (self.depth_map * 255).astype(np.uint8)
                h, w = depth_8bit.shape
                # Create QImage directly from numpy data (efficient)
                q_image = QImage(depth_8bit.data, w, h, w, QImage.Format_Grayscale8)
                pixmap = QPixmap.fromImage(q_image) # Convert to QPixmap for drawing
                # Scale pixmap to fit widget size, maintaining aspect ratio
                scaled_pixmap = pixmap.scaled(self.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
                # Center the scaled pixmap within the widget
                x_offset = (self.width() - scaled_pixmap.width()) / 2
                y_offset = (self.height() - scaled_pixmap.height()) / 2
                painter.drawPixmap(int(x_offset), int(y_offset), scaled_pixmap)
            except Exception as e:
                logger.error(f"Error drawing depth map: {e}", exc_info=True)


        # --- 2. Scanline Effect ---
        painter.setPen(QPen(QColor(0, 255, 0, 50), 2)) # Semi-transparent green line
        painter.drawLine(0, self.scanline_y, self.width(), self.scanline_y)

        # --- 3. Optical Flow Vectors ---
        if self.show_flow and self.flow_vectors:
            painter.setPen(QPen(YELLOW_COLOR, 1)) # Thin yellow lines
            for start_pt, end_pt in self.flow_vectors:
                painter.drawLine(QPoint(*start_pt), QPoint(*end_pt))
                # Optional: Draw small circles at the end points
                # painter.setBrush(YELLOW_COLOR)
                # painter.drawEllipse(QPoint(*end_pt), 1, 1)


        # --- 4. Detections (Boxes and Labels) ---
        if self.detections:
            painter.setFont(self.font_small) # Use small font for labels
            for label, confidence, box in self.detections:
                x1, y1, x2, y2 = box
                box_width = x2 - x1
                box_height = y2 - y1

                # Draw bounding box
                if self.show_boxes:
                    painter.setPen(QPen(GREEN_COLOR, 2)) # Green box outline
                    painter.drawRect(x1, y1, box_width, box_height)

                # Draw label text
                if self.show_labels:
                    painter.setPen(GREEN_COLOR) # Set text color
                    # painter.setBrush(GREEN_COLOR) # Use Brush if filling text background
                    text = f"{label} ({confidence:.2f})"
                    # Position text slightly above the box
                    painter.drawText(x1 + 2, y1 - 5, text)


        # --- 5. Pose Estimation Skeletons ---
        if self.show_poses and self.poses:
             painter.setPen(QPen(CYAN_COLOR, 2)) # Default color for keypoints
             for keypoints_data, box, pose_conf in self.poses:
                  # keypoints_data is N x 3 (x, y, confidence)
                  num_keypoints = keypoints_data.shape[0]
                  points = [] # Store (index, QPointF) for visible keypoints

                  # Draw Keypoints (small circles)
                  for i in range(num_keypoints):
                       x, y, conf = keypoints_data[i]
                       # Only draw keypoints above the confidence threshold
                       if conf > DEFAULT_POSE_CONFIDENCE_THRESHOLD:
                            pt = QPointF(x, y)
                            points.append((i, pt)) # Store index and point
                            painter.setBrush(POSE_COLORS[i % len(POSE_COLORS)]) # Cycle through colors
                            painter.setPen(Qt.NoPen) # No outline for points
                            painter.drawEllipse(pt, 3, 3) # Draw 3x3 ellipse
                       else:
                            points.append((i, None)) # Mark invisible points

                  # Draw Connections (Skeleton Lines)
                  painter.setPen(QPen(MAGENTA_COLOR, 1)) # Color for skeleton lines
                  # Create a dictionary for quick lookup of visible points by index
                  points_dict = {idx: pt for idx, pt in points if pt is not None}
                  # Iterate through predefined connections
                  for i, (start_idx, end_idx) in enumerate(POSE_CONNECTIONS):
                       # Draw line only if both start and end keypoints are visible
                       if start_idx in points_dict and end_idx in points_dict:
                            pt1 = points_dict[start_idx]
                            pt2 = points_dict[end_idx]
                            # Optional: Color lines based on connection index or body part
                            # painter.setPen(QPen(POSE_COLORS[i % len(POSE_COLORS)], 1))
                            painter.drawLine(pt1, pt2)


        # --- 6. Target Path History ---
        if self.show_paths and len(self.target_history) > 1:
            painter.setPen(QPen(ORANGE_COLOR.lighter(120), 1)) # Lighter Orange, semi-transparent width 1
            # Extract points from history (deque stores (timestamp, QPointF))
            path_points = [p for t, p in self.target_history]
            poly = QPolygonF(path_points) # Create polygon from points
            painter.drawPolyline(poly) # Draw the path history


        # --- 7. Predicted Trajectory (Physics-Based) ---
        if self.show_trajectory and self.smoothed_target_center and len(self.velocity_history) >= 1:
            painter.setPen(QPen(MAGENTA_COLOR, 2, Qt.DashLine)) # Magenta dashed line, width 2

            # --- Physics Prediction Loop ---
            predicted_points = [self.smoothed_target_center] # Start prediction from current smoothed position
            current_pred_pos = self.smoothed_target_center
            current_pred_vel = self.smoothed_velocity # Start with current smoothed velocity
            current_pred_accel = self.smoothed_acceleration # Use current smoothed acceleration

            num_steps = int(TRAJECTORY_PREDICTION_DURATION / PREDICTION_TIME_STEP)
            dt = PREDICTION_TIME_STEP

            for _ in range(num_steps):
                # Kinematic equation: pos = pos_old + vel*dt + 0.5*accel*dt^2
                delta_pos_vel = current_pred_vel * dt
                delta_pos_accel = 0.5 * current_pred_accel * dt * dt
                next_pred_pos = current_pred_pos + delta_pos_vel + delta_pos_accel

                # Update velocity for the next step: vel = vel_old + accel*dt
                current_pred_vel = current_pred_vel + current_pred_accel * dt

                # Add the predicted point to the list
                predicted_points.append(next_pred_pos)
                # Update current position for the next iteration
                current_pred_pos = next_pred_pos

            # Draw the calculated trajectory path
            if len(predicted_points) > 1:
                painter.drawPolyline(QPolygonF(predicted_points))


        # --- 8. Smoothed Crosshair ---
        if self.show_crosshair and self.smoothed_target_center:
            painter.setPen(QPen(RED_COLOR, 1)) # Thin red lines
            # Draw crosshair at the smoothed target center
            cx = int(self.smoothed_target_center.x())
            cy = int(self.smoothed_target_center.y())
            size = 10 # Length of crosshair lines from center
            painter.drawLine(cx - size, cy, cx + size, cy) # Horizontal line
            painter.drawLine(cx, cy - size, cx, cy + size) # Vertical line
            # Optional: Draw a small circle at the very center
            # painter.drawEllipse(self.smoothed_target_center, 2, 2)


        # --- 9. HUD Text ---
        if self.show_hud:
            painter.setFont(self.font_medium) # Use medium font for HUD
            painter.setPen(TEXT_COLOR) # Set text color
            y_offset = 20 # Starting Y position for HUD text (from top)
            line_height = 15 # Spacing between lines

            # Display FPS and processing times
            painter.drawText(10, y_offset, f"Cap: {self.capture_fps:.1f} FPS")
            y_offset += line_height
            painter.drawText(10, y_offset, f"Det: {self.detection_fps:.1f} FPS ({self.detection_proc_time:.1f} ms)")
            y_offset += line_height
            if self.show_poses: # Only show pose info if enabled
                 painter.drawText(10, y_offset, f"Pose: {self.pose_fps:.1f} FPS ({self.pose_proc_time:.1f} ms)")
                 y_offset += line_height
            if self.show_flow: # Only show flow info if enabled
                 # TODO: Get flow proc time from thread and display it
                 painter.drawText(10, y_offset, f"Flow: {self.flow_fps:.1f} FPS ({self.flow_proc_time:.1f} ms)")
                 y_offset += line_height
            if self.show_depth: # Only show depth info if enabled
                 # TODO: Get depth proc time from thread and display it
                 painter.drawText(10, y_offset, f"Depth: {self.depth_fps:.1f} FPS ({self.depth_proc_time:.1f} ms)")
                 y_offset += line_height

            # Display system info from hud_texts dictionary
            for key, value in self.hud_texts.items():
                painter.drawText(10, y_offset, f"{key}: {value}")
                y_offset += line_height

            # Display Crosshair Smoothing Info in HUD
            if self.show_crosshair and self.smoothed_target_center:
                 sx, sy = int(self.smoothed_target_center.x()), int(self.smoothed_target_center.y())
                 painter.drawText(10, y_offset, f"Crosshair (Smooth): {sx}, {sy}")
                 y_offset += line_height
                 # Display velocity and acceleration magnitude (optional debug info)
                 vel_mag = (self.smoothed_velocity.x()**2 + self.smoothed_velocity.y()**2)**0.5
                 acc_mag = (self.smoothed_acceleration.x()**2 + self.smoothed_acceleration.y()**2)**0.5
                 painter.drawText(10, y_offset, f"Vel: {vel_mag:.1f} px/s")
                 y_offset += line_height
                 painter.drawText(10, y_offset, f"Acc: {acc_mag:.1f} px/s^2")
                 y_offset += line_height


        painter.end() # Finish painting


# --- Main Application Window ---
class MainWindow(QMainWindow):
    """Main application window with controls and thread management."""
    # --- Signals for Controlling Worker Threads ---
    capture_interval_changed = Signal(int)
    detection_enabled_changed = Signal(bool)
    detection_conf_changed = Signal(float)
    detection_nms_changed = Signal(float)
    pose_enabled_changed = Signal(bool)
    pose_conf_changed = Signal(float)
    flow_enabled_changed = Signal(bool)
    depth_enabled_changed = Signal(bool)

    # --- Signals for Toggling Overlay Elements ---
    toggle_boxes_signal = Signal(bool)
    toggle_labels_signal = Signal(bool)
    toggle_paths_signal = Signal(bool)
    toggle_trajectory_signal = Signal(bool)
    toggle_crosshair_signal = Signal(bool)
    toggle_poses_signal = Signal(bool)
    toggle_flow_signal = Signal(bool)
    toggle_depth_signal = Signal(bool)
    toggle_hud_signal = Signal(bool)

    # --- Signals for Sending Data to Overlay ---
    hud_data_signal = Signal(dict)
    capture_fps_signal = Signal(float)


    def __init__(self):
        super().__init__()
        self.setWindowTitle("Real-Time Detection Overlay")
        self.setGeometry(100, 100, 450, 750) # Adjusted height slightly for new controls/info

        # --- Monitor Selection ---
        try:
            self.monitors = mss.mss().monitors[1:] # Exclude the 'all monitors' entry (index 0)
            if not self.monitors:
                raise RuntimeError("No monitors found (excluding 'all monitors').")
            self.selected_monitor_spec = self.monitors[0] # Default to the first physical monitor
        except Exception as e:
            logger.error(f"Failed to get monitor list: {e}. Exiting.", exc_info=True)
            # Show error message box
            msg_box = QtWidgets.QMessageBox()
            msg_box.setIcon(QtWidgets.QMessageBox.Critical)
            msg_box.setText(f"Error initializing monitors: {e}\n\nPlease ensure display drivers are working.")
            msg_box.setWindowTitle("Monitor Error")
            msg_box.exec()
            sys.exit(1) # Exit if monitors can't be found


        # --- Central Widget and Layout ---
        self.central_widget = QWidget()
        self.main_layout = QVBoxLayout(self.central_widget)
        self.setCentralWidget(self.central_widget)

        # --- Control Panel ---
        self.control_group = QGroupBox("Controls")
        self.control_layout = QGridLayout()
        self.control_group.setLayout(self.control_layout)
        self.main_layout.addWidget(self.control_group)

        # Monitor Selection Dropdown
        self.monitor_combo = QComboBox()
        self.monitor_combo.addItems([f"Monitor {i+1} ({m['width']}x{m['height']}@{m['left']},{m['top']})" for i, m in enumerate(self.monitors)])
        self.monitor_combo.currentIndexChanged.connect(self.update_monitor)
        self.control_layout.addWidget(QLabel("Target Monitor:"), 0, 0)
        self.control_layout.addWidget(self.monitor_combo, 0, 1, 1, 2) # Span 2 columns

        # Capture Interval SpinBox
        self.interval_spinbox = QSpinBox()
        self.interval_spinbox.setRange(1, 1000) # 1ms to 1s
        self.interval_spinbox.setValue(CAPTURE_INTERVAL_MS)
        self.interval_spinbox.setSuffix(" ms")
        self.interval_spinbox.setToolTip("Interval between screen captures (lower = higher FPS target)")
        self.interval_spinbox.valueChanged.connect(self.capture_interval_changed.emit)
        self.control_layout.addWidget(QLabel("Capture Interval:"), 1, 0)
        self.control_layout.addWidget(self.interval_spinbox, 1, 1, 1, 2)

        # --- Detection Controls Group ---
        self.detection_group = QGroupBox("Object Detection (YOLOv8n)")
        self.detection_layout = QGridLayout()
        self.detection_group.setLayout(self.detection_layout)
        self.main_layout.addWidget(self.detection_group)

        self.detection_enabled_check = QCheckBox("Enable Detection")
        self.detection_enabled_check.setChecked(True) # Enabled by default
        self.detection_enabled_check.toggled.connect(self.detection_enabled_changed.emit)
        self.detection_layout.addWidget(self.detection_enabled_check, 0, 0, 1, 3) # Span columns

        self.conf_spinbox = QDoubleSpinBox()
        self.conf_spinbox.setRange(0.01, 1.0)
        self.conf_spinbox.setSingleStep(0.05)
        self.conf_spinbox.setValue(DEFAULT_CONFIDENCE_THRESHOLD)
        self.conf_spinbox.setToolTip("Minimum confidence score for detected objects")
        self.conf_spinbox.valueChanged.connect(self.detection_conf_changed.emit)
        self.detection_layout.addWidget(QLabel("Confidence Threshold:"), 1, 0)
        self.detection_layout.addWidget(self.conf_spinbox, 1, 1, 1, 2)

        self.nms_spinbox = QDoubleSpinBox()
        self.nms_spinbox.setRange(0.01, 1.0)
        self.nms_spinbox.setSingleStep(0.05)
        self.nms_spinbox.setValue(DEFAULT_NMS_THRESHOLD)
        self.nms_spinbox.setToolTip("Non-Maximum Suppression (NMS) threshold (IoU)")
        self.nms_spinbox.valueChanged.connect(self.detection_nms_changed.emit)
        self.detection_layout.addWidget(QLabel("NMS Threshold:"), 2, 0)
        self.detection_layout.addWidget(self.nms_spinbox, 2, 1, 1, 2)

        # --- Pose Estimation Controls Group ---
        self.pose_group = QGroupBox("Pose Estimation (YOLOv8n-Pose)")
        self.pose_layout = QGridLayout()
        self.pose_group.setLayout(self.pose_layout)
        self.main_layout.addWidget(self.pose_group)

        self.pose_enabled_check = QCheckBox("Enable Pose Estimation")
        self.pose_enabled_check.setChecked(False) # Default disabled
        self.pose_enabled_check.toggled.connect(self.pose_enabled_changed.emit)
        self.pose_enabled_check.toggled.connect(self.toggle_poses_signal.emit) # Link enable state to drawing toggle
        self.pose_layout.addWidget(self.pose_enabled_check, 0, 0, 1, 3)

        self.pose_conf_spinbox = QDoubleSpinBox()
        self.pose_conf_spinbox.setRange(0.01, 1.0)
        self.pose_conf_spinbox.setSingleStep(0.05)
        self.pose_conf_spinbox.setValue(DEFAULT_POSE_CONFIDENCE_THRESHOLD)
        self.pose_conf_spinbox.setToolTip("Minimum confidence score for detected poses/keypoints")
        self.pose_conf_spinbox.valueChanged.connect(self.pose_conf_changed.emit)
        self.pose_layout.addWidget(QLabel("Pose Conf Threshold:"), 1, 0)
        self.pose_layout.addWidget(self.pose_conf_spinbox, 1, 1, 1, 2)


        # --- Other Feature Controls Group ---
        self.features_group = QGroupBox("Other Features")
        self.features_layout = QGridLayout()
        self.features_group.setLayout(self.features_layout)
        self.main_layout.addWidget(self.features_group)

        self.flow_enabled_check = QCheckBox("Enable Optical Flow")
        self.flow_enabled_check.setChecked(False) # Default disabled
        self.flow_enabled_check.toggled.connect(self.flow_enabled_changed.emit)
        self.flow_enabled_check.toggled.connect(self.toggle_flow_signal.emit) # Link enable to drawing
        self.features_layout.addWidget(self.flow_enabled_check, 0, 0)

        self.depth_enabled_check = QCheckBox("Enable Depth Estimation")
        self.depth_enabled_check.setChecked(False) # Default disabled
        self.depth_enabled_check.toggled.connect(self.depth_enabled_changed.emit)
        self.depth_enabled_check.toggled.connect(self.toggle_depth_signal.emit) # Link enable to drawing
        self.features_layout.addWidget(self.depth_enabled_check, 0, 1)


        # --- Overlay Visibility Controls Group ---
        self.visibility_group = QGroupBox("Overlay Visibility")
        self.visibility_layout = QGridLayout()
        self.visibility_group.setLayout(self.visibility_layout)
        self.main_layout.addWidget(self.visibility_group)

        # Add checkboxes for each visual element
        self.show_boxes_check = QCheckBox("Show Boxes")
        self.show_boxes_check.setChecked(True)
        self.show_boxes_check.toggled.connect(self.toggle_boxes_signal.emit)
        self.visibility_layout.addWidget(self.show_boxes_check, 0, 0)

        self.show_labels_check = QCheckBox("Show Labels")
        self.show_labels_check.setChecked(True)
        self.show_labels_check.toggled.connect(self.toggle_labels_signal.emit)
        self.visibility_layout.addWidget(self.show_labels_check, 0, 1)

        self.show_paths_check = QCheckBox("Show Path History") # Renamed for clarity
        self.show_paths_check.setChecked(True)
        self.show_paths_check.toggled.connect(self.toggle_paths_signal.emit)
        self.visibility_layout.addWidget(self.show_paths_check, 1, 0)

        self.show_trajectory_check = QCheckBox("Show Trajectory")
        self.show_trajectory_check.setChecked(True)
        self.show_trajectory_check.toggled.connect(self.toggle_trajectory_signal.emit)
        self.visibility_layout.addWidget(self.show_trajectory_check, 1, 1)

        self.show_crosshair_check = QCheckBox("Show Crosshair")
        self.show_crosshair_check.setChecked(True)
        self.show_crosshair_check.toggled.connect(self.toggle_crosshair_signal.emit)
        self.visibility_layout.addWidget(self.show_crosshair_check, 2, 0)

        # Note: Pose visibility is implicitly controlled by the "Enable Pose Estimation" checkbox
        # via the toggle_poses_signal connection. A separate checkbox here would be redundant.

        self.show_hud_check = QCheckBox("Show HUD")
        self.show_hud_check.setChecked(True)
        self.show_hud_check.toggled.connect(self.toggle_hud_signal.emit)
        self.visibility_layout.addWidget(self.show_hud_check, 2, 1)


        # --- Log Output Area ---
        self.log_group = QGroupBox("Log Output")
        self.log_layout = QVBoxLayout()
        self.log_group.setLayout(self.log_layout)
        self.log_text_edit = QTextEdit()
        self.log_text_edit.setReadOnly(True)
        self.log_text_edit.setFont(QFont("Monospace", 8)) # Smaller monospace font for logs
        self.log_layout.addWidget(self.log_text_edit)
        self.main_layout.addWidget(self.log_group)
        self.main_layout.setStretchFactor(self.log_group, 1) # Allow log area to expand

        # --- Status Bar ---
        self.status_bar = self.statusBar()
        self.status_bar.showMessage("Ready")

        # --- Initialize Overlay Widget ---
        monitor_qrect = QRect(
            self.selected_monitor_spec['left'],
            self.selected_monitor_spec['top'],
            self.selected_monitor_spec['width'],
            self.selected_monitor_spec['height']
        )
        self.overlay = OverlayWidget(monitor_qrect)
        self.overlay.show() # Show the overlay window

        # --- Setup Logging Handler ---
        self.log_handler = QTextEditLogger(self.log_text_edit)
        logging.getLogger().addHandler(self.log_handler) # Add handler to root logger
        logging.getLogger().setLevel(logging.INFO) # Set logging level

        # --- System Info Timer ---
        self.system_info_timer = QTimer(self)
        self.system_info_timer.timeout.connect(self.update_system_info)
        self.system_info_timer.start(SYSTEM_INFO_INTERVAL_MS)

        # --- Initialize Threads ---
        self.capture_thread = None
        self.detection_thread = None
        self.pose_thread = None
        self.flow_thread = None
        self.depth_thread = None
        self.start_threads() # Start worker threads

        # --- Connect Signals and Slots ---
        self.connect_signals()

        # --- Initial System Info Update ---
        self.update_system_info() # Populate HUD immediately


    def update_monitor(self, index):
        """Handles monitor selection changes by restarting threads."""
        if 0 <= index < len(self.monitors):
            self.selected_monitor_spec = self.monitors[index]
            logger.info(f"Monitor changed to: {index+1} - {self.selected_monitor_spec}")
            self.status_bar.showMessage(f"Monitor changed to {index+1}. Restarting capture...")

            # 1. Stop existing threads gracefully
            self.stop_threads()

            # 2. Update overlay geometry and internal reference
            monitor_qrect = QRect(
                self.selected_monitor_spec['left'],
                self.selected_monitor_spec['top'],
                self.selected_monitor_spec['width'],
                self.selected_monitor_spec['height']
            )
            self.overlay.monitor_rect = monitor_qrect # Update internal rect
            self.overlay.setGeometry(monitor_qrect) # Resize/reposition overlay window

            # 3. Restart threads with the new monitor spec
            self.start_threads()

            # 4. Reconnect signals as threads are new instances
            self.connect_signals() # Crucial step!

            self.status_bar.showMessage(f"Capture restarted on monitor {index+1}.")
        else:
             logger.error(f"Invalid monitor index selected: {index}")


    def start_threads(self):
        """Initializes and starts all worker threads."""
        logger.info("Starting worker threads...")
        # Screen Capture Thread
        self.capture_thread = ScreenCaptureThread(self.selected_monitor_spec)
        self.capture_thread.status_update.connect(self.update_status)
        self.capture_thread.start()
        self.capture_thread.setObjectName("CaptureThread") # For logging

        # Detection Thread
        self.detection_thread = DetectionThread(OBJECT_DETECTION_MODEL)
        self.detection_thread.status_update.connect(self.update_status)
        self.detection_thread.start()
        self.detection_thread.setObjectName("DetectionThread")

        # Pose Estimation Thread
        self.pose_thread = PoseEstimationThread(POSE_ESTIMATION_MODEL)
        self.pose_thread.status_update.connect(self.update_status)
        self.pose_thread.start()
        self.pose_thread.setObjectName("PoseThread")

        # Optical Flow Thread
        self.flow_thread = OpticalFlowThread()
        self.flow_thread.status_update.connect(self.update_status)
        self.flow_thread.start()
        self.flow_thread.setObjectName("FlowThread")

        # Depth Estimation Thread
        self.depth_thread = DepthEstimationThread() # Uses default MiDaS_small
        self.depth_thread.status_update.connect(self.update_status)
        self.depth_thread.start()
        self.depth_thread.setObjectName("DepthThread")

        logger.info("Worker threads initialized and started.")


    def connect_signals(self):
         """Connects signals between GUI controls, worker threads, and the overlay."""
         logger.info("Connecting signals...")
         # Disconnect existing connections first to prevent duplicates if called multiple times
         try:
             # --- Capture Thread Connections ---
             if self.capture_thread:
                 self.capture_thread.frame_ready.disconnect() # Disconnect all slots from this signal
                 self.capture_interval_changed.disconnect()
             # --- Detection Thread Connections ---
             if self.detection_thread:
                 self.detection_thread.detections_ready.disconnect()
                 self.detection_enabled_changed.disconnect()
                 self.detection_conf_changed.disconnect()
                 self.detection_nms_changed.disconnect()
             # --- Pose Thread Connections ---
             if self.pose_thread:
                 self.pose_thread.poses_ready.disconnect()
                 self.pose_enabled_changed.disconnect()
                 self.pose_conf_changed.disconnect()
             # --- Flow Thread Connections ---
             if self.flow_thread:
                 self.flow_thread.flow_ready.disconnect()
                 self.flow_enabled_changed.disconnect()
             # --- Depth Thread Connections ---
             if self.depth_thread:
                 self.depth_thread.depth_ready.disconnect()
                 self.depth_enabled_changed.disconnect()
             # --- Overlay Visibility Connections ---
             self.toggle_boxes_signal.disconnect()
             self.toggle_labels_signal.disconnect()
             self.toggle_paths_signal.disconnect()
             self.toggle_trajectory_signal.disconnect()
             self.toggle_crosshair_signal.disconnect()
             self.toggle_poses_signal.disconnect()
             self.toggle_flow_signal.disconnect()
             self.toggle_depth_signal.disconnect()
             self.toggle_hud_signal.disconnect()
             # --- HUD Data Connection ---
             self.hud_data_signal.disconnect()
             self.capture_fps_signal.disconnect()
         except (TypeError, RuntimeError) as e:
              # Ignore "signal has no slots" errors during the first connection setup
              # logger.debug(f"Error during signal disconnection (expected on first run): {e}")
              pass


         # --- Capture Thread Connections ---
         if self.capture_thread:
             # Frame processing pipeline
             self.capture_thread.frame_ready.connect(self.calculate_capture_fps) # Calculate FPS first
             # Send frame to worker threads that need it
             if self.detection_thread:
                 self.capture_thread.frame_ready.connect(self.detection_thread.set_frame)
             if self.pose_thread:
                 self.capture_thread.frame_ready.connect(self.pose_thread.set_frame)
             if self.flow_thread:
                 self.capture_thread.frame_ready.connect(self.flow_thread.set_frame)
             if self.depth_thread:
                 self.capture_thread.frame_ready.connect(self.depth_thread.set_frame)
             # Connect GUI controls to capture thread slots
             self.capture_interval_changed.connect(self.capture_thread.update_capture_interval)


         # --- Detection Thread Connections ---
         if self.detection_thread:
             self.detection_thread.detections_ready.connect(self.overlay.update_detections)
             # Connect GUI controls to detection thread slots
             self.detection_enabled_changed.connect(self.detection_thread.set_enabled)
             self.detection_conf_changed.connect(self.detection_thread.update_confidence_threshold)
             self.detection_nms_changed.connect(self.detection_thread.update_nms_threshold)
             # Sync initial state from GUI to thread
             self.detection_thread.set_enabled(self.detection_enabled_check.isChecked())
             self.detection_thread.update_confidence_threshold(self.conf_spinbox.value())
             self.detection_thread.update_nms_threshold(self.nms_spinbox.value())


         # --- Pose Thread Connections ---
         if self.pose_thread:
             self.pose_thread.poses_ready.connect(self.overlay.update_poses)
             # Connect GUI controls to pose thread slots
             self.pose_enabled_changed.connect(self.pose_thread.set_enabled)
             self.pose_conf_changed.connect(self.pose_thread.update_confidence_threshold)
             # Sync initial state from GUI to thread
             self.pose_thread.set_enabled(self.pose_enabled_check.isChecked())
             self.pose_thread.update_confidence_threshold(self.pose_conf_spinbox.value())


         # --- Flow Thread Connections ---
         if self.flow_thread:
             self.flow_thread.flow_ready.connect(self.overlay.update_flow)
             # Connect GUI controls to flow thread slots
             self.flow_enabled_changed.connect(self.flow_thread.set_enabled)
             # Sync initial state from GUI to thread
             self.flow_thread.set_enabled(self.flow_enabled_check.isChecked())


         # --- Depth Thread Connections ---
         if self.depth_thread:
             self.depth_thread.depth_ready.connect(self.overlay.update_depth)
             # Connect GUI controls to depth thread slots
             self.depth_enabled_changed.connect(self.depth_thread.set_enabled)
             # Sync initial state from GUI to thread
             self.depth_thread.set_enabled(self.depth_enabled_check.isChecked())


         # --- Overlay Visibility Connections (GUI Checkbox -> Overlay Slot) ---
         self.toggle_boxes_signal.connect(self.overlay.toggle_boxes)
         self.toggle_labels_signal.connect(self.overlay.toggle_labels)
         self.toggle_paths_signal.connect(self.overlay.toggle_paths)
         self.toggle_trajectory_signal.connect(self.overlay.toggle_trajectory)
         self.toggle_crosshair_signal.connect(self.overlay.toggle_crosshair)
         self.toggle_poses_signal.connect(self.overlay.toggle_poses) # Connect pose visibility
         self.toggle_flow_signal.connect(self.overlay.toggle_flow)
         self.toggle_depth_signal.connect(self.overlay.toggle_depth)
         self.toggle_hud_signal.connect(self.overlay.toggle_hud)
         # Sync initial state from GUI checkboxes to overlay drawing flags
         self.overlay.toggle_boxes(self.show_boxes_check.isChecked())
         self.overlay.toggle_labels(self.show_labels_check.isChecked())
         self.overlay.toggle_paths(self.show_paths_check.isChecked())
         self.overlay.toggle_trajectory(self.show_trajectory_check.isChecked())
         self.overlay.toggle_crosshair(self.show_crosshair_check.isChecked())
         self.overlay.toggle_poses(self.pose_enabled_check.isChecked()) # Link initial pose visibility to enable state
         self.overlay.toggle_flow(self.flow_enabled_check.isChecked()) # Link initial flow visibility to enable state
         self.overlay.toggle_depth(self.depth_enabled_check.isChecked()) # Link initial depth visibility to enable state
         self.overlay.toggle_hud(self.show_hud_check.isChecked())


         # --- HUD Data Connection (Main Window -> Overlay) ---
         self.hud_data_signal.connect(self.overlay.update_hud)
         self.capture_fps_signal.connect(self.overlay.update_capture_fps)


         logger.info("Signals connected/reconnected.")


    def stop_threads(self):
        """Stops all worker threads gracefully."""
        logger.info("Stopping worker threads...")
        threads = [
            self.capture_thread, self.detection_thread, self.pose_thread,
            self.flow_thread, self.depth_thread
        ]
        for thread in threads:
            if thread and thread.isRunning():
                try:
                    thread.stop() # Signal thread to stop
                    if not thread.wait(2000): # Wait up to 2 seconds
                         logger.warning(f"Thread {thread.objectName()} did not finish gracefully, terminating.")
                         thread.terminate() # Force terminate if needed
                         thread.wait() # Wait after termination
                except Exception as e:
                    logger.error(f"Error stopping thread {thread.objectName()}: {e}")

        # Clear thread references
        self.capture_thread = None
        self.detection_thread = None
        self.pose_thread = None
        self.flow_thread = None
        self.depth_thread = None
        logger.info("Worker threads stopped.")


    @Slot(str)
    def update_status(self, message):
        """Updates the status bar message."""
        self.status_bar.showMessage(message, 5000) # Show for 5 seconds


    def update_system_info(self):
        """Fetches system resource usage and sends it to the HUD via signal."""
        try:
            cpu_usage = psutil.cpu_percent()
            memory_info = psutil.virtual_memory()
            mem_usage = f"{memory_info.percent}% ({memory_info.used / (1024**3):.1f}/{memory_info.total / (1024**3):.1f} GB)"
            gpu_name, gpu_mem_usage = get_gpu_info() # Fetch GPU info

            hud_data = {
                "CPU": f"{cpu_usage:.1f}%",
                "RAM": mem_usage,
                "GPU": gpu_name,
                "VRAM": gpu_mem_usage,
                "Time": datetime.datetime.now().strftime("%H:%M:%S")
            }
            self.hud_data_signal.emit(hud_data) # Emit data for overlay
        except Exception as e:
            logger.error(f"Error updating system info: {e}", exc_info=False) # Log error briefly


    # --- FPS Calculation Slot ---
    @Slot(np.ndarray, float)
    def calculate_capture_fps(self, frame, timestamp):
        """Calculates capture FPS based on frame arrival times and emits it."""
        now = time.perf_counter()
        # Use overlay's time tracker for consistency
        time_diff = now - self.overlay.last_capture_time
        if time_diff > 1e-6: # Avoid division by zero or near-zero
            fps = 1.0 / time_diff
            self.capture_fps_signal.emit(fps) # Emit the calculated FPS
        self.overlay.last_capture_time = now # Update the last time stamp


    def closeEvent(self, event):
        """Handles the main window closing event to ensure clean shutdown."""
        logger.info("Close event triggered. Cleaning up...")
        self.stop_threads() # Stop worker threads first
        if self.overlay:
            self.overlay.close() # Close the overlay window
        logger.info("Cleanup complete. Exiting.")
        event.accept() # Accept the close event


if __name__ == "__main__":
    # --- Pre-check/Download Models ---
    # Ensure necessary model files are present or downloaded by ultralytics
    # This avoids potential delays or errors during thread initialization.
    models_to_check = [OBJECT_DETECTION_MODEL, POSE_ESTIMATION_MODEL]
    logger.info("Checking for required YOLO models...")
    for model_file in models_to_check:
        try:
            if not os.path.exists(model_file):
                 logger.info(f"Model '{model_file}' not found. Attempting to download...")
                 _ = YOLO(model_file) # Instantiating YOLO class should trigger download
                 logger.info(f"Model '{model_file}' download attempt complete.")
            else:
                 logger.info(f"Model '{model_file}' found.")
        except Exception as e:
            logger.warning(f"Could not pre-download/verify model '{model_file}': {e}. YOLO will attempt download on first use.", exc_info=False)

    # --- Application Setup ---
    # Enable High DPI scaling for better rendering on high-resolution displays
    QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True)
    QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True)

    app = QApplication(sys.argv)

    # --- Font Loading ---
    # Ensure the custom font is loaded before creating widgets that use it
    font_path = "PressStart2P-Regular.ttf"
    if os.path.exists(font_path):
        font_id = QFontDatabase.addApplicationFont(font_path)
        if font_id == -1:
             logger.warning(f"Failed to load application font: {font_path}")
    else:
        logger.warning(f"Font file not found: {font_path}. Will use fallback.")


    # --- Main Window Initialization ---
    main_window = MainWindow() # Create the main window instance
    main_window.show() # Show the main window

    # --- Start Event Loop ---
    sys.exit(app.exec()) # Execute the application event loop