In [7]:
import os
import cv2
import numpy as np
from glob import glob

# Define class mapping for YOLO
shape_to_class_id = {
    "triangle": 0,
    "square": 1,
    "circle": 2
}

# Helper to classify color (not used in YOLO labels)
def classify_color(roi_rgb):
    avg_color = np.mean(roi_rgb.reshape(-1, 3), axis=0)
    r, g, b = avg_color
    if r > g and r > b:
        return "red"
    elif g > r and g > b:
        return "green"
    elif b > r and b > g:
        return "blue"
    else:
        return "unknown"

# Helper to detect shapes and write YOLO labels
def detect_and_annotate(image_path, save_label_path):
    image = cv2.imread(image_path)
    if image is None:
        return
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    h_img, w_img = image.shape[:2]

    denoised = cv2.medianBlur(image_rgb, 7)
    gray = cv2.cvtColor(denoised, cv2.COLOR_RGB2GRAY)
    kernel = np.ones((3, 3), np.uint8)
    opened = cv2.morphologyEx(gray, cv2.MORPH_OPEN, kernel)
    _, binary_thresh = cv2.threshold(opened, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(binary_thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    label_lines = []

    for cnt in contours:
        area = cv2.contourArea(cnt)
        if area < 200:
            continue
        peri = cv2.arcLength(cnt, True)
        approx = cv2.approxPolyDP(cnt, 0.04 * peri, True)
        x, y, w, h = cv2.boundingRect(approx)
        num_sides = len(approx)
        circularity = 4 * np.pi * area / (peri * peri)

        shape = None
        if num_sides == 3:
            shape = "triangle"
        elif num_sides == 4:
            aspect_ratio = w / float(h)
            if 0.90 < aspect_ratio < 1.10:
                shape = "square"
        elif circularity > 0.8:
            shape = "circle"

        if shape:
            class_id = shape_to_class_id[shape]
            x_center = (x + w / 2) / w_img
            y_center = (y + h / 2) / h_img
            w_norm = w / w_img
            h_norm = h / h_img
            label_lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}")

    if label_lines:
        with open(save_label_path, 'w') as f:
            f.write("\n".join(label_lines))

# Create YOLO format dataset folders
input_dir = "../data/train_dataset"
output_img_dir = "../data/yolov_dataset/images/train"
output_lbl_dir = "../data/yolov_dataset/labels/train"
os.makedirs(output_img_dir, exist_ok=True)
os.makedirs(output_lbl_dir, exist_ok=True)

# Process images and generate YOLO annotations
image_paths = glob(os.path.join(input_dir, "*.png"))
for img_path in image_paths:
    img_name = os.path.basename(img_path)
    label_name = img_name.replace(".png", ".txt")
    output_img_path = os.path.join(output_img_dir, img_name)
    output_lbl_path = os.path.join(output_lbl_dir, label_name)

    # Copy image
    cv2.imwrite(output_img_path, cv2.imread(img_path))
    # Generate YOLO labels
    detect_and_annotate(img_path, output_lbl_path)
