[![Labellerr](https://storage.googleapis.com/labellerr-cdn/%200%20Labellerr%20template/notebook.webp)](https://www.labellerr.com)

# **Fine-Tune YOLO for Traffic Flow Counting**

---

[![labellerr](https://img.shields.io/badge/Labellerr-BLOG-black.svg)](https://www.labellerr.com/blog/<BLOG_NAME>)
[![Youtube](https://img.shields.io/badge/Labellerr-YouTube-b31b1b.svg)](https://www.youtube.com/@Labellerr)
[![Github](https://img.shields.io/badge/Labellerr-GitHub-green.svg)](https://github.com/Labellerr/Hands-On-Learning-in-Computer-Vision)
[![Scientific Paper](https://img.shields.io/badge/Official-Paper-blue.svg)](<PAPER LINK>)


In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random
from ultralytics import YOLO
from collections import defaultdict, deque
import os

## **Plotting Region to Track Object**

In [None]:
polygons = []
current_polygon = []

def draw_polygon(event, x, y, flags, param):
    global current_polygon
    if event == cv2.EVENT_LBUTTONDOWN:
        current_polygon.append((x, y))
    elif event == cv2.EVENT_RBUTTONDOWN:
        if len(current_polygon) > 2:
            polygons.append(current_polygon.copy())
        current_polygon = []

def create_polygons_fullscreen(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error: Could not open video.")
        return []
    
    _, frame = cap.read()  # Read single frame to get video resolution
    if frame is None:
        print("Error: Could not read video frame.")
        return []

    cv2.namedWindow('Create Polygons', cv2.WINDOW_NORMAL)
    cv2.setWindowProperty('Create Polygons', cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
    cv2.setMouseCallback('Create Polygons', draw_polygon)

    while True:
        ret, frame = cap.read()
        if not ret:
            cap.set(cv2.CAP_PROP_POS_FRAMES, 0)  # Loop video
            continue

        # Draw existing polygons in green
        for poly in polygons:
            pts = np.array(poly, np.int32).reshape((-1, 1, 2))
            cv2.polylines(frame, [pts], True, (0, 255, 0), 2)

        # Draw current polygon being drawn in red
        if len(current_polygon) > 1:
            pts = np.array(current_polygon, np.int32).reshape((-1, 1, 2))
            cv2.polylines(frame, [pts], False, (0, 0, 255), 2)

        cv2.putText(frame, 'Left click: Add point', (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 4)
        cv2.putText(frame, 'Right click: Close polygon', (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 4)
        cv2.putText(frame, 'Press Q to Quit', (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 4)

        cv2.imshow('Create Polygons', frame)
        key = cv2.waitKey(20) & 0xFF
        if key == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
    return polygons


In [None]:
polygons = create_polygons_fullscreen('assests/1.mp4')
print(polygons)


In [None]:
def show_polygons_on_video_matplotlib(video_path, polygons):
    # Read the first frame from the video
    cap = cv2.VideoCapture(video_path)
    ret, frame = cap.read()
    cap.release()
    if not ret:
        print("Failed to read video frame.")
        return

    # Convert BGR to RGB for matplotlib
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.imshow(frame_rgb)

    # Draw polygons in different colors
    for poly in polygons:
        # Generate a random color for each polygon, with 0.5 alpha
        color = [random.random(), random.random(), random.random(), 0.5]
        patch = patches.Polygon(poly, closed=True, facecolor=color, edgecolor=color[:3], linewidth=2)
        ax.add_patch(patch)

    ax.set_axis_off()
    plt.tight_layout()
    plt.show()


In [None]:
# Example usage:
polygons_1 = [
    [(687, 870), (834, 729), (1437, 732), (1455, 1374), (102, 1329)],
    [(1707, 720), (2118, 732), (2274, 723), (3204, 1191), (2085, 1230)]
]
show_polygons_on_video_matplotlib('assests/1.mp4', polygons_1)


## **Counting the vehicle which passed through the region**

In [None]:
class VehicleCounter:
    def __init__(self, polygons, model_path=None, target_classes=None):
        """
        Initialize the vehicle counter with polygon regions
        
        Args:
            polygons: List of polygon regions as [(x1,y1), (x2,y2), ...]
            model_path: Path to custom YOLO model (optional)
            target_classes: Dict of class_id: class_name for custom models (optional)
        """
        self.polygons = polygons
        self.region_counters = [0] * len(polygons)
        self.region_colors = self._generate_region_colors()
        self.tracked_objects = {}  # track_id: {last_region: int, history: deque}
        
        # Set up model and target classes
        self.model_path = model_path
        self._setup_model_and_classes(model_path, target_classes)
        
    def _setup_model_and_classes(self, model_path, target_classes):
        """Setup YOLO model and target classes"""
        if model_path is None:
            # Default: Use pre-trained YOLO model with COCO classes
            self.model = YOLO('yolov8x.pt')
            self.target_classes = {
                2: 'car',
                5: 'bus', 
                7: 'truck'
            }
            print("Using default YOLOv8x model with COCO dataset classes")
        else:
            # Custom model
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Custom model file not found: {model_path}")
            
            self.model = YOLO(model_path)
            
            if target_classes is None:
                # Try to get class names from model
                if hasattr(self.model.model, 'names'):
                    model_names = self.model.model.names
                    self.target_classes = {i: name for i, name in model_names.items()}
                    print(f"Using all classes from custom model: {list(self.target_classes.values())}")
                else:
                    # Fallback: assume single class or ask user to provide
                    print("Warning: Could not determine class names from custom model.")
                    print("Assuming single class 'object'. Consider providing target_classes parameter.")
                    self.target_classes = {0: 'object'}
            else:
                # Use provided target classes
                self.target_classes = target_classes
                print(f"Using custom model with specified classes: {list(target_classes.values())}")
            
            print(f"Loaded custom model from: {model_path}")
        
    def _generate_region_colors(self):
        """Generate unique colors for each polygon region"""
        color_names = ['red', 'green', 'blue', 'yellow', 'purple', 'orange', 'cyan', 'magenta', 'gray', 'pink', 'dark orange', 'teal', 'deep pink', 'hot pink', 'red orange']
        colors = [
            (0, 0, 255),    # red
            (0, 255, 0),    # green
            (255, 0, 0),    # blue
            (0, 255, 255),  # yellow
            (128, 0, 128),  # purple
            (0, 165, 255),  # orange
            (255, 255, 0),  # cyan
            (255, 0, 255),   # magenta
            (128, 128, 128),  # gray
            (255, 192, 203), # pink
            (255, 140, 0),   # dark orange
            (0, 128, 128),   # teal
            (255, 20, 147),  # deep pink
            (255, 105, 180), # hot pink
            (255, 69, 0)    # red orange
        ]
        
        region_colors = []
        for i in range(len(self.polygons)):
            color_idx = i % len(colors)
            region_colors.append({
                'color': colors[color_idx],
                'name': color_names[color_idx] if i < len(color_names) else f'region_{i}'
            })
        
        return region_colors
    
    def _point_in_polygon(self, point, polygon):
        """Check if a point is inside a polygon using ray casting algorithm"""
        x, y = point
        n = len(polygon)
        inside = False
        
        p1x, p1y = polygon[0]
        for i in range(1, n + 1):
            p2x, p2y = polygon[i % n]
            if y > min(p1y, p2y):
                if y <= max(p1y, p2y):
                    if x <= max(p1x, p2x):
                        if p1y != p2y:
                            xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
                        if p1x == p2x or x <= xinters:
                            inside = not inside
            p1x, p1y = p2x, p2y
        
        return inside
    
    def _get_object_region(self, center_point):
        """Determine which region (if any) contains the object center point"""
        for region_idx, polygon in enumerate(self.polygons):
            if self._point_in_polygon(center_point, polygon):
                return region_idx
        return -1  # Not in any region
    
    def _draw_polygons(self, frame):
        """Draw polygon regions on the frame with colors and counters"""
        overlay = frame.copy()
        
        for i, polygon in enumerate(self.polygons):
            # Convert polygon to numpy array for OpenCV
            pts = np.array(polygon, np.int32)
            pts = pts.reshape((-1, 1, 2))
            
            # Draw filled polygon with transparency
            cv2.fillPoly(overlay, [pts], self.region_colors[i]['color'])
            
            # Draw polygon outline
            cv2.polylines(frame, [pts], True, self.region_colors[i]['color'], 3)
            
            # Calculate centroid for text placement
            moments = cv2.moments(pts)
            if moments['m00'] != 0:
                cx = int(moments['m10'] / moments['m00'])
                cy = int(moments['m01'] / moments['m00'])
            else:
                cx, cy = polygon[0]  # fallback to first point
            
            # Draw region info at centroid
            region_text = f"{self.region_colors[i]['name']}: {self.region_counters[i]}"
            text_size = cv2.getTextSize(region_text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0]
            
            # Background rectangle for text at centroid
            cv2.rectangle(frame, 
                        (cx - text_size[0]//2 - 5, cy - text_size[1] - 5),
                        (cx + text_size[0]//2 + 5, cy + 5),
                        (0, 0, 0), -1)
            
            # Text at centroid
            cv2.putText(frame, region_text, 
                    (cx - text_size[0]//2, cy),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
        
        # Blend overlay with original frame for transparency effect
        cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
        
        # Draw region information on top right side
        frame_height, frame_width = frame.shape[:2]
        font_scale = 1.0  
        font_thickness = 2
        padding = 20
        line_spacing = 100
        
        # Start position for top right display
        start_y = padding + 90  
        
        for i, polygon in enumerate(self.polygons):
            region_text = f"{self.region_colors[i]['name']}: {self.region_counters[i]}"
            text_size = cv2.getTextSize(region_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)[0]
            text_width, text_height = text_size
            
            # Calculate position (right aligned with padding)
            text_x = frame_width - text_width - padding
            text_y = start_y + (i * line_spacing)
            
            # Draw background rectangle for better visibility
            rect_padding = 15  
            cv2.rectangle(frame,
                        (text_x - rect_padding, text_y - text_height - rect_padding),
                        (text_x + text_width + rect_padding, text_y + rect_padding),
                        (0, 0, 0), -1)
            
            # Draw colored indicator bar next to text
            indicator_width = 24  
            cv2.rectangle(frame,
                        (text_x - rect_padding - indicator_width - 15, text_y - text_height - rect_padding),
                        (text_x - rect_padding - 15, text_y + rect_padding),
                        self.region_colors[i]['color'], -1)
            
            # Draw the text
            cv2.putText(frame, region_text,
                    (text_x, text_y),
                    cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), font_thickness)
        
        return frame
    
    def _update_tracking(self, track_id, current_region):
        """Update tracking information and count if object crosses into new region"""
        if track_id not in self.tracked_objects:
            self.tracked_objects[track_id] = {
                'last_region': current_region,
                'history': deque(maxlen=10)  # Keep last 10 region positions
            }
        
        obj_data = self.tracked_objects[track_id]
        obj_data['history'].append(current_region)
        
        # Check if object moved from outside/different region into current region
        if (obj_data['last_region'] != current_region and 
            current_region != -1 and 
            len(obj_data['history']) >= 2):
            
            # Count only if object was previously outside this region or in a different region
            if obj_data['last_region'] != current_region:
                self.region_counters[current_region] += 1
                print(f"Object {track_id} entered {self.region_colors[current_region]['name']} region. Count: {self.region_counters[current_region]}")
        
        obj_data['last_region'] = current_region
    
    def process_video(self, video_path, output_path, confidence_threshold=0.2):
        """
        Process video and save with region counting
        
        Args:
            video_path: Path to input video
            output_path: Path to save output video
            confidence_threshold: Minimum confidence threshold for detections
        """
        cap = cv2.VideoCapture(video_path)
        
        if not cap.isOpened():
            raise ValueError(f"Error opening video file: {video_path}")
        
        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Setup video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        
        print(f"Processing video: {video_path}")
        print(f"Output will be saved to: {output_path}")
        print(f"Video properties: {width}x{height}, {fps} FPS, {total_frames} frames")
        print(f"Using model: {self.model_path if self.model_path else 'YOLOv8x (default)'}")
        print(f"Target classes: {list(self.target_classes.values())}")
        
        frame_count = 0
        
        try:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                
                frame_count += 1
                
                # Run YOLO tracking on the frame
                # For custom models, we might need to specify classes or use all classes
                if self.model_path is None:
                    # Default model - specify target classes
                    results = self.model.track(frame, persist=True, classes=list(self.target_classes.keys()))
                else:
                    # Custom model - use all classes or specified ones
                    results = self.model.track(frame, persist=True)
                
                if results[0].boxes is not None and results[0].boxes.id is not None:
                    # Get detections
                    boxes = results[0].boxes.xyxy.cpu().numpy()
                    track_ids = results[0].boxes.id.cpu().numpy().astype(int)
                    classes = results[0].boxes.cls.cpu().numpy().astype(int)
                    confidences = results[0].boxes.conf.cpu().numpy()
                    
                    # Process each detection
                    for box, track_id, cls, conf in zip(boxes, track_ids, classes, confidences):
                        # Check if this class is in our target classes and meets confidence threshold
                        if cls in self.target_classes and conf > confidence_threshold:
                            # Get center point of bounding box
                            x1, y1, x2, y2 = box
                            center_x = int((x1 + x2) / 2)
                            center_y = int((y1 + y2) / 2)
                            center_point = (center_x, center_y)
                            
                            # Determine which region the object is in
                            current_region = self._get_object_region(center_point)
                            
                            # Update tracking and counting
                            self._update_tracking(track_id, current_region)
                            
                            # Draw bounding box and label
                            color = (0, 255, 0)  # Green for detections
                            cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
                            
                            class_name = self.target_classes.get(cls, f'class_{cls}')
                            label = f"{class_name} ID:{track_id} {conf:.2f}"
                            label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
                            cv2.rectangle(frame, (int(x1), int(y1) - label_size[1] - 5), 
                                        (int(x1) + label_size[0], int(y1)), color, -1)
                            cv2.putText(frame, label, (int(x1), int(y1) - 5),
                                      cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
                            
                            # Draw center point
                            cv2.circle(frame, center_point, 3, (0, 0, 255), -1)
                
                # Draw polygon regions and counters
                frame = self._draw_polygons(frame)
                
                # Add frame info
                info_text = f"Frame: {frame_count}/{total_frames}"
                cv2.putText(frame, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
                
                # Write frame to output video
                out.write(frame)
                
                # Progress update
                if frame_count % 100 == 0:
                    print(f"Processed {frame_count}/{total_frames} frames")
        
        except Exception as e:
            print(f"Error during processing: {e}")
        
        finally:
            # Cleanup
            cap.release()
            out.release()
            
        print("Processing completed!")
        print("Final counts per region:")
        for i, count in enumerate(self.region_counters):
            print(f"  {self.region_colors[i]['name']} region: {count}")

def count_vehicles_in_regions(polygons, video_path, output_path="output_with_counting.mp4", 
                            model_path=None, target_classes=None, confidence_threshold=0.2):
    """
    Main function to count vehicles passing through polygon regions
    
    Args:
        polygons: List of polygon regions as [[(x1,y1), (x2,y2), ...], ...]
        video_path: Path to input video file
        output_path: Path to save output video (default: "output_with_counting.mp4")
        model_path: Path to custom YOLO model (optional, uses YOLOv8x if None)
        target_classes: Dict of class_id: class_name for custom models (optional)
        confidence_threshold: Minimum confidence threshold for detections (default: 0.2)
    
    Returns:
        dict: Final counts for each region
    
    Examples:
        # Using default YOLO model
        counts = count_vehicles_in_regions(polygons, "input_video.mp4")
        
        # Using custom YOLO model with all classes
        counts = count_vehicles_in_regions(polygons, "input_video.mp4", 
                                         model_path="my_custom_model.pt")
        
        # Using custom YOLO model with specific target classes
        custom_classes = {0: 'vehicle', 1: 'person', 2: 'bicycle'}
        counts = count_vehicles_in_regions(polygons, "input_video.mp4", 
                                         model_path="my_custom_model.pt",
                                         target_classes=custom_classes)
    """
    try:
        # Create vehicle counter instance with custom model support
        counter = VehicleCounter(polygons, model_path=model_path, target_classes=target_classes)
        
        # Process the video
        counter.process_video(video_path, output_path, confidence_threshold=confidence_threshold)
        
        # Return final counts
        final_counts = {}
        for i, count in enumerate(counter.region_counters):
            region_name = counter.region_colors[i]['name']
            final_counts[region_name] = count
        
        return final_counts
        
    except Exception as e:
        print(f"Error in count_vehicles_in_regions: {e}")
        return {}


In [None]:
# Using default YOLO model
# polygon regions
polygons = polygons_1

counts1 = count_vehicles_in_regions(
    polygons=polygons,
    video_path="./assests/1.mp4",
    output_path="1_result.mp4")

## **Custom Model Training**

In [None]:
!git clone https://github.com/Labellerr/yolo_finetune_utils.git

In [None]:
from yolo_finetune_utils.coco_yolo_converter.bbox_converter import coco_to_yolo_converter

result = coco_to_yolo_converter(
            json_path=r'./annotation.json',
            images_dir=r'./dataset',
            output_dir='yolo_format',
            use_split=False
            )

In [None]:
!yolo task=detect mode=train data="./yolo_format/dataset.yaml" model="yolov8x.pt" epochs=200 imgsz=640 batch=20

In [None]:
!yolo task=detect mode=track model="./runs/detect/train/weights/last.pt" source="./assests/3.mp4" conf=0.25 save=True show_labels=False

## **Counting Cars using Drone View Camera**

In [None]:
# Define polygon regions
polygons2 = [
            [(93, 322), (483, 260), (486, 483), (111, 522)],
            [(112, 549), (482, 560), (507, 776), (123, 724)],
            [(1443, 346), (1479, 564), (1870, 584), (1857, 430)],
            [(1478, 598), (1414, 824), (1869, 784), (1868, 598)],
            ]

show_polygons_on_video_matplotlib('assests/3.mp4', polygons2)

In [None]:
# Using custom YOLO model
counts2 = count_vehicles_in_regions(
    polygons=polygons,
    video_path= "./assests/3.mp4",
    output_path= "3_result_custom_2.mp4",
    model_path= "./runs/detect/train/weights/last.pt",
    target_classes= {0: 'car'},  # Optional: specify which classes to track
    confidence_threshold=0.2
)