In [1]:
import os
import shutil
import xml.etree.ElementTree as ET
from collections import Counter
from PIL import Image


In [2]:
def split_files_by_extension(src_folder, ext_to_targetfolder):
    """
    Moves files from src_folder into target folders according to their extension.
    ext_to_targetfolder: dict, e.g. {'.jpg': 'images', '.txt': 'labels', '.xml': 'annotations'}
    """
    for ext, target_folder in ext_to_targetfolder.items():
        os.makedirs(target_folder, exist_ok=True)

    for fname in os.listdir(src_folder):
        _, ext = os.path.splitext(fname)
        ext = ext.lower()
        if ext in ext_to_targetfolder:
            shutil.copy(
                os.path.join(src_folder, fname),
                os.path.join(ext_to_targetfolder[ext], fname)
            )


In [3]:
def rename_images_and_labels(images_folder, labels_folder, prefix='img_'):
    """Renames images and their corresponding label files with a consistent prefix and zero-padded index."""
    image_files = [f for f in os.listdir(images_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    image_files.sort()  # Ensure predictable order
    idx = 1  # Counter for new names

    for img_file in image_files:
        label_name = os.path.splitext(img_file)[0] + '.txt'
        label_path = os.path.join(labels_folder, label_name)
        if os.path.exists(label_path):
            # Set up new base name
            new_base = f"{prefix}{str(idx).zfill(3)}"
            img_ext = os.path.splitext(img_file)[1]
            new_img_name = f"{new_base}{img_ext}"
            # Rename image
            old_img_path = os.path.join(images_folder, img_file)
            new_img_path = os.path.join(images_folder, new_img_name)
            os.rename(old_img_path, new_img_path)
            # Rename label
            new_label_name = f"{new_base}.txt"
            new_label_path = os.path.join(labels_folder, new_label_name)
            os.rename(label_path, new_label_path)
            idx += 1
        else:
            print(f"Skipped: {img_file} (no matching label)")


In [4]:
images_folder = 'dataset/input/images'
labels_folder = 'dataset/input/labels'

In [5]:
def count_files_by_extension(directory):
    ext_counter = Counter()
    for fname in os.listdir(directory):
        _, ext = os.path.splitext(fname)
        ext = ext.lower()
        if ext:
            ext_counter[ext] += 1
    for extension, count in ext_counter.items():
        print(f"{extension}: {count}")
    return ext_counter

In [6]:
class_mapping = {
    'apple': 0,
    'avocado': 1,
    'banana': 2,
    'kiwi': 3,
    'lemon': 4,
    'orange': 5,
    'pear': 6,
    'pomegranate': 7,
    'strawberry': 8,
    'watermelon': 9
}


In [7]:
def voc_to_yolo(xml_file, class_map, output_txt):
    """ Convert VOC XML annotation to YOLO format text file."""
    tree = ET.parse(xml_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
    
    yolo_lines = []
    for obj in root.findall('object'):
        name = obj.find('name').text
        class_idx = class_map.get(name)
        if class_idx is None:
            continue  # Skip if not in map
        bndbox = obj.find('bndbox')
        xmin = float(bndbox.find('xmin').text)
        ymin = float(bndbox.find('ymin').text)
        xmax = float(bndbox.find('xmax').text)
        ymax = float(bndbox.find('ymax').text)
        
        x_center = ((xmin + xmax) / 2) / w
        y_center = ((ymin + ymax) / 2) / h
        bw = (xmax - xmin) / w
        bh = (ymax - ymin) / h
        
        yolo_line = f"{class_idx} {x_center:.6f} {y_center:.6f} {bw:.6f} {bh:.6f}\n"
        yolo_lines.append(yolo_line)

    with open(output_txt, 'w') as f:
        f.writelines(yolo_lines)

def batch_convert_voc_to_yolo(input_folder, output_folder, class_map):
    os.makedirs(output_folder, exist_ok=True)
    for fname in os.listdir(input_folder):
        if fname.lower().endswith('.xml'):
            xml_path = os.path.join(input_folder, fname)
            base = os.path.splitext(fname)[0]
            txt_path = os.path.join(output_folder, base + '.txt')
            voc_to_yolo(xml_path, class_map, txt_path)

In [9]:
def remove_verified_xml(folder):
    """ Removes XML files in the folder if a corresponding TXT file exists."""
    for fname in os.listdir(folder):
        if fname.lower().endswith('.xml'):
            base = os.path.splitext(fname)[0]
            txt_name = base + '.txt'
            txt_path = os.path.join(folder, txt_name)
            if os.path.exists(txt_path):
                xml_path = os.path.join(folder, fname)
                os.remove(xml_path)
                print(f"Deleted: {xml_path}")
            else:
                print(f"Skipped (no TXT): {fname}")

In [10]:
def fix_voc_xml_sizes(xml_folder, image_folder, image_exts=(".jpg", ".jpeg", ".png", ".bmp")):
    """Fixes width and height in VOC XML files based on actual image sizes."""
    for fname in os.listdir(xml_folder):
        if fname.lower().endswith(".xml"):
            base = os.path.splitext(fname)[0]
            # Find image file with possible extension
            img_path = None
            for ext in image_exts:
                candidate = os.path.join(image_folder, base + ext)
                if os.path.exists(candidate):
                    img_path = candidate
                    break
            if not img_path:
                print(f"Image for {fname} not found.")
                continue
            # Read image size
            try:
                with Image.open(img_path) as im:
                    true_w, true_h = im.size
            except Exception as e:
                print(f"Error opening {img_path}: {e}")
                continue
            # Update XML
            xml_path = os.path.join(xml_folder, fname)
            tree = ET.parse(xml_path)
            root = tree.getroot()
            size = root.find('size')
            xml_w = int(size.find('width').text)
            xml_h = int(size.find('height').text)
            if xml_w != true_w or xml_h != true_h:
                size.find('width').text = str(true_w)
                size.find('height').text = str(true_h)
                tree.write(xml_path)
                print(f"Fixed {fname}: set width={true_w}, height={true_h}")
            else:
                print(f"No change for {fname}")


#### Adding more images and labels to the dataset
[Fruit Images for Object Detection](https://www.kaggle.com/datasets/mbkinaci/fruit-images-for-object-detection)


In [None]:
# Check and fix VOC XML sizes (width and height)
fix_voc_xml_sizes(
    "dataset/fruit_images (apple, banana, orange)",
    "dataset/fruit_images (apple, banana, orange)")

Fixed apple_1.xml: set width=349, height=349
No change for apple_10.xml
No change for apple_11.xml
No change for apple_12.xml
No change for apple_13.xml
No change for apple_14.xml
No change for apple_15.xml
No change for apple_16.xml
Fixed apple_17.xml: set width=700, height=800
No change for apple_18.xml
No change for apple_19.xml
No change for apple_2.xml
Fixed apple_20.xml: set width=500, height=500
No change for apple_21.xml
No change for apple_22.xml
No change for apple_23.xml
No change for apple_24.xml
No change for apple_25.xml
No change for apple_26.xml
No change for apple_27.xml
Fixed apple_28.xml: set width=300, height=300
No change for apple_29.xml
No change for apple_3.xml
No change for apple_30.xml
No change for apple_31.xml
No change for apple_32.xml
No change for apple_33.xml
No change for apple_35.xml
No change for apple_36.xml
Fixed apple_37.xml: set width=389, height=352
No change for apple_38.xml
No change for apple_39.xml
No change for apple_4.xml
No change for appl

In [None]:
remove_verified_xml("dataset/fruit_images (apple, banana, orange)")

Deleted: dataset/fruit_images (apple, banana, orange)/apple_1.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_10.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_11.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_12.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_13.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_14.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_15.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_16.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_17.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_18.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_19.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_2.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_20.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_21.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_22.x

In [None]:
# Convert all VOC XML annotations to YOLO format, and remove XMLs with verified TXT
batch_convert_voc_to_yolo(
    "dataset/fruit_images (apple, banana, orange)",
    "dataset/fruit_images (apple, banana, orange)",
    class_mapping
)

In [None]:
print(f"File counts by extension: {images_folder}")
count_files_by_extension(images_folder)
print(f"File counts by extension: {labels_folder}")
count_files_by_extension(labels_folder)
print(f"File counts by extension: dataset/fruit_images (apple, banana, orange)")
count_files_by_extension("dataset/fruit_images (apple, banana, orange)")


File counts by extension: dataset/input/images
.jpg: 1043
.png: 113
.jpeg: 139
File counts by extension: dataset/input/labels
.txt: 1295
File counts by extension: dataset/fruit_images (apple, banana, orange)
.jpg: 300
.txt: 300


Counter({'.jpg': 300, '.txt': 300})

In [None]:
# Add more image samples and labels to the dataset
split_files_by_extension(
    'dataset/fruit_images (apple, banana, orange)',
    {
        '.jpg': images_folder,
        '.jpeg': images_folder,
        '.png': images_folder,
        '.txt': labels_folder,
        '.xml': labels_folder
    }
)

[Fruit Detection Dataset](https://www.kaggle.com/datasets/lakshaytyagi01/fruit-detection)

In [13]:
import zipfile
from kaggle.api.kaggle_api_extended import KaggleApi

DATASET_PATH = "dataset/fruit-detection"
KAGGLE_DATASET = "lakshaytyagi01/fruit-detection"

if os.path.isdir(DATASET_PATH):
    print(f"Dataset already exists at {DATASET_PATH}, skipping download.")
else:
    print("Downloading dataset from Kaggle...")
    os.makedirs(DATASET_PATH, exist_ok=True)
    api = KaggleApi()
    api.authenticate()
    api.dataset_download_files(KAGGLE_DATASET, path=DATASET_PATH, unzip=False)
    # Find zip file
    for file in os.listdir(DATASET_PATH):
        if file.endswith(".zip"):
            zip_path = os.path.join(DATASET_PATH, file)
            with zipfile.ZipFile(zip_path, "r") as zip_ref:
                zip_ref.extractall(DATASET_PATH)
            os.remove(zip_path)
    print("Download and extraction complete.")


Downloading dataset from Kaggle...
Dataset URL: https://www.kaggle.com/datasets/lakshaytyagi01/fruit-detection
Download and extraction complete.


In [None]:
def filter_and_extract_images_by_class(
    data_root_dir,
    class_indices_to_keep,
    output_images_dir,
    output_labels_dir,
    target_num_labels=300,
    image_extensions=(".jpg", ".jpeg", ".png")
):
    """ Filters and extracts images and labels containing specified class indices."""
    os.makedirs(output_images_dir, exist_ok=True)
    os.makedirs(output_labels_dir, exist_ok=True)
    found = 0

    for subset_folder in os.listdir(data_root_dir):
        subset_path = os.path.join(data_root_dir, subset_folder)
        if not os.path.isdir(subset_path):
            continue
        labels_dir = os.path.join(subset_path, "labels")
        images_dir = os.path.join(subset_path, "images")
        if not (os.path.exists(labels_dir) and os.path.exists(images_dir)):
            continue

        for lbl_file in os.listdir(labels_dir):
            if not lbl_file.endswith(".txt"):
                continue
            label_path = os.path.join(labels_dir, lbl_file)
            with open(label_path, "r") as f:
                lines = f.readlines()
            # Filter for relevant class indices only
            filtered_lines = [
                line for line in lines
                if line.strip() and int(line.split()[0]) in class_indices_to_keep
            ]
            # Only proceed if at least one relevant class instance found
            if filtered_lines:
                # Save new filtered label file
                out_lbl_path = os.path.join(output_labels_dir, lbl_file)
                with open(out_lbl_path, "w") as out_f:
                    out_f.writelines(filtered_lines)
                # Copy matching image file
                img_base = os.path.splitext(lbl_file)[0]
                img_copied = False
                for ext in image_extensions:
                    img_path = os.path.join(images_dir, img_base + ext)
                    if os.path.exists(img_path):
                        shutil.copy(img_path, os.path.join(output_images_dir, img_base + ext))
                        img_copied = True
                        break
                if img_copied:
                    found += 1
                if found >= target_num_labels:
                    print(f"Collected {found} labels/images with specified classes.")
                    return
    print(f"Collected {found} labels/images with specified classes (target was {target_num_labels}).")

In [None]:
rename_images_and_labels(images_folder, labels_folder, prefix='img_')

In [11]:
def count_labels_per_class(labels_folder, images_folder, class_mapping, image_exts=(".jpg", ".jpeg", ".png")):
    """ Counts the number of labeled objects per class in YOLO format label files."""
    # Build index to name mapping for output
    idx_to_name = {idx: name for name, idx in class_mapping.items()}
    class_counter = Counter()
    invalid_files = []
    label_files = [f for f in os.listdir(labels_folder) if f.endswith('.txt')]
    for label_fname in label_files:
        base = os.path.splitext(label_fname)[0]
        found_image = False
        for ext in image_exts:
            image_path = os.path.join(images_folder, base + ext)
            if os.path.exists(image_path):
                found_image = True
                break
        if not found_image:
            invalid_files.append(label_fname)
            continue
        # Count labels in file
        with open(os.path.join(labels_folder, label_fname), 'r') as f:
            for line in f:
                parts = line.strip().split()
                if parts:
                    try:
                        class_idx = int(parts[0])
                        class_counter[class_idx] += 1
                    except ValueError:
                        continue
    # Print summary with class names
    print("Total object counts by class:")
    for cls, count in sorted(class_counter.items()):
        name = idx_to_name.get(cls, "Unknown")
        print(f"Class {cls} (\"{name}\"): {count}")
    if invalid_files:
        print("\nLabel files with missing image counterparts:")
        for fname in invalid_files:
            print(fname)


In [None]:
count_labels_per_class(labels_folder, images_folder, class_mapping)

Total object counts by class:
Class 0 ("apple"): 545
Class 1 ("avocado"): 307
Class 2 ("banana"): 749
Class 3 ("kiwi"): 469
Class 4 ("lemon"): 357
Class 5 ("orange"): 493
Class 6 ("pear"): 325
Class 7 ("pomegranate"): 309
Class 8 ("strawberry"): 266
Class 9 ("watermelon"): 153
