# Indian Medicinal Plants — YOLOv8 Classification (Ultralytics)
Train a YOLOv8 **classification** model on the Kaggle *Indian Medicinal Plant Image Dataset* with clear, step-by-step code and visuals.

**What you'll get:**
- Dataset download via Kaggle API (or use Google Drive/local path)
- Auto **train/val split** if missing
- Train **YOLOv8n-cls** with metrics
- Visualizations: class distribution, sample grids, training curves, **confusion matrix**, per-class metrics, predictions gallery

> Tip: If your Colab disconnects, just re-run from the top — paths are preserved.


## 0. Runtime & GPU
Make sure you're on **Runtime → Change runtime type → T4/A100 GPU** for faster training.

In [None]:

import sys, platform, os, pathlib, subprocess, math, random, shutil, glob, json
print("Python:", sys.version)
!nvidia-smi || echo "No GPU detected — training will be slower."


Python: 3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]
/bin/bash: line 1: nvidia-smi: command not found
No GPU detected — training will be slower.


## 1. Install Dependencies

In [None]:

!pip -q install ultralytics==8.3.25 kaggle==1.6.17 ipywidgets==8.1.2
from IPython.display import display, Image as IPyImage, Markdown
import os, json, shutil, random, glob, math, pathlib
import matplotlib.pyplot as plt
import numpy as np
print("Installed.")


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/82.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m81.9/82.7 kB[0m [31m17.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.7/82.7 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m878.7/878.7 kB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.4/139.4 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m39.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for kaggle (setup.py) ... [?25l[?25hdone
Installed.


## 2. Get the Dataset
You have two options:

### Option A (Recommended): Download directly from Kaggle
1. Create a Kaggle API token from **https://www.kaggle.com/settings/account** → *Create New Token* (downloads `kaggle.json`).
2. Run the cell below and **upload** your `kaggle.json`.

### Option B: Skip Kaggle & use an existing path
If you've already uploaded/extracted the dataset to Drive or Colab, set `DATASET_ROOT` manually in the next subsection.

In [None]:

import os, json, pathlib, shutil
from google.colab import files

KAGGLE_DATASET_REF = "warcoder/indian-medicinal-plant-image-dataset"
DATA_DIR = pathlib.Path("/content/data_medicinal_plants")
KAGGLE_DIR = pathlib.Path("/root/.kaggle")
KAGGLE_JSON = KAGGLE_DIR / "kaggle.json"

DATA_DIR.mkdir(parents=True, exist_ok=True)
KAGGLE_DIR.mkdir(parents=True, exist_ok=True)

print("Upload kaggle.json (from Kaggle Account settings). If you prefer to skip, cancel and go to Option B cell.")
try:
    uploaded = files.upload()  # Pop-up file chooser
    for fn in uploaded.keys():
        if fn == "kaggle.json":
            with open(KAGGLE_JSON, "wb") as f:
                f.write(uploaded[fn])
            os.chmod(KAGGLE_JSON, 0o600)
            print("kaggle.json saved.")
except Exception as e:
    print("Upload skipped or failed:", e)

if KAGGLE_JSON.exists():
    !kaggle datasets download -d $KAGGLE_DATASET_REF -p $DATA_DIR --force
    # Try common archive names
    import zipfile, glob
    zips = glob.glob(str(DATA_DIR / "*.zip"))
    for z in zips:
        print("Extracting:", z)
        with zipfile.ZipFile(z, 'r') as zip_ref:
            zip_ref.extractall(DATA_DIR)
    print("Dataset prepared under:", DATA_DIR)
else:
    print("No kaggle.json found. Use Option B below to point to your data path.")


Upload kaggle.json (from Kaggle Account settings). If you prefer to skip, cancel and go to Option B cell.


### Option B: Use an existing dataset path
If you already have the dataset extracted (e.g., in Drive), set `DATASET_ROOT` below. The code will attempt to auto-create a **train/val** split if missing.

In [None]:

# If you mounted Drive, your path might look like: /content/drive/MyDrive/medicinal_plants
# By default, we assume the Kaggle download extracted into DATA_DIR.
import pathlib, os, glob

DEFAULT_CANDIDATES = [
    "/content/data_medicinal_plants",
    "/content/data",
    "/content/drive/MyDrive/medicinal_plants",
    "/content/drive/MyDrive/datasets/indian_medicinal_plants"
]

# Auto-pick an existing directory if found, else keep the first as default.
cand = [p for p in DEFAULT_CANDIDATES if os.path.exists(p)]
DATASET_ROOT = pathlib.Path(cand[0] if cand else DEFAULT_CANDIDATES[0])
print("Using DATASET_ROOT:", DATASET_ROOT)
os.makedirs(DATASET_ROOT, exist_ok=True)


## 3. Inspect Structure & Create Train/Val Split (if needed)
We expect a folder-of-class-folders format. If `train/` and `val/` don't exist, we'll create them (80/20 split).

In [None]:
import os, shutil, random, pathlib, glob

random.seed(42)

def is_class_dir_structure(root):
    # True if root contains only class subfolders with image files
    classes = [p for p in pathlib.Path(root).glob("*") if p.is_dir()]
    if not classes:
        return False
    # Look for images inside at least one class
    exts = {".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff"}
    for c in classes:
        imgs = [p for p in c.rglob("*") if p.suffix.lower() in exts]
        if imgs:
            return True
    return False

def ensure_split(root, train_ratio=0.8):
    root = pathlib.Path(root)
    train_dir, val_dir = root / "train", root / "val"

    # Check if the expected zip file exists, if not, inform the user to download it.
    zip_file_path = root / "indian-medicinal-plant-image-dataset.zip"
    if not zip_file_path.exists():
        print(f"Dataset zip file not found at {zip_file_path}. Please ensure the dataset is downloaded and extracted into the DATASET_ROOT directory.")
        print("You can do this by uploading your kaggle.json in the cell above or manually placing the extracted dataset in the DATASET_ROOT path.")
        raise RuntimeError("Dataset not found in the expected location. Please download or point to the correct directory.")


    if train_dir.exists() and val_dir.exists():
        print("train/ and val/ already exist. Skipping split.")
        return train_dir, val_dir

    # If the dataset came extracted with a top-level folder, use it
    # Try to find a subfolder that is class-structured
    candidate = None
    for p in root.iterdir():
        if p.is_dir() and is_class_dir_structure(p):
            candidate = p
            break
    base = candidate if candidate else root
    print("Base for split:", base)

    # Collect classes
    classes = [p for p in base.iterdir() if p.is_dir()]
    if not classes:
        raise RuntimeError("No class subfolders found. Please set DATASET_ROOT to the folder containing class subfolders.")

    # Prepare train/val
    exts = {".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff"}
    train_dir.mkdir(exist_ok=True)
    val_dir.mkdir(exist_ok=True)
    for c in classes:
        (train_dir / c.name).mkdir(exist_ok=True, parents=True)
        (val_dir / c.name).mkdir(exist_ok=True, parents=True)
        imgs = [p for p in c.rglob("*") if p.suffix.lower() in exts]
        random.shuffle(imgs)
        split = int(len(imgs) * train_ratio)
        for p in imgs[:split]:
            shutil.copy2(p, train_dir / c.name / p.name)
        for p in imgs[split:]:
            shutil.copy2(p, val_dir / c.name / p.name)
        print(f"Class {c.name}: {len(imgs[:split])} train, {len(imgs[split:])} val.")
    return train_dir, val_dir

train_dir, val_dir = ensure_split(DATASET_ROOT)
print("Train:", train_dir)
print("Val:", val_dir)

## 4. Visualize Class Distribution

In [None]:

import os, pathlib, collections
import matplotlib.pyplot as plt

def count_images(folder):
    exts = {".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff"}
    counts = {}
    for cls in sorted([p for p in pathlib.Path(folder).glob("*") if p.is_dir()]):
        n = sum(1 for p in cls.rglob("*") if p.suffix.lower() in exts)
        counts[cls.name] = n
    return counts

train_counts = count_images(train_dir)
val_counts = count_images(val_dir)

plt.figure()
plt.bar(range(len(train_counts)), list(train_counts.values()))
plt.xticks(range(len(train_counts)), list(train_counts.keys()), rotation=90)
plt.title("Training set class distribution")
plt.tight_layout()
plt.show()

plt.figure()
plt.bar(range(len(val_counts)), list(val_counts.values()))
plt.xticks(range(len(val_counts)), list(val_counts.keys()), rotation=90)
plt.title("Validation set class distribution")
plt.tight_layout()
plt.show()


## 5. Peek at Sample Images

In [None]:

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random, pathlib

def sample_images(folder, n=16):
    exts = {".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff"}
    imgs = [p for p in pathlib.Path(folder).rglob("*") if p.suffix.lower() in exts]
    return random.sample(imgs, min(n, len(imgs)))

samples = sample_images(train_dir, n=16)
cols = 4
rows = max(1, (len(samples) + cols - 1)//cols)
plt.figure(figsize=(12, 3*rows))
for i, p in enumerate(samples, 1):
    plt.subplot(rows, cols, i)
    img = Image.open(p).convert("RGB")
    plt.imshow(img)
    plt.title(p.parent.name)
    plt.axis("off")
plt.tight_layout()
plt.show()


## 6. Train YOLOv8n-cls
We use Ultralytics' YOLOv8 **classification** model. You can adjust `epochs` and `imgsz` as needed.

In [None]:
import torch
from ultralytics import YOLO

# Load YOLOv8 classification model
model = YOLO("yolov8n-cls.pt")

# Train on dataset
results = model.train(
    data=str(DATASET_ROOT),
    epochs=15,
    imgsz=224,
    batch=32,
    device=0 if torch.cuda.is_available() else 'cpu'
)


## 7. Training Curves & Metrics
Ultralytics saves plots under `runs/classify/train/`. We’ll display the most useful ones.

In [None]:

import glob, os
from IPython.display import Image as IPyImage, display

run_dirs = sorted(glob.glob("runs/classify/*"), key=os.path.getmtime)
latest_run = run_dirs[-1]
print("Latest run:", latest_run)

plot_files = [
    "results.png",          # loss/accuracy curves
    "confusion_matrix.png", # confusion matrix
]
for pf in plot_files:
    p = os.path.join(latest_run, pf)
    if os.path.exists(p):
        display(IPyImage(filename=p))
    else:
        print("Missing plot:", pf)


## 8. Validate & Show Per-Class Metrics

In [None]:

metrics = model.val(split='val', imgsz=224, device=0 if torch.cuda.is_available() else 'cpu', save_json=True)
print(metrics)

# Try reading per-class metrics from Ultralytics JSON (if available)
import json, glob, os
json_candidates = glob.glob(os.path.join(latest_run, "metrics", "*.json")) + glob.glob(os.path.join(latest_run, "*.json"))
for jc in json_candidates:
    if os.path.basename(jc).startswith("results") or os.path.basename(jc).startswith("metrics"):
        with open(jc, "r") as f:
            data = json.load(f)
        print("Loaded metrics JSON:", jc)
        break

# Display a simple per-class accuracy table if available
try:
    from pandas import DataFrame
    # Fall back: build confusion-based per-class accuracy if confusion matrix npy exists
    cm_npy = os.path.join(latest_run, "confusion_matrix.npy")
    if os.path.exists(cm_npy):
        import numpy as np
        cm = np.load(cm_npy)
        per_class_acc = (cm.diagonal() / cm.sum(axis=1)).tolist()
        # Get class names from data.yaml if present
        names_txt = os.path.join(latest_run, "labels.txt")
        if os.path.exists(names_txt):
            with open(names_txt, "r") as f:
                names = [x.strip() for x in f.readlines() if x.strip()]
        else:
            # Try to infer from train_dir
            names = sorted([p.name for p in pathlib.Path(train_dir).glob("*") if p.is_dir()])
        df = DataFrame({"class": names, "accuracy": per_class_acc})
        display(df.head(20))
    else:
        print("No confusion_matrix.npy available to compute per-class accuracy.")
except Exception as e:
    print("Per-class metrics table not available:", e)


## 9. Inference & Prediction Gallery

In [None]:

# Collect a small sample from val set for inference preview
sample_val = []
for cls in sorted([p for p in pathlib.Path(val_dir).glob("*") if p.is_dir()]):
    candidates = list(cls.glob("*"))
    if candidates:
        sample_val.extend(random.sample(candidates, min(3, len(candidates))))
sample_val = sample_val[:24]

preds = model.predict(source=[str(p) for p in sample_val], imgsz=224, device=0 if torch.cuda.is_available() else 'cpu')
# preds is a list of Results; for classification, probs & names are available
from PIL import Image
import matplotlib.pyplot as plt

cols = 4
rows = max(1, (len(sample_val)+cols-1)//cols)
plt.figure(figsize=(12, 3*rows))
for i, (p, r) in enumerate(zip(sample_val, preds), 1):
    plt.subplot(rows, cols, i)
    img = Image.open(p).convert("RGB")
    plt.imshow(img)
    top_idx = int(np.argmax(r.probs.data.cpu().numpy()))
    top_name = r.names[top_idx]
    top_prob = float(r.probs.data[top_idx])
    plt.title(f"GT: {p.parent.name}\nPred: {top_name} ({top_prob:.2f})")
    plt.axis("off")
plt.tight_layout()
plt.show()


## 10. Export Best Model
We’ll save the best weights path and show where to find it.

In [None]:

best_weights = os.path.join(latest_run, "weights", "best.pt")
print("Best weights saved at:", best_weights)


## (Optional) Save to Google Drive

In [None]:

# from google.colab import drive
# drive.mount('/content/drive')
# !mkdir -p /content/drive/MyDrive/medicinal_plants_yolov8_runs
# !cp -r "$latest_run" "/content/drive/MyDrive/medicinal_plants_yolov8_runs/"
# print("Copied run folder to Drive.")
