## Train/validation/test split - do not run this notebook

All splits are index-based and reproducible across machines using a fixed random seed. Test set will be put away for now to avoid data leakage.

In [1]:
# Prepare image paths and labels
import numpy as np
from pathlib import Path
data_dir = Path(("../data/raw/PlantVillage"))

image_paths = []
labels = []

class_names = sorted([d.name for d in data_dir.iterdir() if d.is_dir()])
class_to_idx = {cls: i for i, cls in enumerate(class_names)}

for cls in class_names:
    for img in (data_dir / cls).glob("*"):
        image_paths.append(img)
        labels.append(class_to_idx[cls])

image_paths = np.array(image_paths)
labels = np.array(labels)

print(f"Total images: {len(image_paths)}")


Total images: 20639


In [2]:
# Dev/Test split. Important: test_paths is sacred! Don't touch it till the very end of the project.

from sklearn.model_selection import train_test_split

SEED = 42

dev_paths, test_paths, dev_labels, test_labels = train_test_split(
    image_paths,
    labels,
    test_size=0.15,
    stratify=labels,
    random_state=SEED
)

print(f"Dev set: {len(dev_paths)}")
print(f"Test set: {len(test_paths)}")


Dev set: 17543
Test set: 3096


In [3]:
# Train/validation split for dev set

train_paths, val_paths, train_labels, val_labels = train_test_split(
    dev_paths,
    dev_labels,
    test_size=0.2,
    stratify=dev_labels,
    random_state=SEED
)

print(f"Train: {len(train_paths)}")
print(f"Validation: {len(val_paths)}")


Train: 14034
Validation: 3509


In [4]:
from pathlib import Path

split_dir = Path("../data/splits")
split_dir.mkdir(parents=True, exist_ok=True)


In [5]:
np.save(split_dir / "train_paths.npy", train_paths)
np.save(split_dir / "train_labels.npy", train_labels)
np.save(split_dir / "val_paths.npy", val_paths)
np.save(split_dir / "val_labels.npy", val_labels)
np.save(split_dir / "test_paths.npy", test_paths)
np.save(split_dir / "test_labels.npy", test_labels)

In [6]:
arr = np.load("../data/splits/train_labels.npy", allow_pickle=True)

print(type(arr))
print(arr.shape)
print(arr.dtype)
print(arr)


<class 'numpy.ndarray'>
(14034,)
int32
[7 2 7 ... 7 0 7]
