In [None]:
import sys
import time
import random
import datetime
import cv2
import numpy as np
import mss
import psutil
import pyautogui # Keep for potential future use, but not used in current drawing logic
from PySide6 import QtCore, QtGui, QtWidgets
from PySide6.QtCore import Qt, QTimer, QPoint, QPointF, QThread, Signal, Slot, QMutex, QMutexLocker, QObject, QRect
from PySide6.QtGui import QColor, QFont, QPainter, QPen, QFontDatabase, QImage, QPixmap, QPolygonF
from PySide6.QtWidgets import (QApplication, QWidget, QMainWindow, QVBoxLayout, QGridLayout,
                             QComboBox, QSpinBox, QDoubleSpinBox, QPushButton, QLabel,
                             QTextEdit, QCheckBox, QGroupBox)
import logging
from ultralytics import YOLO
import torch
import os
import collections # Added for deque

# --- Configurations ---
UPDATE_INTERVAL_MS = 500  # Interval for updating HUD text elements
SCANLINE_SPEED_MS = 15    # Speed of the scanline effect
# Adjusted capture interval - 1ms is often too fast for processing to keep up
# Set to 10ms (100 FPS target capture rate), processing threads will run as fast as possible
CAPTURE_INTERVAL_MS = 10
SYSTEM_INFO_INTERVAL_MS = 1000 # Interval for updating system info
DEFAULT_CONFIDENCE_THRESHOLD = 0.4 # Default confidence threshold for detection
DEFAULT_NMS_THRESHOLD = 0.3        # Default NMS threshold
DEFAULT_POSE_CONFIDENCE_THRESHOLD = 0.5 # Default confidence for pose keypoints
FONT_NAME = "Press Start 2P"
FALLBACK_FONT = "Monospace"
FONT_SIZE_SMALL = 10
FONT_SIZE_MEDIUM = 12
RED_COLOR = QColor(255, 0, 0)
GREEN_COLOR = QColor(0, 255, 0)
BLUE_COLOR = QColor(0, 0, 255)
YELLOW_COLOR = QColor(255, 255, 0)
CYAN_COLOR = QColor(0, 255, 255)
MAGENTA_COLOR = QColor(255, 0, 255)
TEXT_COLOR = QColor(255, 0, 0)

# --- Smoothing Parameters ---
TARGET_CENTER_SMOOTHING_FACTOR = 0.3 # Alpha for EMA (lower = smoother, more lag)
TRAJECTORY_SMOOTHING_FACTOR = 0.4    # Alpha for EMA for trajectory prediction
PATH_HISTORY_LENGTH = 30             # Number of past points to store for path tracking (reduced for performance)
TRAJECTORY_PREDICTION_POINTS = 5     # Number of recent points to use for velocity calculation
TRAJECTORY_PREDICTION_DURATION = 0.3 # Seconds into the future to predict (reduced for performance)

# --- Model Names ---
OBJECT_DETECTION_MODEL = 'yolov8n.pt'
POSE_ESTIMATION_MODEL = 'yolov8n-pose.pt'

# --- Pose Estimation Keypoint Connections (COCO format) ---
# Define connections between keypoints to draw skeletons
# Based on COCO keypoint definition used by YOLOv8 pose models
# (Check model documentation for specific keypoint indices if needed)
POSE_CONNECTIONS = [
    (0, 1), (0, 2), (1, 3), (2, 4),  # Head
    (5, 6), (5, 7), (7, 9), (6, 8), (8, 10),  # Torso/Arms
    (11, 12), (5, 11), (6, 12), # Shoulders to Hips
    (11, 13), (13, 15), (12, 14), (14, 16)  # Legs
]
# Colors for different keypoints/limbs (example)
POSE_COLORS = [QColor(255, 0, 0), QColor(0, 255, 0), QColor(0, 0, 255),
               QColor(255, 255, 0), QColor(0, 255, 255), QColor(255, 0, 255)]


# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(threadName)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Helper Functions ---
def random_hex(length):
    """Generate a random hexadecimal string of specified length."""
    return ''.join(random.choice('ABCDEF0123456789') for _ in range(length))

def get_gpu_info():
    """Fetches GPU name and memory usage if CUDA is available."""
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        try:
            total_mem, free_mem = torch.cuda.mem_get_info(0)
            used_mem = total_mem - free_mem
            mem_usage = f"{(used_mem / (1024**3)):.1f}/{(total_mem / (1024**3)):.1f} GB"
        except AttributeError:
            total_mem = torch.cuda.get_device_properties(0).total_memory
            used_mem = torch.cuda.memory_allocated(0)
            mem_usage = f"{(used_mem / (1024**3)):.1f}/{(total_mem / (1024**3)):.1f} GB (Allocated)"
        except Exception as e:
             logger.error(f"Error getting GPU memory info: {e}", exc_info=True)
             mem_usage = "N/A (Error reading memory)"
        return gpu_name, mem_usage
    return "N/A (CUDA not available)", "N/A"

# --- Custom Logging Handler ---
class QTextEditLogger(logging.Handler):
    """Sends log records to a QTextEdit widget in a thread-safe manner."""
    def __init__(self, text_edit):
        super().__init__()
        self.text_edit = text_edit
        self.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))

    def emit(self, record):
        msg = self.format(record)
        QtCore.QMetaObject.invokeMethod(
            self.text_edit,
            "append",
            QtCore.Qt.QueuedConnection,
            QtCore.Q_ARG(str, msg)
        )

# --- Screen Capture Thread ---
class ScreenCaptureThread(QThread):
    """Captures the screen at a specified interval."""
    frame_ready = Signal(np.ndarray, float) # Emit frame (BGR) and capture timestamp
    status_update = Signal(str)

    def __init__(self, monitor_spec):
        super().__init__()
        self.monitor_spec = monitor_spec
        self.running = False
        self.sct = None
        self._capture_interval_ms = CAPTURE_INTERVAL_MS
        self._lock = QMutex()

    def run(self):
        self.running = True
        try:
            self.sct = mss.mss()
            logger.info(f"Screen capture started for monitor: {self.monitor_spec}")
            last_capture_time = time.perf_counter() # Use high-resolution timer

            while self.running:
                capture_start_time = time.perf_counter()
                try:
                    sct_img = self.sct.grab(self.monitor_spec)
                    frame = np.array(sct_img)

                    if frame.shape[2] == 4:
                        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)
                    elif frame.shape[2] == 3:
                        frame_bgr = frame
                    else:
                        logger.warning(f"Unexpected frame channel count: {frame.shape[2]}")
                        continue

                    current_time = time.time() # Use system time for timestamping frame data
                    self.frame_ready.emit(frame_bgr, current_time)

                except mss.ScreenShotError as e:
                    self.status_update.emit(f"Screen capture error: {e}")
                    logger.error(f"Screen capture error: {e}")
                    time.sleep(1) # Wait before retrying on error
                except Exception as e:
                    self.status_update.emit(f"Unexpected screen capture error: {e}")
                    logger.error(f"Unexpected screen capture error: {e}", exc_info=True)
                    time.sleep(1)

                # Calculate sleep time based on high-resolution timer
                elapsed = time.perf_counter() - capture_start_time
                with QMutexLocker(self._lock):
                    interval_sec = self._capture_interval_ms / 1000.0
                sleep_time = max(0, interval_sec - elapsed)
                if sleep_time > 0:
                    time.sleep(sleep_time) # Sleep for the remaining interval time

        finally:
            if self.sct:
                self.sct.close()
            logger.info("Screen capture stopped.")

    def stop(self):
        self.running = False

    @Slot(int)
    def update_capture_interval(self, interval):
        with QMutexLocker(self._lock):
            if interval > 0:
                logger.info(f"Updating capture interval to {interval} ms")
                self._capture_interval_ms = interval
            else:
                 logger.warning(f"Ignoring invalid capture interval: {interval} ms")

# --- Base Worker Thread for YOLO Models ---
class BaseYoloWorker(QThread):
    """Base class for YOLO detection and pose estimation threads."""
    status_update = Signal(str)

    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name
        self.model = None
        self.input_frame = None
        self.frame_timestamp = 0.0
        self.running = False
        self._enabled = True
        self._confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD
        self._nms_threshold = DEFAULT_NMS_THRESHOLD # Used by object detection
        self._frame_lock = QMutex()
        self.device = None
        self.processing_time_ms = 0.0

    def load_model(self):
        self.status_update.emit(f"Loading model {self.model_name}...")
        logger.info(f"Attempting to load YOLO model: {self.model_name}")
        try:
            self.model = YOLO(self.model_name)
            if torch.cuda.is_available():
                self.device = torch.device('cuda')
                self.model.to(self.device)
                device_name = torch.cuda.get_device_name(0)
                logger.info(f"Model '{self.model_name}' loaded on GPU: {device_name}.")
                self.status_update.emit(f"Model '{self.model_name}' loaded on GPU.")
            else:
                self.device = torch.device('cpu')
                self.model.to(self.device)
                logger.info(f"Model '{self.model_name}' loaded on CPU.")
                self.status_update.emit(f"Model '{self.model_name}' loaded on CPU.")
            return True
        except Exception as e:
            error_msg = f"Failed to load model '{self.model_name}': {e}"
            logger.error(error_msg, exc_info=True)
            self.status_update.emit(error_msg)
            self.model = None
            self.device = None
            return False

    @Slot(np.ndarray, float)
    def set_frame(self, frame, timestamp):
        with QMutexLocker(self._frame_lock):
            # Keep the latest frame and its timestamp
            self.input_frame = frame.copy()
            self.frame_timestamp = timestamp

    def run(self):
        if not self.load_model():
            self.running = False
            return

        self.running = True
        logger.info(f"{self.__class__.__name__} thread started.")
        while self.running:
            if self._enabled:
                frame_to_process = None
                timestamp_to_process = 0.0
                with QMutexLocker(self._frame_lock):
                    if self.input_frame is not None:
                        frame_to_process = self.input_frame
                        timestamp_to_process = self.frame_timestamp
                        self.input_frame = None # Consume frame

                if frame_to_process is not None and self.model is not None:
                    start_time = time.perf_counter()
                    try:
                        # Perform inference on the correct device
                        results = self.model(
                            frame_to_process,
                            conf=self._confidence_threshold,
                            iou=self._nms_threshold, # NMS relevant for detection, ignored by pose? Check docs.
                            verbose=False,
                            device=self.device
                        )
                        # Process results (implemented in subclasses)
                        self.process_results(results, timestamp_to_process)

                    except Exception as e:
                        logger.error(f"{self.__class__.__name__} error: {e}", exc_info=True)
                        self.status_update.emit(f"{self.__class__.__name__} error: {e}")
                    finally:
                         # Calculate processing time
                        end_time = time.perf_counter()
                        self.processing_time_ms = (end_time - start_time) * 1000

                else:
                    # No frame, sleep briefly
                    self.msleep(5)
            else:
                # Disabled, sleep longer
                self.msleep(50)

        logger.info(f"{self.__class__.__name__} thread stopped.")
        self.model = None
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def process_results(self, results, timestamp):
        """Placeholder: Subclasses must implement this to process model output."""
        raise NotImplementedError

    def stop(self):
        self.running = False

    @Slot(bool)
    def set_enabled(self, enabled):
        self._enabled = enabled
        logger.info(f"{self.__class__.__name__} {'enabled' if enabled else 'disabled'}")
        if not enabled:
             # Clear potential pending frame when disabled
             with QMutexLocker(self._frame_lock):
                 self.input_frame = None

    @Slot(float)
    def update_confidence_threshold(self, threshold):
        self._confidence_threshold = threshold
        logger.info(f"{self.__class__.__name__} confidence threshold updated to {threshold:.2f}")

    @Slot(float)
    def update_nms_threshold(self, threshold):
        # Only relevant for object detection, but keep slot for potential future use
        self._nms_threshold = threshold
        logger.info(f"{self.__class__.__name__} NMS threshold updated to {threshold:.2f}")

    def get_processing_time(self):
        return self.processing_time_ms

# --- Detection Thread ---
class DetectionThread(BaseYoloWorker):
    """Performs object detection using a YOLO model."""
    # Emits: list of (label, confidence, box_tuple), timestamp, processing_time_ms
    detections_ready = Signal(list, float, float)

    def __init__(self, model_name=OBJECT_DETECTION_MODEL):
        super().__init__(model_name)
        self._confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD # Reset specific default
        self._nms_threshold = DEFAULT_NMS_THRESHOLD

    def process_results(self, results, timestamp):
        """Processes YOLO object detection results."""
        detections = []
        if results and results[0]:
            boxes = results[0].boxes
            for box in boxes:
                # Move results to CPU for signal emission
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                confidence = box.conf[0].cpu().numpy()
                class_id = int(box.cls[0].cpu().numpy())
                label = self.model.names.get(class_id, f"ID:{class_id}")
                detections.append((label, float(confidence), (int(x1), int(y1), int(x2), int(y2))))

        self.detections_ready.emit(detections, timestamp, self.get_processing_time())

# --- Pose Estimation Thread ---
class PoseEstimationThread(BaseYoloWorker):
    """Performs pose estimation using a YOLO pose model."""
    # Emits: list of (keypoints_array, box_tuple, confidence), timestamp, processing_time_ms
    # keypoints_array is N x 3 (x, y, confidence) for N keypoints
    poses_ready = Signal(list, float, float)

    def __init__(self, model_name=POSE_ESTIMATION_MODEL):
        super().__init__(model_name)
        self._confidence_threshold = DEFAULT_POSE_CONFIDENCE_THRESHOLD # Use pose-specific threshold
        # NMS is less critical/different for pose, set default but may not be used same way
        self._nms_threshold = DEFAULT_NMS_THRESHOLD

    def process_results(self, results, timestamp):
        """Processes YOLO pose estimation results."""
        poses = []
        if results and results[0]:
            keypoints_list = results[0].keypoints # Access keypoints directly
            boxes = results[0].boxes # Get associated bounding boxes if needed

            for i, kpts in enumerate(keypoints_list):
                # kpts.xy are the coordinates, kpts.conf contains confidences
                # Combine them into an N x 3 array (x, y, conf)
                kpts_data_cpu = kpts.data[0].cpu().numpy() # Get tensor data, move to CPU

                # Extract bounding box and overall confidence for this pose instance
                box_data = boxes[i]
                x1, y1, x2, y2 = box_data.xyxy[0].cpu().numpy()
                pose_confidence = box_data.conf[0].cpu().numpy() # Overall confidence for the detected person/pose

                # Filter keypoints by individual confidence if needed (optional)
                # kpts_data_cpu[kpts_data_cpu[:, 2] < KEYPOINT_VISIBILITY_THRESHOLD] = [0, 0, 0] # Example filtering

                poses.append((kpts_data_cpu, (int(x1), int(y1), int(x2), int(y2)), float(pose_confidence)))

        self.poses_ready.emit(poses, timestamp, self.get_processing_time())

# --- Optical Flow Thread (Remains largely unchanged, ensure set_enabled works) ---
class OpticalFlowThread(QThread):
    """Calculates sparse optical flow (GPU accelerated if OpenCV CUDA is available)."""
    flow_ready = Signal(list) # List of (start_point_tuple, end_point_tuple)
    status_update = Signal(str)

    def __init__(self):
        super().__init__()
        self.running = False
        self._enabled = False
        self._frame_lock = QMutex()
        self.current_frame = None
        self.use_gpu = False
        self.gpu_detector = None
        self.gpu_lk_flow = None
        self.prev_gray_gpu = None
        self.prev_points_gpu = None
        self.cpu_lk_params = dict(winSize=(21, 21), maxLevel=3, # Slightly larger window, more levels
                                  criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03))
        self.cpu_feature_params = dict(maxCorners=100, qualityLevel=0.2, minDistance=7, blockSize=7) # Slightly lower quality level
        self.prev_gray_cpu = None
        self.prev_points_cpu = None
        self.processing_time_ms = 0.0

        # Check for OpenCV CUDA support (No changes needed here)
        try:
            if cv2.cuda.getCudaEnabledDeviceCount() > 0:
                logger.info("OpenCV CUDA detected. Attempting to use GPU for Optical Flow.")
                # Ensure parameters match CPU version if desired
                gpu_feature_params = self.cpu_feature_params.copy()
                gpu_feature_params['qualityLevel'] = float(gpu_feature_params['qualityLevel']) # Ensure float type
                gpu_feature_params['minDistance'] = float(gpu_feature_params['minDistance']) # Ensure float type

                self.gpu_detector = cv2.cuda.createGoodFeaturesToTrackDetector(cv2.CV_8UC1, **gpu_feature_params)
                self.gpu_lk_flow = cv2.cuda.SparsePyrLKOpticalFlow_create(winSize=self.cpu_lk_params['winSize'],
                                                                           maxLevel=self.cpu_lk_params['maxLevel'],
                                                                           iters=self.cpu_lk_params['criteria'][2])
                self.use_gpu = True
                logger.info("Successfully initialized GPU Optical Flow components.")
                self.status_update.emit("Optical Flow: Using GPU")
            else:
                logger.info("No OpenCV CUDA devices found. Using CPU for Optical Flow.")
                self.status_update.emit("Optical Flow: Using CPU")
        except AttributeError:
            logger.warning("OpenCV CUDA module not found or unavailable in this build. Using CPU for Optical Flow.")
            self.status_update.emit("Optical Flow: Using CPU (CUDA module unavailable)")
            self.use_gpu = False
        except Exception as e:
            logger.error(f"Error initializing OpenCV CUDA components: {e}. Falling back to CPU.", exc_info=True)
            self.use_gpu = False
            self.status_update.emit("Optical Flow: Error initializing GPU, using CPU")


    @Slot(np.ndarray, float) # Accept timestamp, though not directly used in flow calc
    def set_frame(self, frame, timestamp):
        with QMutexLocker(self._frame_lock):
            self.current_frame = frame.copy()

    def run(self):
        self.running = True
        logger.info(f"Optical Flow thread started (GPU Enabled: {self.use_gpu}).")
        while self.running:
            if self._enabled:
                frame = None
                with QMutexLocker(self._frame_lock):
                    if self.current_frame is not None:
                        frame = self.current_frame
                        self.current_frame = None

                if frame is not None:
                    start_time = time.perf_counter()
                    try:
                        if self.use_gpu:
                            self.run_gpu(frame)
                        else:
                            self.run_cpu(frame)
                    except cv2.error as e:
                         logger.error(f"OpenCV error in Optical Flow: {e}", exc_info=True)
                         self.status_update.emit(f"Optical Flow Error: {e}")
                         self.reset_state()
                         self.msleep(100)
                    except Exception as e:
                         logger.error(f"Unexpected error in Optical Flow: {e}", exc_info=True)
                         self.status_update.emit(f"Optical Flow Error: {e}")
                         self.reset_state()
                         self.msleep(100)
                    finally:
                        end_time = time.perf_counter()
                        self.processing_time_ms = (end_time - start_time) * 1000
                else:
                     self.msleep(5)
            else:
                # Reset state only if it was previously enabled
                if self.prev_gray_cpu is not None or self.prev_gray_gpu is not None:
                     self.reset_state()
                self.msleep(50)

        logger.info("Optical Flow thread stopped.")
        self.reset_state()

    def run_gpu(self, frame):
        frame_gpu = cv2.cuda_GpuMat()
        frame_gpu.upload(frame)
        gray_gpu = cv2.cuda.cvtColor(frame_gpu, cv2.COLOR_BGR2GRAY)
        flow_vectors = []
        points_to_track = False

        if self.prev_gray_gpu is not None and self.prev_points_gpu is not None and not self.prev_points_gpu.empty():
            points_to_track = True
            # Ensure prev_points_gpu is float32 (should be from detector, but double-check)
            if self.prev_points_gpu.type() != cv2.CV_32FC2:
                 # This indicates an issue, log warning and attempt conversion
                 logger.warning("prev_points_gpu is not CV_32FC2, attempting conversion.")
                 temp_cpu = self.prev_points_gpu.download().astype(np.float32)
                 temp_gpu = cv2.cuda_GpuMat()
                 temp_gpu.upload(temp_cpu)
                 self.prev_points_gpu = temp_gpu
                 if self.prev_points_gpu.empty(): # Check if conversion failed
                      points_to_track = False
                      logger.error("Failed to convert prev_points_gpu to CV_32FC2.")


            if points_to_track:
                next_points_gpu, status_gpu, err_gpu = self.gpu_lk_flow.calc(self.prev_gray_gpu, gray_gpu, self.prev_points_gpu, None)

                if next_points_gpu is not None and status_gpu is not None:
                    status = status_gpu.download().flatten()
                    # Filter based on status BEFORE downloading points for efficiency
                    prev_points_gpu_filtered = self.prev_points_gpu[status == 1]
                    next_points_gpu_filtered = next_points_gpu[status == 1]


                    if not next_points_gpu_filtered.empty() and not prev_points_gpu_filtered.empty():
                        good_new = next_points_gpu_filtered.download().reshape(-1, 2)
                        good_old = prev_points_gpu_filtered.download().reshape(-1, 2)

                        flow_vectors = [(tuple(map(int, p)), tuple(map(int, q))) for p, q in zip(good_old, good_new)]
                        self.prev_points_gpu = next_points_gpu_filtered # Update with the successfully tracked points (already filtered)
                    else:
                        points_to_track = False # Lost all points
                        self.prev_points_gpu = None # Reset points
                else:
                    points_to_track = False
                    self.prev_points_gpu = None

        # Detect new features if needed
        # Only detect if we have too few points or lost track
        detect_new = False
        if not points_to_track:
            detect_new = True
        elif self.prev_points_gpu is not None and self.prev_points_gpu.rows() < self.cpu_feature_params['maxCorners'] * 0.5: # Redetect if points drop below 50%
             detect_new = True


        if detect_new:
            # logger.debug("Detecting new features (GPU)...") # Optional debug log
            # Mask is None here, could potentially mask out areas with existing detections
            detected_points_gpu_mat = self.gpu_detector.detect(gray_gpu, None)

            if detected_points_gpu_mat is not None and not detected_points_gpu_mat.empty():
                # Convert detected points (keypoints) to the format needed by LK (float32 points)
                # Keypoints might be N x 1 GpuMat of KeyPoint objects, need conversion
                # Or N x 1 x 2 float32 points directly, depending on OpenCV version/detector
                # Assuming N x 1 x 2 float32 format here based on typical LK input needs
                # If it's KeyPoint objects, conversion is needed:
                # keypoints_cpu = detected_points_gpu_mat.download() # Download if KeyPoint objects
                # points_cpu = cv2.KeyPoint_convert(keypoints_cpu) # Convert KeyPoints to points
                # points_gpu_float32 = cv2.cuda_GpuMat()
                # points_gpu_float32.upload(points_cpu.astype(np.float32))
                # self.prev_points_gpu = points_gpu_float32.reshape(-1, 1, 2) # Ensure shape

                # Simpler approach if detector directly gives points:
                 self.prev_points_gpu = detected_points_gpu_mat.reshape(-1, 1, 2) # Ensure shape
                 if self.prev_points_gpu.type() != cv2.CV_32FC2:
                     # Convert if necessary
                     temp_cpu = self.prev_points_gpu.download().astype(np.float32)
                     temp_gpu = cv2.cuda_GpuMat()
                     temp_gpu.upload(temp_cpu)
                     self.prev_points_gpu = temp_gpu


                 # logger.debug(f"Detected {self.prev_points_gpu.rows()} new features (GPU).")
            else:
                self.prev_points_gpu = None # No features found
                # logger.debug("No new features detected (GPU).")


        self.prev_gray_gpu = gray_gpu
        if flow_vectors:
            self.flow_ready.emit(flow_vectors)


    def run_cpu(self, frame):
        gray_cpu = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        flow_vectors = []
        points_to_track = False

        if self.prev_gray_cpu is not None and self.prev_points_cpu is not None and len(self.prev_points_cpu) > 0:
            points_to_track = True
            next_points_cpu, status, err = cv2.calcOpticalFlowPyrLK(
                self.prev_gray_cpu, gray_cpu, self.prev_points_cpu, None, **self.cpu_lk_params
            )

            if next_points_cpu is not None and status is not None:
                good_new = next_points_cpu[status == 1]
                good_old = self.prev_points_cpu[status == 1]

                if len(good_new) > 0:
                     flow_vectors = [(tuple(map(int, p)), tuple(map(int, q))) for p, q in zip(good_old, good_new)]
                     self.prev_points_cpu = good_new.reshape(-1, 1, 2) # Update with tracked points
                else:
                    points_to_track = False # Lost all points
                    self.prev_points_cpu = None # Reset points
            else:
                points_to_track = False
                self.prev_points_cpu = None


        # Detect new features if needed
        detect_new = False
        if not points_to_track:
            detect_new = True
        elif self.prev_points_cpu is not None and len(self.prev_points_cpu) < self.cpu_feature_params['maxCorners'] * 0.5:
            detect_new = True

        if detect_new:
            # logger.debug("Detecting new features (CPU)...")
            self.prev_points_cpu = cv2.goodFeaturesToTrack(gray_cpu, mask=None, **self.cpu_feature_params)
            # if self.prev_points_cpu is not None:
            #      logger.debug(f"Detected {len(self.prev_points_cpu)} new features (CPU).")
            # else:
            #      logger.debug("No new features detected (CPU).")


        self.prev_gray_cpu = gray_cpu.copy()
        if flow_vectors:
            self.flow_ready.emit(flow_vectors)


    def stop(self):
        self.running = False

    def reset_state(self):
        """Resets the internal state (both CPU and GPU)."""
        self.prev_gray_gpu = None
        self.prev_points_gpu = None
        self.prev_gray_cpu = None
        self.prev_points_cpu = None
        # logger.debug("Optical flow state reset.") # Optional debug log

    @Slot(bool)
    def set_enabled(self, enabled):
        if self._enabled != enabled:
            self._enabled = enabled
            logger.info(f"Optical Flow {'enabled' if enabled else 'disabled'}")
            if not enabled:
                # Reset state when disabling to clear old points/frames
                self.reset_state()

    def get_processing_time(self):
        return self.processing_time_ms


# --- Depth Estimation Thread (Remains largely unchanged, ensure set_enabled works) ---
class DepthEstimationThread(QThread):
    """Performs depth estimation using a MiDaS model (GPU if available)."""
    depth_ready = Signal(np.ndarray) # Emits normalized depth map (0-1, CPU numpy array)
    status_update = Signal(str)

    def __init__(self, model_type="MiDaS_small"):
        super().__init__()
        self.model_type = model_type
        self.model = None
        self.transform = None
        self.running = False
        self._enabled = False
        self._frame_lock = QMutex()
        self.current_frame = None
        self.device = None
        self.processing_time_ms = 0.0

    def load_model(self):
        self.status_update.emit(f"Loading MiDaS model: {self.model_type}...")
        logger.info(f"Attempting to load MiDaS model: {self.model_type}")
        try:
            # Determine device
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
                logger.info("Using GPU for MiDaS.")
            else:
                self.device = torch.device("cpu")
                logger.info("Using CPU for MiDaS.")

            # Load model and transform from PyTorch Hub
            self.model = torch.hub.load("intel-isl/MiDaS", self.model_type, trust_repo=True)
            midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)

            if self.model_type == "MiDaS_small":
                 self.transform = midas_transforms.small_transform
            elif "dpt_large" in self.model_type.lower() or "dpt_hybrid" in self.model_type.lower():
                 self.transform = midas_transforms.dpt_transform
            else: # Default or other MiDaS v2.1 models
                 self.transform = midas_transforms.dpt_transform # DPT transform often works well

            self.model.to(self.device)
            self.model.eval() # Set model to evaluation mode
            logger.info(f"MiDaS model '{self.model_type}' loaded on {self.device}.")
            self.status_update.emit(f"MiDaS model '{self.model_type}' loaded on {self.device}.")
            return True

        except Exception as e:
            error_msg = f"Failed to load MiDaS model '{self.model_type}': {e}"
            logger.error(error_msg, exc_info=True)
            self.status_update.emit(error_msg)
            self.model = None
            self.transform = None
            self.device = None
            return False

    @Slot(np.ndarray, float) # Accept timestamp, though not directly used
    def set_frame(self, frame, timestamp):
        with QMutexLocker(self._frame_lock):
            self.current_frame = frame.copy()

    def run(self):
        if not self.load_model():
            self.running = False
            return

        self.running = True
        logger.info("Depth Estimation thread started.")
        while self.running:
            if self._enabled:
                frame_to_process = None
                with QMutexLocker(self._frame_lock):
                    if self.current_frame is not None:
                        frame_to_process = self.current_frame
                        self.current_frame = None

                if frame_to_process is not None and self.model is not None and self.transform is not None:
                    start_time = time.perf_counter()
                    try:
                        # Preprocess frame (runs on CPU)
                        # Ensure input is RGB for MiDaS transforms
                        if frame_to_process.shape[2] == 3: # BGR
                            img_rgb = cv2.cvtColor(frame_to_process, cv2.COLOR_BGR2RGB)
                        else: # Grayscale or other? Skip if not BGR/RGB
                             logger.warning("Depth estimation requires BGR input frame.")
                             continue

                        input_batch = self.transform(img_rgb).to(self.device) # Transform and move to device

                        with torch.no_grad(): # Inference without gradient calculation
                            prediction = self.model(input_batch)

                            # Resize prediction to original image size
                            prediction = torch.nn.functional.interpolate(
                                prediction.unsqueeze(1),
                                size=img_rgb.shape[:2], # Use original RGB frame height, width
                                mode="bicubic",        # Smoother interpolation
                                align_corners=False,
                            ).squeeze()

                        # Move depth map to CPU and normalize
                        depth_map = prediction.cpu().numpy()
                        # Normalize to 0-1 range for visualization
                        normalized_depth = cv2.normalize(depth_map, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)

                        self.depth_ready.emit(normalized_depth)

                    except Exception as e:
                        logger.error(f"Depth estimation error: {e}", exc_info=True)
                        self.status_update.emit(f"Depth estimation error: {e}")
                    finally:
                        end_time = time.perf_counter()
                        self.processing_time_ms = (end_time - start_time) * 1000
                else:
                    self.msleep(5)
            else:
                 # Clear potential pending frame when disabled
                 with QMutexLocker(self._frame_lock):
                      self.current_frame = None
                 self.msleep(50)

        logger.info("Depth Estimation thread stopped.")
        self.model = None
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def stop(self):
        self.running = False

    @Slot(bool)
    def set_enabled(self, enabled):
        if self._enabled != enabled:
            self._enabled = enabled
            logger.info(f"Depth Estimation {'enabled' if enabled else 'disabled'}")
            if not enabled:
                 # Clear potential pending frame when disabled
                 with QMutexLocker(self._frame_lock):
                      self.current_frame = None

    def get_processing_time(self):
        return self.processing_time_ms


# --- Overlay Widget ---
class OverlayWidget(QWidget):
    """Displays the transparent overlay with detections, paths, etc."""
    def __init__(self, monitor_rect, parent=None):
        super().__init__(parent)
        # Make window transparent, frameless, and stay on top
        self.setWindowFlags(Qt.FramelessWindowHint | Qt.WindowStaysOnTopHint | Qt.Tool)
        self.setAttribute(Qt.WA_TranslucentBackground)
        self.setAttribute(Qt.WA_NoSystemBackground, True)
        self.setAttribute(Qt.WA_PaintOnScreen) # May improve performance

        # Geometry needs to match the captured monitor exactly
        self.monitor_rect = monitor_rect
        self.setGeometry(monitor_rect.left(), monitor_rect.top(), monitor_rect.width(), monitor_rect.height())

        # --- Font Setup ---
        font_id = QFontDatabase.addApplicationFont("PressStart2P-Regular.ttf") # Ensure font file is present
        if font_id != -1:
            font_families = QFontDatabase.applicationFontFamilies(font_id)
            if font_families:
                self.hud_font_name = font_families[0]
                logger.info(f"Loaded font: {self.hud_font_name}")
            else:
                self.hud_font_name = FALLBACK_FONT
                logger.warning("Could not get font family name, using fallback.")
        else:
            self.hud_font_name = FALLBACK_FONT
            logger.warning("Could not load font 'PressStart2P-Regular.ttf', using fallback.")

        self.font_small = QFont(self.hud_font_name, FONT_SIZE_SMALL)
        self.font_medium = QFont(self.hud_font_name, FONT_SIZE_MEDIUM)

        # --- Data Storage ---
        self.detections = []
        self.poses = []
        self.flow_vectors = []
        self.depth_map = None
        self.hud_texts = {}
        self.target_history = collections.deque(maxlen=PATH_HISTORY_LENGTH)
        self.smoothed_target_center = None # QPointF for smoothed crosshair
        self.smoothed_velocity = QPointF(0, 0) # Smoothed velocity vector

        # --- Drawing Flags ---
        self.show_boxes = True
        self.show_labels = True
        self.show_paths = True
        self.show_trajectory = True
        self.show_crosshair = True
        self.show_poses = True # Flag for drawing poses
        self.show_flow = False
        self.show_depth = False
        self.show_hud = True

        # --- Scanline Effect ---
        self.scanline_y = 0
        self.scanline_timer = QTimer(self)
        self.scanline_timer.timeout.connect(self.update_scanline)
        self.scanline_timer.start(SCANLINE_SPEED_MS)

        # --- HUD Update Timer ---
        self.hud_update_timer = QTimer(self)
        self.hud_update_timer.timeout.connect(self.update) # Trigger repaint for HUD updates
        self.hud_update_timer.start(UPDATE_INTERVAL_MS)

        # --- Performance Metrics ---
        self.capture_fps = 0.0
        self.detection_fps = 0.0
        self.pose_fps = 0.0
        self.flow_fps = 0.0
        self.depth_fps = 0.0
        self.last_capture_time = time.perf_counter()
        self.last_detection_time = time.perf_counter()
        self.last_pose_time = time.perf_counter()
        self.last_flow_time = time.perf_counter()
        self.last_depth_time = time.perf_counter()
        self.detection_proc_time = 0.0
        self.pose_proc_time = 0.0
        self.flow_proc_time = 0.0
        self.depth_proc_time = 0.0


    def update_scanline(self):
        """Updates the position of the scanline effect."""
        self.scanline_y = (self.scanline_y + 5) % self.height()
        self.update() # Trigger repaint

    @Slot(list, float, float)
    def update_detections(self, detections, timestamp, proc_time):
        """Receives detection results from the detection thread."""
        self.detections = detections
        self.detection_proc_time = proc_time

        # --- Target Selection and Smoothing ---
        # Example: Target the largest detected "person" box
        target_box = None
        max_area = 0
        person_detections = [d for d in detections if d[0] == 'person'] # Filter for persons

        if person_detections:
            for label, conf, box in person_detections:
                x1, y1, x2, y2 = box
                area = (x2 - x1) * (y2 - y1)
                if area > max_area:
                    max_area = area
                    target_box = box

        # Calculate raw center
        raw_target_center = None
        if target_box:
            x1, y1, x2, y2 = target_box
            raw_target_center = QPointF((x1 + x2) / 2, (y1 + y2) / 2)

        # --- Exponential Moving Average (EMA) for Smoothing ---
        if raw_target_center:
            if self.smoothed_target_center is None:
                # Initialize smoothing on the first valid detection
                self.smoothed_target_center = raw_target_center
            else:
                # Apply EMA: smooth = alpha * new + (1 - alpha) * old
                self.smoothed_target_center = (TARGET_CENTER_SMOOTHING_FACTOR * raw_target_center +
                                              (1.0 - TARGET_CENTER_SMOOTHING_FACTOR) * self.smoothed_target_center)

            # --- Update Target History and Velocity ---
            current_time = timestamp # Use the timestamp from the detection result
            self.target_history.append((current_time, self.smoothed_target_center)) # Store smoothed center

            # Calculate smoothed velocity for trajectory prediction
            if len(self.target_history) >= 2:
                 # Use last few points for velocity calculation
                 points_to_use = min(TRAJECTORY_PREDICTION_POINTS, len(self.target_history))
                 recent_points = list(self.target_history)[-points_to_use:]

                 if len(recent_points) >= 2:
                      # Calculate average velocity over the recent points
                      total_dx = 0
                      total_dy = 0
                      total_dt = 0
                      for i in range(len(recent_points) - 1):
                           t1, p1 = recent_points[i]
                           t2, p2 = recent_points[i+1]
                           dt = t2 - t1
                           if dt > 1e-6: # Avoid division by zero
                                total_dx += (p2.x() - p1.x())
                                total_dy += (p2.y() - p1.y())
                                total_dt += dt

                      if total_dt > 1e-6:
                           raw_velocity = QPointF((total_dx / total_dt), (total_dy / total_dt))
                           # Smooth the velocity vector itself using EMA
                           self.smoothed_velocity = (TRAJECTORY_SMOOTHING_FACTOR * raw_velocity +
                                                     (1.0 - TRAJECTORY_SMOOTHING_FACTOR) * self.smoothed_velocity)


        else:
            # No target detected this frame, maybe slowly decay velocity?
            # Or keep last known velocity? For now, keep it.
            # self.smoothed_target_center = None # Optional: Clear smoothed if no raw target
            pass # Keep last smoothed position if no new raw target


        # --- FPS Calculation ---
        now = time.perf_counter()
        time_diff = now - self.last_detection_time
        if time_diff > 0:
            self.detection_fps = 1.0 / time_diff
        self.last_detection_time = now

        self.update() # Trigger repaint

    @Slot(list, float, float)
    def update_poses(self, poses, timestamp, proc_time):
        """Receives pose estimation results."""
        self.poses = poses
        self.pose_proc_time = proc_time
        # --- FPS Calculation ---
        now = time.perf_counter()
        time_diff = now - self.last_pose_time
        if time_diff > 0:
            self.pose_fps = 1.0 / time_diff
        self.last_pose_time = now
        if self.show_poses: # Only repaint if poses are visible
             self.update()

    @Slot(list)
    def update_flow(self, flow_vectors):
        """Receives optical flow results."""
        self.flow_vectors = flow_vectors
        # --- FPS Calculation (Approximate based on signal arrival) ---
        now = time.perf_counter()
        time_diff = now - self.last_flow_time
        if time_diff > 0:
            self.flow_fps = 1.0 / time_diff
        self.last_flow_time = now
        if self.show_flow: # Only repaint if flow is visible
            self.update()

    @Slot(np.ndarray)
    def update_depth(self, depth_map):
        """Receives depth estimation results."""
        self.depth_map = depth_map
        # --- FPS Calculation (Approximate based on signal arrival) ---
        now = time.perf_counter()
        time_diff = now - self.last_depth_time
        if time_diff > 0:
            self.depth_fps = 1.0 / time_diff
        self.last_depth_time = now
        if self.show_depth: # Only repaint if depth is visible
            self.update()

    @Slot(dict)
    def update_hud(self, hud_data):
        """Receives text data for the HUD."""
        self.hud_texts.update(hud_data)
        # No repaint needed here, hud_update_timer handles it

    @Slot(float)
    def update_capture_fps(self, fps):
        """Receives capture FPS."""
        self.capture_fps = fps

    # --- Toggling drawing elements ---
    @Slot(bool)
    def toggle_boxes(self, show): self.show_boxes = show; self.update()
    @Slot(bool)
    def toggle_labels(self, show): self.show_labels = show; self.update()
    @Slot(bool)
    def toggle_paths(self, show): self.show_paths = show; self.update()
    @Slot(bool)
    def toggle_trajectory(self, show): self.show_trajectory = show; self.update()
    @Slot(bool)
    def toggle_crosshair(self, show): self.show_crosshair = show; self.update()
    @Slot(bool)
    def toggle_poses(self, show): self.show_poses = show; self.update() # Slot for pose visibility
    @Slot(bool)
    def toggle_flow(self, show): self.show_flow = show; self.update()
    @Slot(bool)
    def toggle_depth(self, show): self.show_depth = show; self.update()
    @Slot(bool)
    def toggle_hud(self, show): self.show_hud = show; self.update()


    def paintEvent(self, event):
        """Draws all overlay elements."""
        painter = QPainter(self)
        painter.setRenderHint(QPainter.Antialiasing)

        # --- Clear Background (Important for transparency) ---
        painter.fillRect(self.rect(), Qt.transparent)

        # --- 1. Depth Map (Draw first, as background) ---
        if self.show_depth and self.depth_map is not None:
            try:
                # Convert normalized float32 depth map to 8-bit grayscale QImage
                depth_8bit = (self.depth_map * 255).astype(np.uint8)
                h, w = depth_8bit.shape
                q_image = QImage(depth_8bit.data, w, h, w, QImage.Format_Grayscale8)
                # Create QPixmap from QImage
                pixmap = QPixmap.fromImage(q_image)
                # Scale pixmap to fit widget size while maintaining aspect ratio
                scaled_pixmap = pixmap.scaled(self.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
                # Center the pixmap
                x_offset = (self.width() - scaled_pixmap.width()) / 2
                y_offset = (self.height() - scaled_pixmap.height()) / 2
                painter.drawPixmap(x_offset, y_offset, scaled_pixmap)
            except Exception as e:
                logger.error(f"Error drawing depth map: {e}", exc_info=True)


        # --- 2. Scanline Effect ---
        painter.setPen(QPen(QColor(0, 255, 0, 50), 2)) # Semi-transparent green line
        painter.drawLine(0, self.scanline_y, self.width(), self.scanline_y)

        # --- 3. Optical Flow Vectors ---
        if self.show_flow and self.flow_vectors:
            painter.setPen(QPen(YELLOW_COLOR, 1))
            for start_pt, end_pt in self.flow_vectors:
                painter.drawLine(QPoint(*start_pt), QPoint(*end_pt))
                # Optionally draw points
                # painter.drawPoint(QPoint(*end_pt))


        # --- 4. Detections (Boxes and Labels) ---
        if self.detections:
            painter.setFont(self.font_small)
            for label, confidence, box in self.detections:
                x1, y1, x2, y2 = box
                box_width = x2 - x1
                box_height = y2 - y1

                if self.show_boxes:
                    painter.setPen(QPen(GREEN_COLOR, 2))
                    painter.drawRect(x1, y1, box_width, box_height)

                if self.show_labels:
                    painter.setPen(GREEN_COLOR) # Pen for text outline (optional)
                    painter.setBrush(GREEN_COLOR) # Brush for text fill
                    text = f"{label} ({confidence:.2f})"
                    # Simple text background
                    # metrics = painter.fontMetrics()
                    # text_width = metrics.horizontalAdvance(text)
                    # text_height = metrics.height()
                    # painter.fillRect(x1, y1 - text_height, text_width + 2, text_height, QColor(0,0,0,150))
                    painter.drawText(x1 + 2, y1 - 5, text) # Position label above box


        # --- 5. Pose Estimation Skeletons ---
        if self.show_poses and self.poses:
             painter.setPen(QPen(CYAN_COLOR, 2)) # Default color for keypoints
             for keypoints_data, box, pose_conf in self.poses:
                  # keypoints_data is N x 3 (x, y, confidence)
                  num_keypoints = keypoints_data.shape[0]
                  points = [] # Store QPointF for drawing connections
                  # Draw Keypoints
                  for i in range(num_keypoints):
                       x, y, conf = keypoints_data[i]
                       if conf > DEFAULT_POSE_CONFIDENCE_THRESHOLD: # Only draw visible keypoints
                            pt = QPointF(x, y)
                            points.append((i, pt)) # Store index and point
                            painter.setBrush(POSE_COLORS[i % len(POSE_COLORS)]) # Cycle through colors
                            painter.setPen(Qt.NoPen) # No outline for points
                            painter.drawEllipse(pt, 3, 3) # Draw small circle for keypoint
                       else:
                            points.append((i, None)) # Mark invisible points

                  # Draw Connections (Skeleton)
                  painter.setPen(QPen(MAGENTA_COLOR, 1)) # Color for lines
                  points_dict = {idx: pt for idx, pt in points if pt is not None} # Quick lookup for visible points
                  for i, (start_idx, end_idx) in enumerate(POSE_CONNECTIONS):
                       if start_idx in points_dict and end_idx in points_dict:
                            pt1 = points_dict[start_idx]
                            pt2 = points_dict[end_idx]
                            # Optional: Color lines differently
                            # painter.setPen(QPen(POSE_COLORS[i % len(POSE_COLORS)], 1))
                            painter.drawLine(pt1, pt2)


        # --- 6. Target Path History ---
        if self.show_paths and len(self.target_history) > 1:
            painter.setPen(QPen(QColor(255, 165, 0, 180), 1)) # Orange, semi-transparent
            path_points = [p for t, p in self.target_history]
            poly = QPolygonF(path_points)
            painter.drawPolyline(poly)


        # --- 7. Predicted Trajectory ---
        if self.show_trajectory and self.smoothed_target_center and len(self.target_history) >= 2:
            painter.setPen(QPen(QColor(255, 0, 255, 200), 2, Qt.DashLine)) # Magenta dashed line
            # Predict future position based on smoothed velocity
            start_point = self.smoothed_target_center
            # Predict N steps into the future based on smoothed velocity and duration
            # Note: This is a simple linear prediction. More complex models exist.
            end_point = start_point + self.smoothed_velocity * TRAJECTORY_PREDICTION_DURATION
            painter.drawLine(start_point, end_point)


        # --- 8. Smoothed Crosshair ---
        if self.show_crosshair and self.smoothed_target_center:
            painter.setPen(QPen(RED_COLOR, 1))
            # Draw crosshair at the smoothed target center
            cx = int(self.smoothed_target_center.x())
            cy = int(self.smoothed_target_center.y())
            size = 10 # Size of the crosshair lines
            painter.drawLine(cx - size, cy, cx + size, cy) # Horizontal line
            painter.drawLine(cx, cy - size, cx, cy + size) # Vertical line
            # Optional: Draw circle at center
            # painter.drawEllipse(self.smoothed_target_center, 3, 3)

            # --- Display Crosshair Calculation Info (Example in HUD instead) ---
            # You could draw text near the crosshair, but HUD is cleaner:
            # painter.setFont(self.font_small)
            # painter.setPen(RED_COLOR)
            # painter.drawText(cx + 15, cy - 15, f"Smooth X: {cx}")
            # painter.drawText(cx + 15, cy, f"Smooth Y: {cy}")
            # raw_center = self.target_history[-1][1] if self.target_history else QPointF(0,0) # Get latest raw for comparison
            # painter.drawText(cx + 15, cy + 15, f"Raw X: {int(raw_center.x())}")
            # painter.drawText(cx + 15, cy + 30, f"Raw Y: {int(raw_center.y())}")


        # --- 9. HUD Text ---
        if self.show_hud:
            painter.setFont(self.font_medium)
            painter.setPen(TEXT_COLOR)
            y_offset = 20 # Starting Y position for HUD text

            # Display basic info
            painter.drawText(10, y_offset, f"Cap: {self.capture_fps:.1f} FPS")
            y_offset += 15
            painter.drawText(10, y_offset, f"Det: {self.detection_fps:.1f} FPS ({self.detection_proc_time:.1f} ms)")
            y_offset += 15
            if self.show_poses: # Only show pose info if enabled
                 painter.drawText(10, y_offset, f"Pose: {self.pose_fps:.1f} FPS ({self.pose_proc_time:.1f} ms)")
                 y_offset += 15
            if self.show_flow: # Only show flow info if enabled
                 painter.drawText(10, y_offset, f"Flow: {self.flow_fps:.1f} FPS ({self.flow_proc_time:.1f} ms)") # Assuming you add flow_proc_time
                 y_offset += 15
            if self.show_depth: # Only show depth info if enabled
                 painter.drawText(10, y_offset, f"Depth: {self.depth_fps:.1f} FPS ({self.depth_proc_time:.1f} ms)") # Assuming you add depth_proc_time
                 y_offset += 15


            # Display system info from hud_texts dictionary
            for key, value in self.hud_texts.items():
                painter.drawText(10, y_offset, f"{key}: {value}")
                y_offset += 15

            # Display Crosshair Smoothing Info in HUD
            if self.show_crosshair and self.smoothed_target_center:
                 sx, sy = int(self.smoothed_target_center.x()), int(self.smoothed_target_center.y())
                 painter.drawText(10, y_offset, f"Crosshair (Smooth): {sx}, {sy}")
                 y_offset += 15
                 # Display raw if available for comparison
                 # raw_center = QPointF(0,0)
                 # if self.target_history:
                 #      # Find the most recent non-smoothed center if possible (difficult with current EMA)
                 #      # Simplification: Show the latest smoothed point again or leave raw out
                 #      pass
                 # painter.drawText(10, y_offset, f"Crosshair (Raw): {int(raw_center.x())}, {int(raw_center.y())}")
                 # y_offset += 15


        painter.end()


# --- Main Application Window ---
class MainWindow(QMainWindow):
    """Main application window with controls."""
    # Signals to control worker threads
    capture_interval_changed = Signal(int)
    detection_enabled_changed = Signal(bool)
    detection_conf_changed = Signal(float)
    detection_nms_changed = Signal(float)
    pose_enabled_changed = Signal(bool) # Signal for pose estimation enable/disable
    pose_conf_changed = Signal(float)   # Signal for pose confidence threshold
    flow_enabled_changed = Signal(bool)
    depth_enabled_changed = Signal(bool)

    # Signals to update overlay drawing toggles
    toggle_boxes_signal = Signal(bool)
    toggle_labels_signal = Signal(bool)
    toggle_paths_signal = Signal(bool)
    toggle_trajectory_signal = Signal(bool)
    toggle_crosshair_signal = Signal(bool)
    toggle_poses_signal = Signal(bool) # Signal for toggling pose drawing
    toggle_flow_signal = Signal(bool)
    toggle_depth_signal = Signal(bool)
    toggle_hud_signal = Signal(bool)

    # Signal to send HUD data
    hud_data_signal = Signal(dict)
    capture_fps_signal = Signal(float)


    def __init__(self):
        super().__init__()
        self.setWindowTitle("Real-Time Detection Overlay")
        self.setGeometry(100, 100, 450, 700) # Main window size

        # --- Monitor Selection ---
        self.monitors = mss.mss().monitors[1:] # Exclude the 'all monitors' entry
        self.selected_monitor_spec = self.monitors[0] # Default to the first monitor

        # --- Central Widget and Layout ---
        self.central_widget = QWidget()
        self.main_layout = QVBoxLayout(self.central_widget)
        self.setCentralWidget(self.central_widget)

        # --- Control Panel ---
        self.control_group = QGroupBox("Controls")
        self.control_layout = QGridLayout()
        self.control_group.setLayout(self.control_layout)
        self.main_layout.addWidget(self.control_group)

        # Monitor Selection
        self.monitor_combo = QComboBox()
        self.monitor_combo.addItems([f"Monitor {i+1} ({m['width']}x{m['height']})" for i, m in enumerate(self.monitors)])
        self.monitor_combo.currentIndexChanged.connect(self.update_monitor)
        self.control_layout.addWidget(QLabel("Target Monitor:"), 0, 0)
        self.control_layout.addWidget(self.monitor_combo, 0, 1, 1, 2) # Span 2 columns

        # Capture Interval
        self.interval_spinbox = QSpinBox()
        self.interval_spinbox.setRange(1, 1000) # 1ms to 1s
        self.interval_spinbox.setValue(CAPTURE_INTERVAL_MS)
        self.interval_spinbox.setSuffix(" ms")
        self.interval_spinbox.valueChanged.connect(self.capture_interval_changed.emit)
        self.control_layout.addWidget(QLabel("Capture Interval:"), 1, 0)
        self.control_layout.addWidget(self.interval_spinbox, 1, 1, 1, 2)

        # --- Detection Controls ---
        self.detection_group = QGroupBox("Object Detection (YOLOv8n)")
        self.detection_layout = QGridLayout()
        self.detection_group.setLayout(self.detection_layout)
        self.main_layout.addWidget(self.detection_group)

        self.detection_enabled_check = QCheckBox("Enable Detection")
        self.detection_enabled_check.setChecked(True)
        self.detection_enabled_check.toggled.connect(self.detection_enabled_changed.emit)
        self.detection_layout.addWidget(self.detection_enabled_check, 0, 0, 1, 3) # Span columns

        self.conf_spinbox = QDoubleSpinBox()
        self.conf_spinbox.setRange(0.01, 1.0)
        self.conf_spinbox.setSingleStep(0.05)
        self.conf_spinbox.setValue(DEFAULT_CONFIDENCE_THRESHOLD)
        self.conf_spinbox.valueChanged.connect(self.detection_conf_changed.emit)
        self.detection_layout.addWidget(QLabel("Confidence Threshold:"), 1, 0)
        self.detection_layout.addWidget(self.conf_spinbox, 1, 1, 1, 2)

        self.nms_spinbox = QDoubleSpinBox()
        self.nms_spinbox.setRange(0.01, 1.0)
        self.nms_spinbox.setSingleStep(0.05)
        self.nms_spinbox.setValue(DEFAULT_NMS_THRESHOLD)
        self.nms_spinbox.valueChanged.connect(self.detection_nms_changed.emit)
        self.detection_layout.addWidget(QLabel("NMS Threshold:"), 2, 0)
        self.detection_layout.addWidget(self.nms_spinbox, 2, 1, 1, 2)

        # --- Pose Estimation Controls ---
        self.pose_group = QGroupBox("Pose Estimation (YOLOv8n-Pose)")
        self.pose_layout = QGridLayout()
        self.pose_group.setLayout(self.pose_layout)
        self.main_layout.addWidget(self.pose_group)

        self.pose_enabled_check = QCheckBox("Enable Pose Estimation")
        self.pose_enabled_check.setChecked(False) # Default disabled
        self.pose_enabled_check.toggled.connect(self.pose_enabled_changed.emit)
        self.pose_enabled_check.toggled.connect(self.toggle_poses_signal.emit) # Also toggle drawing
        self.pose_layout.addWidget(self.pose_enabled_check, 0, 0, 1, 3)

        self.pose_conf_spinbox = QDoubleSpinBox()
        self.pose_conf_spinbox.setRange(0.01, 1.0)
        self.pose_conf_spinbox.setSingleStep(0.05)
        self.pose_conf_spinbox.setValue(DEFAULT_POSE_CONFIDENCE_THRESHOLD)
        self.pose_conf_spinbox.valueChanged.connect(self.pose_conf_changed.emit)
        self.pose_layout.addWidget(QLabel("Pose Conf Threshold:"), 1, 0)
        self.pose_layout.addWidget(self.pose_conf_spinbox, 1, 1, 1, 2)


        # --- Other Feature Controls ---
        self.features_group = QGroupBox("Other Features")
        self.features_layout = QGridLayout()
        self.features_group.setLayout(self.features_layout)
        self.main_layout.addWidget(self.features_group)

        self.flow_enabled_check = QCheckBox("Enable Optical Flow")
        self.flow_enabled_check.setChecked(False)
        self.flow_enabled_check.toggled.connect(self.flow_enabled_changed.emit)
        self.flow_enabled_check.toggled.connect(self.toggle_flow_signal.emit)
        self.features_layout.addWidget(self.flow_enabled_check, 0, 0)

        self.depth_enabled_check = QCheckBox("Enable Depth Estimation")
        self.depth_enabled_check.setChecked(False)
        self.depth_enabled_check.toggled.connect(self.depth_enabled_changed.emit)
        self.depth_enabled_check.toggled.connect(self.toggle_depth_signal.emit)
        self.features_layout.addWidget(self.depth_enabled_check, 0, 1)


        # --- Overlay Visibility Controls ---
        self.visibility_group = QGroupBox("Overlay Visibility")
        self.visibility_layout = QGridLayout()
        self.visibility_group.setLayout(self.visibility_layout)
        self.main_layout.addWidget(self.visibility_group)

        self.show_boxes_check = QCheckBox("Show Boxes")
        self.show_boxes_check.setChecked(True)
        self.show_boxes_check.toggled.connect(self.toggle_boxes_signal.emit)
        self.visibility_layout.addWidget(self.show_boxes_check, 0, 0)

        self.show_labels_check = QCheckBox("Show Labels")
        self.show_labels_check.setChecked(True)
        self.show_labels_check.toggled.connect(self.toggle_labels_signal.emit)
        self.visibility_layout.addWidget(self.show_labels_check, 0, 1)

        self.show_paths_check = QCheckBox("Show Paths")
        self.show_paths_check.setChecked(True)
        self.show_paths_check.toggled.connect(self.toggle_paths_signal.emit)
        self.visibility_layout.addWidget(self.show_paths_check, 1, 0)

        self.show_trajectory_check = QCheckBox("Show Trajectory")
        self.show_trajectory_check.setChecked(True)
        self.show_trajectory_check.toggled.connect(self.toggle_trajectory_signal.emit)
        self.visibility_layout.addWidget(self.show_trajectory_check, 1, 1)

        self.show_crosshair_check = QCheckBox("Show Crosshair")
        self.show_crosshair_check.setChecked(True)
        self.show_crosshair_check.toggled.connect(self.toggle_crosshair_signal.emit)
        self.visibility_layout.addWidget(self.show_crosshair_check, 2, 0)

        # Pose visibility is linked to the enable checkbox, but keep separate toggle signal
        # self.show_poses_check = QCheckBox("Show Poses") # Redundant if linked to enable
        # self.show_poses_check.setChecked(self.pose_enabled_check.isChecked())
        # self.show_poses_check.toggled.connect(self.toggle_poses_signal.emit)
        # self.visibility_layout.addWidget(self.show_poses_check, 2, 1)

        self.show_hud_check = QCheckBox("Show HUD")
        self.show_hud_check.setChecked(True)
        self.show_hud_check.toggled.connect(self.toggle_hud_signal.emit)
        self.visibility_layout.addWidget(self.show_hud_check, 2, 1) # Move HUD toggle here


        # --- Log Output ---
        self.log_group = QGroupBox("Log Output")
        self.log_layout = QVBoxLayout()
        self.log_group.setLayout(self.log_layout)
        self.log_text_edit = QTextEdit()
        self.log_text_edit.setReadOnly(True)
        self.log_layout.addWidget(self.log_text_edit)
        self.main_layout.addWidget(self.log_group)

        # --- Status Bar ---
        self.status_bar = self.statusBar()
        self.status_bar.showMessage("Ready")

        # --- Initialize Overlay ---
        monitor_qrect = QRect(
            self.selected_monitor_spec['left'],
            self.selected_monitor_spec['top'],
            self.selected_monitor_spec['width'],
            self.selected_monitor_spec['height']
        )
        self.overlay = OverlayWidget(monitor_qrect)
        self.overlay.show()

        # --- Setup Logging Handler ---
        self.log_handler = QTextEditLogger(self.log_text_edit)
        logging.getLogger().addHandler(self.log_handler)
        logging.getLogger().setLevel(logging.INFO) # Ensure root logger level is appropriate

        # --- System Info Timer ---
        self.system_info_timer = QTimer(self)
        self.system_info_timer.timeout.connect(self.update_system_info)
        self.system_info_timer.start(SYSTEM_INFO_INTERVAL_MS)

        # --- Initialize Threads ---
        self.capture_thread = None
        self.detection_thread = None
        self.pose_thread = None
        self.flow_thread = None
        self.depth_thread = None
        self.start_threads()

        # --- Connect Signals ---
        self.connect_signals()

        # --- Initial System Info Update ---
        self.update_system_info()


    def update_monitor(self, index):
        """Restarts threads when the monitor selection changes."""
        if 0 <= index < len(self.monitors):
            self.selected_monitor_spec = self.monitors[index]
            logger.info(f"Monitor changed to: {index+1}")
            self.status_bar.showMessage(f"Monitor changed to {index+1}. Restarting capture...")
            # Stop existing threads
            self.stop_threads()
            # Update overlay geometry
            monitor_qrect = QRect(
                self.selected_monitor_spec['left'],
                self.selected_monitor_spec['top'],
                self.selected_monitor_spec['width'],
                self.selected_monitor_spec['height']
            )
            self.overlay.setGeometry(monitor_qrect)
            self.overlay.monitor_rect = monitor_qrect # Update internal rect too
            # Restart threads with the new monitor spec
            self.start_threads()
            # Reconnect signals as threads are new instances
            self.connect_signals() # Reconnect is crucial
            self.status_bar.showMessage(f"Capture restarted on monitor {index+1}.")
        else:
             logger.error(f"Invalid monitor index: {index}")


    def start_threads(self):
        """Initializes and starts all worker threads."""
        logger.info("Starting worker threads...")
        # Screen Capture
        self.capture_thread = ScreenCaptureThread(self.selected_monitor_spec)
        self.capture_thread.status_update.connect(self.update_status)
        self.capture_thread.start()

        # Detection
        self.detection_thread = DetectionThread(OBJECT_DETECTION_MODEL)
        self.detection_thread.status_update.connect(self.update_status)
        self.detection_thread.start()

        # Pose Estimation
        self.pose_thread = PoseEstimationThread(POSE_ESTIMATION_MODEL)
        self.pose_thread.status_update.connect(self.update_status)
        self.pose_thread.start()

        # Optical Flow
        self.flow_thread = OpticalFlowThread()
        self.flow_thread.status_update.connect(self.update_status)
        self.flow_thread.start()

        # Depth Estimation
        self.depth_thread = DepthEstimationThread() # Uses default MiDaS_small
        self.depth_thread.status_update.connect(self.update_status)
        self.depth_thread.start()

        logger.info("Worker threads started.")


    def connect_signals(self):
         """Connects signals between GUI, threads, and overlay."""
         logger.info("Connecting signals...")
         # --- Capture Thread Connections ---
         if self.capture_thread:
             self.capture_thread.frame_ready.connect(self.calculate_capture_fps) # Connect to FPS calc first
             # Connect frame_ready to worker threads that need the frame
             if self.detection_thread:
                 self.capture_thread.frame_ready.connect(self.detection_thread.set_frame)
             if self.pose_thread:
                 self.capture_thread.frame_ready.connect(self.pose_thread.set_frame)
             if self.flow_thread:
                 self.capture_thread.frame_ready.connect(self.flow_thread.set_frame)
             if self.depth_thread:
                 self.capture_thread.frame_ready.connect(self.depth_thread.set_frame)
             # Connect GUI controls to capture thread slots
             self.capture_interval_changed.connect(self.capture_thread.update_capture_interval)


         # --- Detection Thread Connections ---
         if self.detection_thread:
             self.detection_thread.detections_ready.connect(self.overlay.update_detections)
             # Connect GUI controls to detection thread slots
             self.detection_enabled_changed.connect(self.detection_thread.set_enabled)
             self.detection_conf_changed.connect(self.detection_thread.update_confidence_threshold)
             self.detection_nms_changed.connect(self.detection_thread.update_nms_threshold)
             # Initial state sync
             self.detection_thread.set_enabled(self.detection_enabled_check.isChecked())
             self.detection_thread.update_confidence_threshold(self.conf_spinbox.value())
             self.detection_thread.update_nms_threshold(self.nms_spinbox.value())


         # --- Pose Thread Connections ---
         if self.pose_thread:
             self.pose_thread.poses_ready.connect(self.overlay.update_poses)
             # Connect GUI controls to pose thread slots
             self.pose_enabled_changed.connect(self.pose_thread.set_enabled)
             self.pose_conf_changed.connect(self.pose_thread.update_confidence_threshold)
              # Initial state sync
             self.pose_thread.set_enabled(self.pose_enabled_check.isChecked())
             self.pose_thread.update_confidence_threshold(self.pose_conf_spinbox.value())


         # --- Flow Thread Connections ---
         if self.flow_thread:
             self.flow_thread.flow_ready.connect(self.overlay.update_flow)
             # Connect GUI controls to flow thread slots
             self.flow_enabled_changed.connect(self.flow_thread.set_enabled)
             # Initial state sync
             self.flow_thread.set_enabled(self.flow_enabled_check.isChecked())


         # --- Depth Thread Connections ---
         if self.depth_thread:
             self.depth_thread.depth_ready.connect(self.overlay.update_depth)
             # Connect GUI controls to depth thread slots
             self.depth_enabled_changed.connect(self.depth_thread.set_enabled)
             # Initial state sync
             self.depth_thread.set_enabled(self.depth_enabled_check.isChecked())


         # --- Overlay Visibility Connections ---
         self.toggle_boxes_signal.connect(self.overlay.toggle_boxes)
         self.toggle_labels_signal.connect(self.overlay.toggle_labels)
         self.toggle_paths_signal.connect(self.overlay.toggle_paths)
         self.toggle_trajectory_signal.connect(self.overlay.toggle_trajectory)
         self.toggle_crosshair_signal.connect(self.overlay.toggle_crosshair)
         self.toggle_poses_signal.connect(self.overlay.toggle_poses) # Connect pose visibility
         self.toggle_flow_signal.connect(self.overlay.toggle_flow)
         self.toggle_depth_signal.connect(self.overlay.toggle_depth)
         self.toggle_hud_signal.connect(self.overlay.toggle_hud)
         # Initial state sync for visibility toggles
         self.overlay.toggle_boxes(self.show_boxes_check.isChecked())
         self.overlay.toggle_labels(self.show_labels_check.isChecked())
         self.overlay.toggle_paths(self.show_paths_check.isChecked())
         self.overlay.toggle_trajectory(self.show_trajectory_check.isChecked())
         self.overlay.toggle_crosshair(self.show_crosshair_check.isChecked())
         self.overlay.toggle_poses(self.pose_enabled_check.isChecked()) # Link initial pose visibility to enable state
         self.overlay.toggle_flow(self.flow_enabled_check.isChecked()) # Link initial flow visibility to enable state
         self.overlay.toggle_depth(self.depth_enabled_check.isChecked()) # Link initial depth visibility to enable state
         self.overlay.toggle_hud(self.show_hud_check.isChecked())


         # --- HUD Data Connection ---
         self.hud_data_signal.connect(self.overlay.update_hud)
         self.capture_fps_signal.connect(self.overlay.update_capture_fps)


         logger.info("Signals connected.")


    def stop_threads(self):
        """Stops all worker threads gracefully."""
        logger.info("Stopping worker threads...")
        if self.capture_thread:
            self.capture_thread.stop()
            self.capture_thread.wait() # Wait for thread to finish
            self.capture_thread = None
        if self.detection_thread:
            self.detection_thread.stop()
            self.detection_thread.wait()
            self.detection_thread = None
        if self.pose_thread:
            self.pose_thread.stop()
            self.pose_thread.wait()
            self.pose_thread = None
        if self.flow_thread:
            self.flow_thread.stop()
            self.flow_thread.wait()
            self.flow_thread = None
        if self.depth_thread:
            self.depth_thread.stop()
            self.depth_thread.wait()
            self.depth_thread = None
        logger.info("Worker threads stopped.")


    @Slot(str)
    def update_status(self, message):
        """Updates the status bar."""
        self.status_bar.showMessage(message, 5000) # Show for 5 seconds


    def update_system_info(self):
        """Fetches system info and sends it to the HUD."""
        cpu_usage = psutil.cpu_percent()
        memory_info = psutil.virtual_memory()
        mem_usage = f"{memory_info.percent}% ({memory_info.used / (1024**3):.1f}/{memory_info.total / (1024**3):.1f} GB)"
        gpu_name, gpu_mem_usage = get_gpu_info()

        hud_data = {
            "CPU": f"{cpu_usage:.1f}%",
            "RAM": mem_usage,
            "GPU": gpu_name,
            "VRAM": gpu_mem_usage,
            "Time": datetime.datetime.now().strftime("%H:%M:%S")
        }
        self.hud_data_signal.emit(hud_data)


    # --- FPS Calculation Slot ---
    @Slot(np.ndarray, float)
    def calculate_capture_fps(self, frame, timestamp):
        """Calculates capture FPS based on frame arrival times."""
        now = time.perf_counter()
        time_diff = now - self.overlay.last_capture_time # Use overlay's time tracker
        if time_diff > 0:
            fps = 1.0 / time_diff
            self.capture_fps_signal.emit(fps) # Emit the calculated FPS
        self.overlay.last_capture_time = now # Update the last time


    def closeEvent(self, event):
        """Ensures threads and overlay are cleaned up on exit."""
        logger.info("Close event triggered. Cleaning up...")
        self.stop_threads()
        if self.overlay:
            self.overlay.close() # Close the overlay window
        logger.info("Cleanup complete. Exiting.")
        event.accept()


if __name__ == "__main__":
    # Ensure necessary model files are present or downloaded by ultralytics
    # You might need to run `yolo predict model=yolov8n.pt source=0` or similar once
    # for ultralytics to download the models if they aren't cached.
    try:
        # Check if models exist, attempt download if not (basic check)
        if not os.path.exists(OBJECT_DETECTION_MODEL):
             logger.info(f"Attempting to download {OBJECT_DETECTION_MODEL}...")
             _ = YOLO(OBJECT_DETECTION_MODEL) # Instantiating should trigger download
        if not os.path.exists(POSE_ESTIMATION_MODEL):
             logger.info(f"Attempting to download {POSE_ESTIMATION_MODEL}...")
             _ = YOLO(POSE_ESTIMATION_MODEL)
    except Exception as e:
        logger.warning(f"Could not pre-download/verify models: {e}. YOLO will attempt download on first use.")


    # Set high DPI scaling for better rendering on some systems
    QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True)
    QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True)

    app = QApplication(sys.argv)
    main_window = MainWindow()
    main_window.show()
    sys.exit(app.exec())