# 🍠 Potato Disease Classifier - Model Training

這份 Notebook 負責從資料切分到訓練模型，並產生模型權重 (best.pt)

In [None]:
# 安裝必要套件
!pip install -q ultralytics

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 資料切分
import pathlib, random, shutil

RAW_PATH = pathlib.Path('/content/drive/MyDrive/training')
DST_PATH = pathlib.Path('/content/datasets/potato')
random.seed(42)

for cls_dir in RAW_PATH.glob('Potato_*'):
    imgs = list(cls_dir.glob('*'))
    if len(imgs) == 0:
        print(f'[WARN] {cls_dir} is empty!')
        continue
    random.shuffle(imgs)
    n = len(imgs)
    splits = {'train': imgs[:int(0.8*n)],
              'val'  : imgs[int(0.8*n):int(0.9*n)],
              'test' : imgs[int(0.9*n):]}
    for split, files in splits.items():
        out_dir = DST_PATH / split / cls_dir.name
        out_dir.mkdir(parents=True, exist_ok=True)
        for f in files:
            shutil.copy(f, out_dir)
print('✅ Dataset prepared at', DST_PATH)

In [None]:
# 建立 potato.yaml
yaml_text = """
path: /content/datasets/potato
train: train
val: val
test: test
names:
  0: Early_Blight
  1: Late_Blight
  2: Healthy
"""
with open('/content/datasets/potato.yaml', 'w') as f:
    f.write(yaml_text)
print("✅ YAML saved")

In [None]:
# 訓練模型
from ultralytics import YOLO

model = YOLO('yolov8n-cls.pt')
results = model.train(
    data='/content/datasets/potato',
    epochs=50,
    imgsz=224,
    batch=32,
    lr0=1e-3,
    patience=10,
    project='plant_cls',
    name='potato_yolov8n_v2'
)

In [None]:
# 評估並產生混淆矩陣
metrics = model.val(save_json=True, plots=True)
print(f"Top-1 Accuracy : {metrics.top1:.4f}")

In [None]:
# 計算每類 Precision / Recall / F1
import numpy as np, pandas as pd

cm = metrics.confusion_matrix.matrix
tp = np.diag(cm)
precision = tp / cm.sum(axis=0)
recall = tp / cm.sum(axis=1)
f1 = 2 * precision * recall / (precision + recall + 1e-9)

labels = ['Early_Blight', 'Late_Blight', 'Healthy']
df = pd.DataFrame({'Precision': precision.round(3), 'Recall': recall.round(3), 'F1-Score': f1.round(3)}, index=labels)
df.to_csv('/content/classify_metrics.csv')
df