# Dataloader

In [1]:
import os
from PIL import Image, UnidentifiedImageError
import shutil
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
# Set kaggle API
!pip install --upgrade kaggle==1.7.4.2 --force-reinstall --no-deps
!mkdir /root/.kaggle

with open("/root/.kaggle/kaggle.json", "w+") as f:
    # Put your kaggle username & key here
    f.write('{"username":"gaolelin","key":"c79578333a3f6e722ce4e64cc649b9db"}')

!chmod 600 /root/.kaggle/kaggle.json

Collecting kaggle==1.7.4.2
  Downloading kaggle-1.7.4.2-py3-none-any.whl.metadata (16 kB)
Downloading kaggle-1.7.4.2-py3-none-any.whl (173 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/173.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m173.2/173.2 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kaggle
  Attempting uninstall: kaggle
    Found existing installation: kaggle 1.7.4.2
    Uninstalling kaggle-1.7.4.2:
      Successfully uninstalled kaggle-1.7.4.2
Successfully installed kaggle-1.7.4.2


In [3]:
# Download dataset
!mkdir '/content/data'
!kaggle datasets download -d nirmalsankalana/crop-pest-and-disease-detection -p /content/data --unzip

Dataset URL: https://www.kaggle.com/datasets/nirmalsankalana/crop-pest-and-disease-detection
License(s): CC0-1.0


In [4]:
def count_images(directory, extensions={".jpg", ".jpeg", ".png"}):
    count = 0
    for root, _, files in os.walk(directory):
        for file in files:
            if os.path.splitext(file)[1].lower() in extensions:
                count += 1
    return count

total_images = count_images("/content/data")
print(f"✅ Total images found in dataset: {total_images}")

✅ Total images found in dataset: 25220


In [5]:
# Set path
original_base = "/content/data"
new_base = "/content/data_yolo"
splits = ['train', 'val', 'test']
split_ratio = [0.8, 0.1, 0.1]
categories = sorted([d for d in os.listdir(original_base) if os.path.isdir(os.path.join(original_base, d))])

# Create new data dictionary
def create_yolo_dirs(base_dir, categories, splits):
    for split in splits:
        for cls in categories:
            os.makedirs(os.path.join(base_dir, split, cls), exist_ok=True)

create_yolo_dirs(new_base, categories, splits)

image_paths, image_labels = [], []
for category in os.listdir(original_base):  # os instead of manual categories
    category_path = os.path.join(original_base, category)
    if os.path.isdir(category_path):
        for file in os.listdir(category_path):
            if file.lower().endswith(".jpg"):
                image_paths.append(os.path.join(category_path, file))
                image_labels.append(category)

print(f"Images before filter: {len(image_labels)}")


# Split dataset
train_paths, temp_paths, train_labels, temp_labels = train_test_split(
    image_paths, image_labels, test_size=(1 - split_ratio[0]), stratify=image_labels, random_state=42)

val_paths, test_paths, val_labels, test_labels = train_test_split(
    temp_paths, temp_labels, test_size=split_ratio[2] / (split_ratio[1] + split_ratio[2]),
    stratify=temp_labels, random_state=42)

# Verify and remove crashed images
def save_images(paths, labels, split_name):
    counter = {}
    for path, label in tqdm(zip(paths, labels), total=len(paths), desc=f"Saving {split_name}"):
        try:
            with Image.open(path) as img:
                img.verify()  # Check file integrity
            with Image.open(path) as img:
                img.convert("RGB").load()  # Ensure image is readable
            counter[label] = counter.get(label, 0) + 1
            filename = os.path.basename(path)
            save_path = os.path.join(new_base, split_name, label, filename)
            shutil.copy(path, save_path)
        except (UnidentifiedImageError, OSError):
            continue

    # Stats
    total = sum(counter.values())
    print(f"\n{split_name} split summary:")
    print(f"  Total images: {total}\n")
    return total

# Save to new dictionary
all_total = 0
all_total += save_images(train_paths, train_labels, 'train')
all_total += save_images(val_paths, val_labels, 'val')
all_total += save_images(test_paths, test_labels, 'test')
print(f"  Images after filter : {all_total}\n")


Images before filter: 25220


Saving train: 100%|██████████| 20176/20176 [00:32<00:00, 616.73it/s]



train split summary:
  Total images: 20103



Saving val: 100%|██████████| 2522/2522 [00:03<00:00, 640.31it/s]



val split summary:
  Total images: 2511



Saving test: 100%|██████████| 2522/2522 [00:03<00:00, 638.87it/s]


test split summary:
  Total images: 2512

  Images after filter : 25126




