In [26]:
import cv2
import numpy as np
import supervision as sv
from ultralytics import YOLO
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Dict, List, Tuple, Set, Optional
import logging
import time
import sys
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('traffic_monitor.log')
    ]
)
logger = logging.getLogger(__name__)

In [27]:
@dataclass
class IntersectionArm:
    """Represents one arm of the intersection"""
    name: str
    entry_zone: np.ndarray
    exit_zone: np.ndarray
    vehicle_count_in: int = 0
    vehicle_count_out: int = 0

class TrafficMonitor:
    def __init__(
        self,
        source_path: str,
        model_path: str = "yolov8n.pt",
        conf_threshold: float = 0.3,
        iou_threshold: float = 0.5,
        model_resolution: int = 1280,
        output_path: Optional[str] = None
    ):
        """Initialize the traffic monitoring system"""
        try:
            # Validate input paths
            if not Path(source_path).exists():
                raise FileNotFoundError(f"Source video not found: {source_path}")
            if not Path(model_path).exists():
                raise FileNotFoundError(f"Model file not found: {model_path}")
            
            self.model = YOLO(model_path)
            self.conf_threshold = conf_threshold
            self.iou_threshold = iou_threshold
            self.model_resolution = model_resolution
            
            # Video setup
            self.video_info = sv.VideoInfo.from_video_path(source_path)
            self.frame_generator = sv.get_video_frames_generator(source_path)
            
            # Initialize ByteTrack
            self.tracker = sv.ByteTrack()
            
            # Storage initialization
            self.vehicle_tracks = defaultdict(lambda: deque(maxlen=int(self.video_info.fps * 3)))
            self.vehicle_speeds = {}
            self.detected_anomalies = set()
            self.last_positions = {}
            
            # Initialize video writer if output path is provided
            self.video_writer = None
            if output_path:
                self.video_writer = cv2.VideoWriter(
                    output_path,
                    cv2.VideoWriter_fourcc(*'mp4v'),
                    self.video_info.fps,
                    self.video_info.resolution_wh
                )
            
            # Setup annotators and zones
            self._setup_annotators()
            self.arms = self._setup_intersection_arms()
            
        except Exception as e:
            logger.error(f"Initialization error: {str(e)}")
            raise

    def _setup_annotators(self):
        """Setup visualization annotators with error handling"""
        try:
            thickness = sv.calculate_optimal_line_thickness(self.video_info.resolution_wh)
            text_scale = sv.calculate_optimal_text_scale(self.video_info.resolution_wh)
            
            # Updated to use BoxAnnotator instead of BoundingBoxAnnotator
            self.box_annotator = sv.BoxAnnotator(thickness=thickness)
            self.label_annotator = sv.LabelAnnotator(
                text_scale=text_scale,
                text_thickness=thickness
            )
            self.trace_annotator = sv.TraceAnnotator(
                thickness=thickness,
                trace_length=self.video_info.fps * 2
            )
            
        except Exception as e:
            logger.error(f"Error setting up annotators: {str(e)}")
            raise

    def _setup_intersection_arms(self) -> Dict[str, IntersectionArm]:
        """Setup intersection arms with video-specific coordinates"""
        try:
            # Define the regions of interest (ROIs) for your specific video
            return {
                "main_road": IntersectionArm(
                    name="main_road",
                    entry_zone=np.array([
                        [740, 570],   # Top-left
                        [1120, 570],  # Top-right
                        [1750, 1080], # Bottom-right
                        [0, 1080]     # Bottom-left
                    ]),
                    exit_zone=np.array([
                        [740, 470],   # Top-left
                        [1120, 470],  # Top-right
                        [1120, 570],  # Bottom-right
                        [740, 570]    # Bottom-left
                    ])
                )
            }
        except Exception as e:
            logger.error(f"Error setting up intersection arms: {str(e)}")
            raise

    def _annotate_frame(self, frame: np.ndarray, detections: sv.Detections, speeds: Dict[int, float]) -> np.ndarray:
        """Annotate frame with detections and metrics"""
        try:
            annotated_frame = frame.copy()
            
            # Draw zones using simple polygon drawing
            for arm in self.arms.values():
                # Draw entry zone in red
                annotated_frame = cv2.polylines(
                    annotated_frame, 
                    [arm.entry_zone], 
                    True, 
                    (0, 0, 255), 
                    2
                )
                # Draw exit zone in green
                annotated_frame = cv2.polylines(
                    annotated_frame, 
                    [arm.exit_zone], 
                    True, 
                    (0, 255, 0), 
                    2
                )
            
            # Prepare labels
            labels = [
                f"#{track_id} {int(speeds.get(track_id, 0))}km/h"
                for track_id in detections.tracker_id
            ]
            
            # Draw detections
            annotated_frame = self.trace_annotator.annotate(annotated_frame, detections)
            annotated_frame = self.box_annotator.annotate(annotated_frame, detections)
            annotated_frame = self.label_annotator.annotate(
                annotated_frame, detections, labels
            )
            
            # Add text overlay for metrics
            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(
                annotated_frame,
                f"Total Vehicles: {len(detections)}",
                (10, 30),
                font,
                1,
                (255, 255, 255),
                2
            )
            
            return annotated_frame
            
        except Exception as e:
            logger.error(f"Error annotating frame: {str(e)}")
            return frame

    def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Process a single frame"""
        try:
            # Run detection
            results = self.model(
                frame,
                imgsz=self.model_resolution,
                verbose=False
            )[0]
            
            # Convert to supervision Detections
            detections = sv.Detections.from_ultralytics(results)
            
            # Filter detections
            mask = np.ones(len(detections), dtype=bool)
            if len(detections) > 0:
                # Confidence threshold
                mask &= detections.confidence > self.conf_threshold
                # Class filter (exclude pedestrians)
                mask &= detections.class_id != 0
            detections = detections[mask]
            
            # Apply NMS
            detections = detections.with_nms(self.iou_threshold)
            
            # Update tracks
            detections = self.tracker.update_with_detections(detections)
            
            # Calculate metrics
            speeds = self._calculate_speeds(detections)
            self._update_flux_counts(detections)
            anomalies = self._detect_anomalies(detections)
            
            # Prepare visualization
            annotated_frame = self._annotate_frame(frame, detections, speeds)
            
            # Prepare metrics
            metrics = {
                "anomalies": anomalies,
                "flux_counts": {
                    arm.name: (arm.vehicle_count_in, arm.vehicle_count_out)
                    for arm in self.arms.values()
                },
                "total_vehicles": len(detections),
                "average_speed": np.mean(list(speeds.values())) if speeds else 0
            }
            
            return annotated_frame, metrics
            
        except Exception as e:
            logger.error(f"Error processing frame: {str(e)}")
            return frame, {}

    def _calculate_speeds(self, detections: sv.Detections) -> Dict[int, float]:
        """Calculate vehicle speeds"""
        speeds = {}
        try:
            for track_id, bbox in zip(detections.tracker_id, detections.xyxy):
                center = np.array([(bbox[0] + bbox[2])/2, (bbox[1] + bbox[3])/2])
                self.vehicle_tracks[track_id].append(center)
                
                if len(self.vehicle_tracks[track_id]) >= 2:
                    positions = list(self.vehicle_tracks[track_id])
                    displacement = np.linalg.norm(positions[-1] - positions[0])
                    time_diff = len(positions) / self.video_info.fps
                    speed = (displacement * 3.6) / time_diff
                    speeds[track_id] = speed
        except Exception as e:
            logger.error(f"Error calculating speeds: {str(e)}")
        return speeds

    def _update_flux_counts(self, detections: sv.Detections):
        """Update vehicle counts for each intersection arm"""
        try:
            for track_id, bbox in zip(detections.tracker_id, detections.xyxy):
                center_point = ((bbox[0] + bbox[2])/2, (bbox[1] + bbox[3])/2)
                
                for arm in self.arms.values():
                    if track_id not in self.last_positions:
                        if self._point_in_polygon(center_point, arm.entry_zone):
                            arm.vehicle_count_in += 1
                    elif self._point_in_polygon(center_point, arm.exit_zone):
                        arm.vehicle_count_out += 1
                
                self.last_positions[track_id] = center_point
        except Exception as e:
            logger.error(f"Error updating flux counts: {str(e)}")

    def _point_in_polygon(self, point: Tuple[float, float], polygon: np.ndarray) -> bool:
        """Check if point is inside polygon using ray casting algorithm"""
        try:
            x, y = point
            n = len(polygon)
            inside = False
            
            p1x, p1y = polygon[0]
            for i in range(n + 1):
                p2x, p2y = polygon[i % n]
                if y > min(p1y, p2y):
                    if y <= max(p1y, p2y):
                        if x <= max(p1x, p2x):
                            if p1y != p2y:
                                xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
                            if p1x == p2x or x <= xinters:
                                inside = not inside
                p1x, p1y = p2x, p2y
            
            return inside
        except Exception as e:
            logger.error(f"Error in point_in_polygon: {str(e)}")
            return False

    def _detect_anomalies(self, detections: sv.Detections) -> Set[str]:
        """Detect traffic anomalies such as stopped vehicles and unusual density"""
        anomalies = set()
        try:
            # Check for stopped vehicles
            for track_id in detections.tracker_id:
                if track_id in self.vehicle_tracks:
                    track = self.vehicle_tracks[track_id]
                    if len(track) >= 2:
                        # Calculate movement between last two positions
                        movement = np.linalg.norm(track[-1] - track[-2])
                        if movement < 1.0:  # Threshold for stopped vehicle (in pixels)
                            anomalies.add(f"Stopped vehicle: {track_id}")
            
            # Check for unusual vehicle density in entry zones
            for arm in self.arms.values():
                vehicle_count = sum(1 for bbox in detections.xyxy if 
                                self._point_in_polygon(
                                    ((bbox[0] + bbox[2])/2, (bbox[1] + bbox[3])/2),
                                    arm.entry_zone
                                ))
                if vehicle_count > 10:  # Threshold for high density
                    anomalies.add(f"High density in {arm.name}")
            
            # Check for speeding vehicles
            for track_id in detections.tracker_id:
                if track_id in self.vehicle_speeds:
                    speed = self.vehicle_speeds[track_id]
                    if speed > 80:  # Speed threshold in km/h
                        anomalies.add(f"Speeding vehicle: {track_id}")
                        
        except Exception as e:
            logger.error(f"Error detecting anomalies: {str(e)}")
        
        return anomalies


    def cleanup(self):
        """Cleanup resources"""
        if self.video_writer is not None:
            self.video_writer.release()

In [28]:
def main():
    """Main execution function"""
    monitor = None
    try:
        # Initialize traffic monitor
        monitor = TrafficMonitor(
            source_path="/home/raw/Desktop/Coding/Jhakaas_Rasta/datasets/allahabad_3.MOV",
            conf_threshold=0.3,
            iou_threshold=0.5,
            output_path="/home/raw/Desktop/Coding/Jhakaas_Rasta/datasets/output_processed.mp4"
        )
        
        start_time = time.time()
        frame_count = 0
        
        for frame in monitor.frame_generator:
            frame_count += 1
            
            # Process frame
            annotated_frame, metrics = monitor.process_frame(frame)
            
            # Save frame if writer exists
            if monitor.video_writer is not None:
                monitor.video_writer.write(annotated_frame)
            
            # Log metrics
            if metrics.get('anomalies'):
                logger.warning(f"Detected anomalies: {metrics['anomalies']}")
            
            # Calculate FPS
            if frame_count % 30 == 0:
                fps = frame_count / (time.time() - start_time)
                logger.info(f"Processing at {fps:.2f} FPS")
            
            # Display frame
            cv2.imshow('Traffic Monitor', annotated_frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
                
    except KeyboardInterrupt:
        logger.info("Processing interrupted by user")
    except Exception as e:
        logger.error(f"Error in main execution: {str(e)}")
    finally:
        # Cleanup
        if monitor is not None:
            monitor.cleanup()
        cv2.destroyAllWindows()

In [29]:
if __name__ == "__main__":
    main()

INFO:__main__:Processing at 11.19 FPS
INFO:__main__:Processing at 11.93 FPS
INFO:__main__:Processing at 12.35 FPS
INFO:__main__:Processing at 12.67 FPS
INFO:__main__:Processing at 12.81 FPS
INFO:__main__:Processing at 12.77 FPS
INFO:__main__:Processing at 12.83 FPS
INFO:__main__:Processing at 12.84 FPS
INFO:__main__:Processing at 12.88 FPS
INFO:__main__:Processing at 12.87 FPS
INFO:__main__:Processing at 12.79 FPS
INFO:__main__:Processing at 12.76 FPS
INFO:__main__:Processing at 12.77 FPS
INFO:__main__:Processing at 12.79 FPS
INFO:__main__:Processing at 12.78 FPS
INFO:__main__:Processing at 12.81 FPS
INFO:__main__:Processing at 12.84 FPS
INFO:__main__:Processing at 12.94 FPS
INFO:__main__:Processing at 12.96 FPS
INFO:__main__:Processing at 12.94 FPS
INFO:__main__:Processing interrupted by user
