In [1]:
import os
import random
from helper import *

In [2]:
images_folder = 'dataset/input/images'
labels_folder = 'dataset/input/labels'

In [4]:
class_mapping = {
    'apple': 0,
    'avocado': 1,
    'banana': 2,
    'kiwi': 3,
    'lemon': 4,
    'orange': 5,
    'pear': 6,
    'pomegranate': 7,
    'strawberry': 8,
    'watermelon': 9
}


#### Adding more images and labels to the dataset (300 images of apple,banana and orange)
[Fruit Images for Object Detection](https://www.kaggle.com/datasets/mbkinaci/fruit-images-for-object-detection)


In [7]:
# Download and extract the "Fruit Images for Object Detection" dataset from Kaggle
# This downloads `mbkinaci/fruit-images-for-object-detection`, extracts any zip files
# (train/test zips) and moves image/label files into
# `dataset/fruit_images (apple, banana, orange)` (flattened into the parent directory)

from kaggle.api.kaggle_api_extended import KaggleApi
import os
import zipfile
import shutil
import glob

DATASET = "mbkinaci/fruit-images-for-object-detection"
TMP_DIR = "dataset/kaggle_temp/fruit_images"
OUT_PARENT = "dataset/fruit_images (apple, banana, orange)"
FORCE = False  # Set to True to re-download and overwrite

# extensions to keep and move to the output folder
KEEP_EXTS = {'.jpg', '.jpeg', '.png', '.txt', '.xml'}

os.makedirs(TMP_DIR, exist_ok=True)
os.makedirs(OUT_PARENT, exist_ok=True)

# If OUT_PARENT already contains files and FORCE is False, skip download
if not FORCE and any(os.scandir(OUT_PARENT)):
    print(f"Target folder '{OUT_PARENT}' already exists and is not empty. Skipping download. Set FORCE=True to override.")
else:
    api = KaggleApi()
    api.authenticate()
    print(f"Downloading dataset {DATASET} into {TMP_DIR}...")
    api.dataset_download_files(DATASET, path=TMP_DIR, unzip=False)

    # Extract all zip files found in TMP_DIR
    zip_paths = [os.path.join(TMP_DIR, f) for f in os.listdir(TMP_DIR) if f.lower().endswith('.zip')]
    if not zip_paths:
        print("No zip files found in the download folder. If you already downloaded the dataset manually, ensure the zip files are placed in the temp folder.")

    for zpath in zip_paths:
        print(f"Extracting {zpath}...")
        with zipfile.ZipFile(zpath, 'r') as zf:
            extract_folder = os.path.join(TMP_DIR, os.path.splitext(os.path.basename(zpath))[0])
            os.makedirs(extract_folder, exist_ok=True)
            zf.extractall(extract_folder)

    # Walk extracted folders and move relevant files into OUT_PARENT (flattened)
    moved = 0
    skipped = 0
    for root, dirs, files in os.walk(TMP_DIR):
        for fname in files:
            if fname.lower().endswith('.zip'):
                continue
            ext = os.path.splitext(fname)[1].lower()
            if ext not in KEEP_EXTS:
                skipped += 1
                continue
            src = os.path.join(root, fname)
            dst = os.path.join(OUT_PARENT, fname)
            # Avoid clobbering: if filename exists, append suffix
            if os.path.exists(dst):
                base, ext = os.path.splitext(dst)
                i = 1
                new_dst = f"{base}_dup{i}{ext}"
                while os.path.exists(new_dst):
                    i += 1
                    new_dst = f"{base}_dup{i}{ext}"
                dst = new_dst
            shutil.move(src, dst)
            moved += 1

    # Clean up temporary files and folders
    print("Cleaning temporary files...")
    try:
        # Remove any remaining zip files
        for z in glob.glob(os.path.join(TMP_DIR, '*.zip')):
            os.remove(z)
        # Remove the TMP_DIR tree
        if os.path.exists(TMP_DIR):
            shutil.rmtree(TMP_DIR)
    except Exception as e:
        print("Warning while cleaning temp files:", e)

    print(f"Done. Moved {moved} files into: {OUT_PARENT} (skipped {skipped} non-image/label files)")

# --- Flatten any nested extracted folders that may already exist inside OUT_PARENT ---
# Move any image/label files from subfolders (e.g., 'fruit-images-for-object-detection/train_zip/train' or
# 'fruit-images-for-object-detection/test_zip/test') up into OUT_PARENT, avoiding overwrites.

moved_up = 0
skipped_up = 0
for entry in os.listdir(OUT_PARENT):
    entry_path = os.path.join(OUT_PARENT, entry)
    if os.path.isdir(entry_path):
        # Walk this subfolder and move allowed files up
        for root, dirs, files in os.walk(entry_path):
            for fname in files:
                ext = os.path.splitext(fname)[1].lower()
                if ext not in KEEP_EXTS:
                    skipped_up += 1
                    continue
                src = os.path.join(root, fname)
                dst = os.path.join(OUT_PARENT, fname)
                if os.path.exists(dst):
                    base, ext = os.path.splitext(dst)
                    i = 1
                    new_dst = f"{base}_dup{i}{ext}"
                    while os.path.exists(new_dst):
                        i += 1
                        new_dst = f"{base}_dup{i}{ext}"
                    dst = new_dst
                shutil.move(src, dst)
                moved_up += 1
        # Attempt to remove the empty directory tree
        try:
            shutil.rmtree(entry_path)
        except Exception:
            pass

if moved_up:
    print(f"Flattened nested folders: moved {moved_up} files into {OUT_PARENT} (skipped {skipped_up} other files)")
else:
    print("No nested folders to flatten or nothing to move.")


Downloading dataset mbkinaci/fruit-images-for-object-detection into dataset/kaggle_temp/fruit_images...
Dataset URL: https://www.kaggle.com/datasets/mbkinaci/fruit-images-for-object-detection
Extracting dataset/kaggle_temp/fruit_images/fruit-images-for-object-detection.zip...
Extracting dataset/kaggle_temp/fruit_images/fruit-images-for-object-detection.zip...
Cleaning temporary files...
Done. Moved 600 files into: dataset/fruit_images (apple, banana, orange) (skipped 0 non-image/label files)
Cleaning temporary files...
Done. Moved 600 files into: dataset/fruit_images (apple, banana, orange) (skipped 0 non-image/label files)
No nested folders to flatten or nothing to move.
No nested folders to flatten or nothing to move.


In [8]:
# Check and fix VOC XML sizes (width and height)
fix_voc_xml_sizes(
    "dataset/fruit_images (apple, banana, orange)",
    "dataset/fruit_images (apple, banana, orange)")

Fixed apple_1.xml: set width=349, height=349
No change for apple_10.xml
No change for apple_11.xml
No change for apple_12.xml
No change for apple_13.xml
No change for apple_14.xml
No change for apple_15.xml
No change for apple_16.xml
Fixed apple_17.xml: set width=700, height=800
No change for apple_18.xml
No change for apple_19.xml
No change for apple_2.xml
Fixed apple_20.xml: set width=500, height=500
No change for apple_15.xml
No change for apple_16.xml
Fixed apple_17.xml: set width=700, height=800
No change for apple_18.xml
No change for apple_19.xml
No change for apple_2.xml
Fixed apple_20.xml: set width=500, height=500
No change for apple_21.xml
No change for apple_22.xml
No change for apple_23.xml
No change for apple_24.xml
No change for apple_25.xml
No change for apple_26.xml
No change for apple_27.xml
Fixed apple_28.xml: set width=300, height=300
No change for apple_21.xml
No change for apple_22.xml
No change for apple_23.xml
No change for apple_24.xml
No change for apple_25.xm

In [9]:
# Convert all VOC XML annotations to YOLO format
batch_convert_voc_to_yolo(
    "dataset/fruit_images (apple, banana, orange)",
    "dataset/fruit_images (apple, banana, orange)",
    class_mapping
)

In [10]:
# Remove any XML files that have corresponding verified TXT files
remove_verified_xml("dataset/fruit_images (apple, banana, orange)")

Deleted: dataset/fruit_images (apple, banana, orange)/apple_1.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_10.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_11.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_12.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_13.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_14.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_15.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_16.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_17.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_18.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_19.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_2.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_20.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_21.xml
Deleted: dataset/fruit_images (apple, banana, orange)/apple_22.x

In [None]:
# Add more image samples and labels to the dataset
split_files_by_extension(
    'dataset/fruit_images (apple, banana, orange)',
    {
        '.jpg': images_folder,
        '.jpeg': images_folder,
        '.png': images_folder,
        '.txt': labels_folder,
        '.xml': labels_folder
    }
)

[Fruit Detection Dataset](https://www.kaggle.com/datasets/lakshaytyagi01/fruit-detection)

In [13]:
import zipfile
from kaggle.api.kaggle_api_extended import KaggleApi

DATASET_PATH = "dataset/fruit-detection"
KAGGLE_DATASET = "lakshaytyagi01/fruit-detection"

if os.path.isdir(DATASET_PATH):
    print(f"Dataset already exists at {DATASET_PATH}, skipping download.")
else:
    print("Downloading dataset from Kaggle...")
    os.makedirs(DATASET_PATH, exist_ok=True)
    api = KaggleApi()
    api.authenticate()
    api.dataset_download_files(KAGGLE_DATASET, path=DATASET_PATH, unzip=False)
    # Find zip file
    for file in os.listdir(DATASET_PATH):
        if file.endswith(".zip"):
            zip_path = os.path.join(DATASET_PATH, file)
            with zipfile.ZipFile(zip_path, "r") as zip_ref:
                zip_ref.extractall(DATASET_PATH)
            os.remove(zip_path)
    print("Download and extraction complete.")


Downloading dataset from Kaggle...
Dataset URL: https://www.kaggle.com/datasets/lakshaytyagi01/fruit-detection
Download and extraction complete.


Extract 150 pictures of watermelon from dataset

In [12]:
filter_and_extract_images_by_class(
    data_root_dir="dataset/fruit-detection/Fruits-detection",
    class_indices_to_keep={5},  # watermelon
    output_images_dir="dataset/unprocessed/images",
    output_labels_dir="dataset/unprocessed/labels",
    target_num_labels=150
)

Collected 150 labels/images with specified classes.


In [None]:
rename_images_and_labels(images_folder, labels_folder, prefix='image_')

In [36]:
delete_labels_without_images("dataset/unprocessed/labels", "dataset/unprocessed/images")

Deleted 0 label files without matching images:


In [13]:
relabel_yolo_labels("dataset/unprocessed/labels", {5: 9})  # Remap watermelon class (5) to (9)

Relabeled 150 label files.


In [None]:
count_labels_per_class("dataset/input/labels", "dataset/input/images", class_mapping)

Total object counts by class:
Class 9 ("watermelon"): 353


In [15]:
split_files_by_extension(
    'dataset/unprocessed/images',
    {'.jpg': images_folder,
     '.jpeg': images_folder,
     '.png': images_folder})

split_files_by_extension(
    'dataset/unprocessed/labels',
    {'.txt': labels_folder})

#### Dataset labels by class

In [11]:
count_labels_per_class(labels_folder, images_folder, class_mapping)

Total object counts by class:
Class 0 ("apple"): 545
Class 1 ("avocado"): 307
Class 2 ("banana"): 749
Class 3 ("kiwi"): 469
Class 4 ("lemon"): 357
Class 5 ("orange"): 493
Class 6 ("pear"): 325
Class 7 ("pomegranate"): 309
Class 8 ("strawberry"): 266
Class 9 ("watermelon"): 744


In [38]:
count_files_by_extension(images_folder)
count_files_by_extension(labels_folder)

.jpg: 559
.jpeg: 128
.png: 112
.txt: 505


Counter({'.txt': 505})

#### Split Dataset into train/validation/test sets with 75/15/10 class distribution (+-5%)

In [8]:
labels_folder = "dataset/input/labels"
images_folder = "dataset/input/images"
SEED = 42
random.seed(SEED)
n_classes = 10

In [13]:
# Count which classes are in each image
img_class_map = count_image_classes(labels_folder)
# Iterative stratified train/val+test split (75/15/10 split)
train_set, valtest_set = iterative_train_val_split(img_class_map.keys(), img_class_map, n_classes, val_ratio=0.25, seed=SEED)
valtest_class_map = {lbl: img_class_map[lbl] for lbl in valtest_set}
val_set, test_set = iterative_train_val_split(valtest_class_map.keys(), valtest_class_map, n_classes, val_ratio=0.4, seed=SEED)
# save splits
save_balanced_split(train_set, labels_folder, images_folder, "dataset/split/train/labels", "dataset/split/train/images")
save_balanced_split(val_set, labels_folder, images_folder, "dataset/split/val/labels", "dataset/split/val/images")
save_balanced_split(test_set, labels_folder, images_folder, "dataset/split/test/labels", "dataset/split/test/images")


In [6]:
count_labels_per_class_v2("dataset/split/train/labels", "dataset/split/train/images", class_mapping)


ðŸ“Š Object Distribution Summary
   Labels Folder: dataset/split/train/labels
   Images Folder: dataset/split/train/images
Class ID   Class Name           Count      %         
------------------------------------------------------------
0          apple                419         12.28%
1          avocado              241          7.06%
2          banana               569         16.68%
3          kiwi                 371         10.87%
4          lemon                260          7.62%
5          orange               348         10.20%
6          pear                 235          6.89%
7          pomegranate          210          6.15%
8          strawberry           201          5.89%
9          watermelon           558         16.35%
------------------------------------------------------------
TOTAL                           3412          100.00%


Counter({2: 569,
         9: 558,
         0: 419,
         3: 371,
         5: 348,
         4: 260,
         1: 241,
         6: 235,
         7: 210,
         8: 201})

In [9]:
# Undersample/oversample to balance classes in training set
img_class_map_train = count_image_classes("dataset/split/train/labels")
undersampled = balance_classes(img_class_map_train, target_num_per_class=300, seed=SEED)
# Save undersampled training set
save_balanced_split(
    undersampled,
    "dataset/split/train/labels",
    "dataset/split/train/images",
    "dataset/split/train/balanced_labels",
    "dataset/split/train/balanced_images"
)

In [10]:
# Oversample (duplicate) minority classes in the balanced train folder
oversample_classes(
    selected_files=os.listdir("dataset/split/train/balanced_labels"),
    image_class_map=count_image_classes("dataset/split/train/balanced_labels"),
    class_target=300,
    output_images="dataset/split/train/balanced_images",
    output_labels="dataset/split/train/balanced_labels",
    labels_folder="dataset/split/train/balanced_labels",
    images_folder="dataset/split/train/balanced_images"
)

Oversampling complete. New class distribution: Counter({4: 307, 3: 304, 6: 303, 7: 302, 1: 302, 9: 300, 5: 300, 2: 300, 8: 300, 0: 300})


In [12]:
# Replace original train folders with balanced versions for YOLO compatibility
# Delete original train/labels and train/images folders
train_labels_orig = "dataset/split/train/labels"
train_images_orig = "dataset/split/train/images"
train_labels_balanced = "dataset/split/train/balanced_labels"
train_images_balanced = "dataset/split/train/balanced_images"

if os.path.exists(train_labels_orig):
    shutil.rmtree(train_labels_orig)
    print(f"Deleted: {train_labels_orig}")

if os.path.exists(train_images_orig):
    shutil.rmtree(train_images_orig)
    print(f"Deleted: {train_images_orig}")

# Rename balanced folders to standard YOLO names
if os.path.exists(train_labels_balanced):
    os.rename(train_labels_balanced, train_labels_orig)
    print(f"Renamed: {train_labels_balanced} -> {train_labels_orig}")

if os.path.exists(train_images_balanced):
    os.rename(train_images_balanced, train_images_orig)
    print(f"Renamed: {train_images_balanced} -> {train_images_orig}")

print("\nâœ… Training dataset updated with balanced class distribution")
print(f"   Labels: {train_labels_orig}")
print(f"   Images: {train_images_orig}")


Deleted: dataset/split/train/labels
Deleted: dataset/split/train/images
Renamed: dataset/split/train/balanced_labels -> dataset/split/train/labels
Renamed: dataset/split/train/balanced_images -> dataset/split/train/images

âœ… Training dataset updated with balanced class distribution
   Labels: dataset/split/train/labels
   Images: dataset/split/train/images
Deleted: dataset/split/train/images
Renamed: dataset/split/train/balanced_labels -> dataset/split/train/labels
Renamed: dataset/split/train/balanced_images -> dataset/split/train/images

âœ… Training dataset updated with balanced class distribution
   Labels: dataset/split/train/labels
   Images: dataset/split/train/images


#### Dataset summary (Final)

In [13]:
count_labels_per_class_v2("dataset/split/train/labels", "dataset/split/train/images", class_mapping)
print("\n")
count_labels_per_class_v2("dataset/split/val/labels", "dataset/split/val/images", class_mapping)
print("\n")
count_labels_per_class_v2("dataset/split/test/labels", "dataset/split/test/images", class_mapping)

ðŸ“Š Object Distribution Summary
   Labels Folder: dataset/split/train/labels
   Images Folder: dataset/split/train/images
Class ID   Class Name           Count      %         
------------------------------------------------------------
0          apple                301          9.91%
1          avocado              302          9.95%
2          banana               309         10.18%
3          kiwi                 307         10.11%
4          lemon                309         10.18%
5          orange               302          9.95%
6          pear                 304         10.01%
7          pomegranate          302          9.95%
8          strawberry           300          9.88%
9          watermelon           300          9.88%
------------------------------------------------------------
TOTAL                           3036          100.00%


ðŸ“Š Object Distribution Summary
   Labels Folder: dataset/split/val/labels
   Images Folder: dataset/split/val/images
Class ID   Class

Counter({9: 79, 2: 79, 5: 59, 6: 46, 0: 41, 4: 39, 7: 34, 3: 32, 1: 30, 8: 24})