# Train Plant Disease Detector (Random Forest / SVM)

This notebook shows steps to load PlantVillage data, extract simple features (color histograms + HOG), train a RandomForest or SVM classifier, and save `model.pkl`.


In [ ]:
import os
from glob import glob
from PIL import Image
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import joblib
from model_utils import preprocess_image_for_ml

# Set dataset directory (update this path to where you extracted PlantVillage)
DATA_DIR = 'sample_data/PlantVillage'

# Collect classes
classes = [d for d in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR,d))]
classes = sorted(classes)
print('Found classes:', classes[:10])

# Example: load a small subset (change as needed)
X = []
y = []
for cls in classes:
    imgs = glob(os.path.join(DATA_DIR, cls, '*.jpg'))[:200]  # limit per class for faster runs
    for p in imgs:
        img = Image.open(p).convert('RGB')
        feat = preprocess_image_for_ml(img)
        X.append(feat)
        y.append(cls)

X = np.array(X)
y = np.array(y)
print('X shape', X.shape)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

clf = RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred))
joblib.dump(clf, 'model.pkl')
print('Model saved to model.pkl')
