## Splitting the wild train data

In [1]:
# Prepare wild image paths and labels (from wild_pool)
import numpy as np
from pathlib import Path
import json

wild_dir = Path("../data/wild_pool")

In [2]:
# Load authoritative class order from PV splits
split_dir = Path("../data/splits")
with open(split_dir / "class_names.json", "r") as f:
    class_names = json.load(f)

class_to_idx = {cls: i for i, cls in enumerate(class_names)}

image_paths = []
labels = []

valid_ext = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}

for cls in class_names:
    cls_dir = wild_dir / cls
    if not cls_dir.exists():
        raise FileNotFoundError(f"Missing class folder: {cls_dir}")

    for img in cls_dir.glob("*"):
        if img.suffix.lower() in valid_ext:
            image_paths.append(img)
            labels.append(class_to_idx[cls])

image_paths = np.array(image_paths, dtype=object)
labels = np.array(labels, dtype=int)

print(f"Total wild_pool images: {len(image_paths)}")

Total wild_pool images: 225


In [3]:
# Train/validation split for wild_pool
from sklearn.model_selection import train_test_split

SEED = 42

wild_train_paths, wild_val_paths, wild_train_labels, wild_val_labels = train_test_split(
    image_paths,
    labels,
    test_size=0.15,        # ~34 images val out of 225
    stratify=labels,
    random_state=SEED
)

print(f"Wild train: {len(wild_train_paths)}")
print(f"Wild val:   {len(wild_val_paths)}")

Wild train: 191
Wild val:   34


In [4]:
# Save wild splits (do NOT overwrite PV split files)
split_dir.mkdir(parents=True, exist_ok=True)

np.save(split_dir / "wild_train_paths.npy", wild_train_paths, allow_pickle=True)
np.save(split_dir / "wild_train_labels.npy", wild_train_labels)
np.save(split_dir / "wild_val_paths.npy", wild_val_paths, allow_pickle=True)
np.save(split_dir / "wild_val_labels.npy", wild_val_labels)

# Optional: save class_names again for clarity
with open(split_dir / "wild_class_names.json", "w") as f:
    json.dump(class_names, f)

print("Saved wild splits to:", split_dir)

Saved wild splits to: ..\data\splits


In [5]:
print("Train label counts:", np.bincount(wild_train_labels, minlength=len(class_names)))
print("Val label counts:  ", np.bincount(wild_val_labels, minlength=len(class_names)))


Train label counts: [17 20 20 20 19 17 18 14 28 18]
Val label counts:   [3 4 4 3 3 3 3 3 5 3]
