In [6]:
import os
import cv2
from ultralytics import YOLO
from tqdm.notebook import tqdm
import math

# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================

# --- Model Paths ---
MODEL_DETECT_PATH = 'data/model/detect_yolo_small_v3.pt'
MODEL_CLASSIFY_PATH = 'data/model/classify_yolo_small_v3.pt'

# --- Video Path ---
# *** CHANGE THIS to the path of your video file ***
VIDEO_PATH = 'F:/GREEN.mp4'

# --- Output Directory Paths ---
# The output directory will be created next to the video file:
# e.g., 'D:/Recordings/New_Recordings/Round2/classes/'
VIDEO_DIR = os.path.dirname(VIDEO_PATH)
OUTPUT_BASE_DIR = os.path.join(VIDEO_DIR, 'classes')

# --- Processing Parameters ---
# Confidence threshold for the DETECT model
CONFIDENCE_THRESHOLD = 0.9
# Process every Nth frame (e.g., 1 processes every frame, 10 processes every tenth frame)
FRAME_SKIP = 100

# ==============================================================================
# 2. SETUP
# ==============================================================================

print("Setup: Loading YOLO models...")
# Load the YOLO Detection Model
try:
    model_detect = YOLO(MODEL_DETECT_PATH)
except Exception as e:
    print(f"Error loading detection model: {e}")
    print("Please check if MODEL_DETECT_PATH is correct.")
    exit()

# Load the YOLO Classification Model
try:
    model_classify = YOLO(MODEL_CLASSIFY_PATH, task='classify')
    CLASSIFY_NAMES = model_classify.names
    print(f"Classification Model Classes: {CLASSIFY_NAMES}")
except Exception as e:
    print(f"Error loading classification model: {e}")
    print("Please check if MODEL_CLASSIFY_PATH is correct.")
    exit()

# Ensure the base output directory exists
os.makedirs(OUTPUT_BASE_DIR, exist_ok=True)
print(f"Output directory created at: {OUTPUT_BASE_DIR}")
print("Setup Complete. Starting video processing...")

# Dictionary to hold the next sequential ID for each class, shared across all frames
class_counters = {}

# ==============================================================================
# 3. CORE PROCESSING LOGIC
# ==============================================================================

def process_frame(img, frame_number, class_counters):
    """
    Performs detection, cropping, classification, and saves the cropped image
    for a single frame.
    """
    if img is None:
        return class_counters # Skip if frame load failed

    # 1. Run Detection
    # Run the detection model on the image
    results_detect = model_detect(img, conf=CONFIDENCE_THRESHOLD, verbose=False)

    # Process results from the first (and only) image in the batch
    if not results_detect or not results_detect[0].boxes:
        return class_counters

    detections = results_detect[0].boxes

    # 2. Iterate through Detections
    for i, box in enumerate(detections):
        # Get bounding box coordinates (x_min, y_min, x_max, y_max)
        x1, y1, x2, y2 = map(int, box.xyxy[0])

        # 3. Crop the Detected Object
        # Ensure coordinates are valid and within image boundaries
        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(img.shape[1], x2)
        y2 = min(img.shape[0], y2)

        # Perform the crop
        crop_img = img[y1:y2, x1:x2]

        # Skip if the crop is empty (e.g., invalid coordinates or too small)
        if crop_img.size == 0 or crop_img.shape[0] < 5 or crop_img.shape[1] < 5:
            continue

        # 4. Classify the Cropped Object
        # Run the classification model on the cropped image
        results_classify = model_classify(crop_img, verbose=False)

        # Get the top class ID and name
        top_class_id = results_classify[0].probs.top1
        class_name = CLASSIFY_NAMES[top_class_id]

        # 5. Save to its Class Type Folder with Sequential Naming
        # Define the class-specific output directory: {OUTPUT_BASE_DIR}/{class_name}
        class_output_dir = os.path.join(OUTPUT_BASE_DIR, class_name)
        os.makedirs(class_output_dir, exist_ok=True)

        # Get the current sequence counter and increment
        if class_name not in class_counters:
            class_counters[class_name] = 1

        current_id = class_counters[class_name]

        # Format the filename: %04d.jpg (e.g., 0001.jpg)
        crop_filename = f"{current_id:04d}.jpg"
        output_path = os.path.join(class_output_dir, crop_filename)

        # Save the cropped image
        cv2.imwrite(output_path, crop_img)

        # Increment the counter for the next crop of this class
        class_counters[class_name] += 1

    return class_counters

def process_video(video_path, class_counters):
    """
    Reads the video, extracts frames, and calls process_frame on each.
    """
    # Open the video file
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video file at {video_path}")
        return

    # Get video properties
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    print(f"Video loaded: {os.path.basename(video_path)} | Frames: {frame_count} | FPS: {fps:.2f}")

    frame_number = 0
    # Use tqdm to show progress through the video frames
    with tqdm(total=frame_count, desc="Processing Video Frames") as pbar:
        while cap.isOpened():
            ret, frame = cap.read()

            if not ret:
                break # End of video

            frame_number += 1
            pbar.update(1)

            # Skip frames based on FRAME_SKIP configuration
            if frame_number % FRAME_SKIP != 0:
                continue

            # Process the frame and update the counters
            process_frame(frame, frame_number, class_counters)

    # Release the video capture object
    cap.release()
    cv2.destroyAllWindows()

    print("\n--- Summary ---")
    if class_counters:
        total_crops = sum(class_counters.values()) - len(class_counters)
        print(f"Finished processing video. Total {total_crops} crops saved.")
        for class_name, count in class_counters.items():
             print(f"- '{class_name}': {count - 1} crops saved.")
    else:
        print("Finished processing video. No objects were detected and saved.")

# ==============================================================================
# 4. EXECUTION
# ==============================================================================

if not os.path.exists(VIDEO_PATH):
    print(f"\nFATAL ERROR: The configured VIDEO_PATH does not exist: {VIDEO_PATH}")
    print("Please update the 'VIDEO_PATH' variable in section 1 of the script.")
else:
    process_video(VIDEO_PATH, class_counters)
    print(f"\nFinal crops output folder: {OUTPUT_BASE_DIR}")

Setup: Loading YOLO models...
Classification Model Classes: {0: 'Blue_Yellow', 1: 'Bran', 2: 'Brown_Orange_Overlay', 3: 'Brown_Orange_Small', 4: 'Green_Yellow', 5: 'Red_Yellow', 6: 'Wheatberry'}
Output directory created at: F:/classes
Setup Complete. Starting video processing...
Video loaded: GREEN.mp4 | Frames: 71988 | FPS: 19.99


Processing Video Frames:   0%|          | 0/71988 [00:00<?, ?it/s]

KeyboardInterrupt: 