[![Labellerr](https://storage.googleapis.com/labellerr-cdn/%200%20Labellerr%20template/notebook.webp)](https://www.labellerr.com)

# **Fine-Tune-YOLO-For-Fruits-Counting**

---

[![labellerr](https://img.shields.io/badge/Labellerr-BLOG-black.svg)](https://www.labellerr.com/blog/<BLOG_NAME>)
[![Youtube](https://img.shields.io/badge/Labellerr-YouTube-b31b1b.svg)](https://www.youtube.com/@Labellerr)
[![Github](https://img.shields.io/badge/Labellerr-GitHub-green.svg)](https://github.com/Labellerr/Hands-On-Learning-in-Computer-Vision)
[![Scientific Paper](https://img.shields.io/badge/Official-Paper-blue.svg)](<PAPER LINK>)

### **Dataset Creation and Annotation**

In [None]:
!git clone https://github.com/Labellerr/yolo_finetune_utils.git

In [None]:
from yolo_finetune_utils.frame_extractor import extract_random_frames

extract_random_frames(paths=[r"assests\2.mp4"], 
                      out_dir='dataset', 
                      total_images=20,
                      seed=42)

### **Converting COCO-JSON to YOLO format**

In [None]:
from yolo_finetune_utils.coco_yolo_converter.seg_converter import coco_to_yolo_converter

coco_to_yolo_converter(json_path="annotation.json", images_dir="dataset", output_dir="yolo_format", seed=15)

### **Model Training**

In [None]:
from ultralytics import YOLO
import cv2
import matplotlib.pyplot as plt
import numpy as np

In [None]:
!yolo task=segment mode=train data="./yolo_format/data.yaml" model="yolo11m-seg.pt" epochs=250 imgsz=640 batch=30

### **Tracking using Custom Model**

In [None]:
!yolo task=segment mode=track tracker=botsort.yaml model="./runs/segment/train/weights/best.pt" conf=0.2 source="./assests/2.mp4" save=True show_labels=True

### **Drawing Counter Line**

In [None]:
video_path = r'assests\2.mp4'  # ← VIDEO PATH

In [None]:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    print(f"❌ Error: Cannot open video file: {video_path}")

cap.set(cv2.CAP_PROP_POS_FRAMES, 100)  # Set to frame number 100
ret, frame = cap.read()
cap.release()

plt.figure(figsize=(10, 6))
plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()

In [None]:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    print(f"❌ Error: Cannot open video file: {video_path}")
else:
    width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
print(f"Width: {width}, Height: {height}")

In [None]:
line = (1500,0), (1500,1080)

In [None]:
start_point, end_point = line

cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    print(f"❌ Error: Cannot open video file: {video_path}")

cap.set(cv2.CAP_PROP_POS_FRAMES, 100)  # Set to frame number 100
ret, frame = cap.read()
cap.release()

if ret:
    
    cv2.line(frame, start_point, end_point, (255, 120, 255), 10)
    
    TEXT = "COUNTING LINE"
    cv2.putText(frame, TEXT, (1510, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 120, 255), 10)
    
    plt.figure(figsize=(10, 6))
    plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.show()

### **Fruits Counting Logic**

In [None]:
from datetime import datetime

# =============================================================================
# GLOBAL VARIABLES
# =============================================================================
product_counter = 0
perform_segmentation = False  # Set to True to enable segmentation visualization
counting_line = (line)  # line coordinates
video_path = "assests/2.mp4"
output_video_path = "output5.mp4"
model_path = "./runs/segment/train/weights/best.pt"  # Trained segmentation model
model_confidence = 0.9  # Confidence threshold for YOLO model

# =============================================================================
# FUNCTIONS
# =============================================================================
def load_yolo_model(model_path):
    """Load YOLO model"""
    global model
    try:
        print(f"Loading YOLO segmentation model: {model_path}")
        model = YOLO(model_path)
        print("YOLO model loaded successfully")
        return True
    except Exception as e:
        print(f"Error loading YOLO model: {e}")
        return False


def line_intersection(p1, p2, p3, p4):
    """Check if line p1-p2 intersects with line p3-p4"""
    def ccw(A, B, C):
        return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])
    return ccw(p1, p3, p4) != ccw(p2, p3, p4) and ccw(p1, p2, p3) != ccw(p1, p2, p4)


def check_line_crossing(prev_pos, curr_pos, obj_id):
    """Check if object crosses the counting line"""
    global product_counter, counted_objects
    if line_intersection(prev_pos, curr_pos, counting_line[0], counting_line[1]):
        if obj_id not in counted_objects:  # Only count once per object
            counted_objects.add(obj_id)
            product_counter += 1
            print(f"🎯 Object {obj_id} crossed the line! Total count: {product_counter}")


def process_video():
    """Main function to process video"""
    global product_counter, counted_objects
    
    print(f"Starting Product Counter")
    print(f"Input: {video_path}")
    print(f"Output: {output_video_path}")
    print(f"Model: {model_path}")
    print("="*50)

    # Load model
    if not load_yolo_model(model_path):
        return False

    # Open input video
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"❌ Cannot open video: {video_path}")
        return False

    # Get video properties
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    print(f"Video: {width}x{height} @ {fps}fps, {total_frames} frames")

    # Create output video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    # Reset counter and tracking
    product_counter = 0
    counted_objects = set()  # Track which objects have been counted
    track_history = {}  # {id: (prev_center)}

    frame_count = 0
    start_time = datetime.now()

    print("Processing...")

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frame_count += 1

        # Run YOLO segmentation + tracking
        results = model.track(frame, conf=model_confidence, tracker="botsort.yaml", persist=True, verbose=False)

        if results and results[0].boxes.id is not None:
            # Draw segmentation masks with color coding
            if results[0].masks is not None:
                masks = results[0].masks.xy  # list of polygons
                boxes = results[0].boxes
                
                if perform_segmentation == True:
                    for i, seg in enumerate(masks):
                        if i < len(boxes):
                            obj_id = int(boxes[i].id[0].cpu().numpy())
                            
                            # Color based on counting status
                            if obj_id in counted_objects:
                                # Yellow translucent for counted objects
                                color = (0, 255, 255)  # BGR format: Yellow
                                fill_color = (0, 255, 255, 100)  # Yellow with alpha
                            else:
                                # Purple translucent for uncounted objects
                                color = (255, 0, 255)  # BGR format: Purple/Magenta
                                fill_color = (255, 0, 255, 100)  # Purple with alpha
                            
                            pts = np.array(seg, dtype=np.int32)
                            
                            # Create overlay for translucent fill
                            overlay = frame.copy()
                            cv2.fillPoly(overlay, [pts], color)
                            cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
                            
                            # Draw outline
                            cv2.polylines(frame, [pts], True, color, 2)

            # Draw tracked objects
            for box in results[0].boxes:
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                obj_id = int(box.id[0].cpu().numpy())
                cls_id = int(box.cls[0].cpu().numpy())
                conf = float(box.conf[0].cpu().numpy())

                # Object center
                center_x = int((x1 + x2) / 2)
                center_y = int((y1 + y2) / 2)
                center = (center_x, center_y)

                # Check line crossing
                if obj_id in track_history:
                    prev_center = track_history[obj_id]
                    check_line_crossing(prev_center, center, obj_id)
                track_history[obj_id] = center

                # Draw bounding box with color coding
                box_color = (0, 255, 255) if obj_id in counted_objects else (255, 0, 255)
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), box_color, 2)
                cv2.putText(frame, f"ID:{obj_id}", (int(x1), int(y1) - 10),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
                cv2.circle(frame, center, 5, (0, 0, 255), -1)

        # Draw counting line
        cv2.line(frame, counting_line[0], counting_line[1], (0, 255, 0), 10)
        cv2.putText(frame, "COUNTING LINE",
                    ((counting_line[0][0] + counting_line[1][0]) // 2 + 20,
                     (counting_line[0][1] + counting_line[1][1]) // 2),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 10)

        # Draw counter
        cv2.rectangle(frame, (10, 10), (200, 60), (0, 0, 0), -1)
        cv2.putText(frame, f"COUNT: {product_counter}", (20, 45),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2)

        out.write(frame)

        # Show progress every 10%
        if total_frames > 0 and frame_count % max(1, total_frames // 10) == 0:
            progress = (frame_count / total_frames) * 100
            print(f"📈 {progress:.0f}% - Frame {frame_count}/{total_frames} - Count: {product_counter}")

    # Cleanup
    cap.release()
    out.release()

    end_time = datetime.now()
    processing_time = end_time - start_time

    # Results
    print("="*50)
    print("Processing completed!")
    print(f"Total count: {product_counter}")
    print(f"Processing time: {processing_time}")
    print(f"Output saved: {output_video_path}")

In [None]:
process_video()