In [None]:
# train_yolov8_binary_cls.py
# ------------------------------------------------------------
# Train a YOLOv8 model (classification task) for a 2-class problem.
# Uses the Ultralytics 'yolov8n-cls' checkpoint and fine-tunes on your data.
# ------------------------------------------------------------

import os
from ultralytics import YOLO

def main():
    # ----- 1) Paths -----
    # Replace these with your dataset folders
    train_dir = "data/train"  # folder with subfolders per class
    val_dir   = "data/val"    # folder with subfolders per class (for validation)
    test_dir  = "data/test"   # optional test set (for final evaluation)

    # Basic sanity checks (optional but helpful)
    assert os.path.isdir(train_dir), f"Train dir not found: {train_dir}"
    assert os.path.isdir(val_dir),   f"Val dir not found: {val_dir}"

    # ----- 2) Create / Load model -----
    # 'yolov8n-cls.pt' is a lightweight classification checkpoint
    # Other options: yolov8s-cls.pt, yolov8m-cls.pt (bigger, slower, potentially higher accuracy)
    model = YOLO("yolov8n-cls.pt")  # load a pretrained classification model

    # ----- 3) Training hyperparameters -----
    # Key args:
    # - data: pass folder path for classification datasets
    # - epochs: number of training epochs
    # - imgsz: images are resized to this square size
    # - batch: batch size (tune according to GPU memory)
    # - lr0: initial learning rate
    # - patience: early stopping patience (in epochs)
    # - optimizer: "SGD" or "AdamW" etc.
    # - dropout: dropout for classifier head (0.0~0.5 typical)
    # - workers: dataloader workers
    results = model.train(
        data=train_dir,
        val=val_dir,
        epochs=50,
        imgsz=224,
        batch=32,
        lr0=1e-3,
        patience=10,
        optimizer="AdamW",
        dropout=0.0,
        workers=4,
        verbose=True
    )

    # Training artifacts (best model, last model, logs) are saved under runs/classify/trainX/
    print("Training finished. Results saved to:", results.save_dir)

    # ----- 4) Validation / metrics on the validation set -----
    # This will evaluate the latest weights (best.pt by default) on the val set
    metrics = model.val(data=val_dir, imgsz=224, split="val")
    print("Validation metrics:", metrics)

    # ----- 5) Optional: Evaluate on a separate test set -----
    if os.path.isdir(test_dir):
        test_metrics = model.val(data=test_dir, imgsz=224, split="test")
        print("Test metrics:", test_metrics)

    # ----- 6) Inference demo (single image) -----
    # Provide a path to an image to get the predicted class and probability
    demo_image = None  # e.g., "data/test/class0/sample.jpg"
    if demo_image and os.path.isfile(demo_image):
        preds = model.predict(source=demo_image, imgsz=224)
        # preds[0].probs contains per-class probabilities
        print("Prediction:", preds[0].probs.top1, preds[0].probs.top1conf, preds[0].names)

if __name__ == "__main__":
    # Install: pip install ultralytics
    # Then run: python train_yolov8_binary_cls.py
    main()


In [None]:
# inference_yolov8_cls.py
# ------------------------------------------------------------
# Inference for YOLOv8 classification (binary or multi-class).
# Works with a trained weights file (e.g., runs/classify/train*/weights/best.pt).
# Supports: single image, directory (recursive), or a text file of image paths.
# ------------------------------------------------------------

import os
import csv
import glob
import argparse
from typing import List

import torch
from ultralytics import YOLO


def list_images(path: str) -> List[str]:
    """Collect image paths from a single file, directory or .txt file."""
    exts = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
    if os.path.isfile(path):
        if path.lower().endswith(".txt"):
            with open(path, "r", encoding="utf-8") as f:
                imgs = [ln.strip() for ln in f if ln.strip()]
            return [p for p in imgs if os.path.isfile(p)]
        elif path.lower().endswith(exts):
            return [path]
        else:
            raise ValueError(f"Unsupported file type: {path}")
    elif os.path.isdir(path):
        files = []
        for ext in exts:
            files.extend(glob.glob(os.path.join(path, "**", f"*{ext}"), recursive=True))
        return sorted(files)
    else:
        raise FileNotFoundError(f"Path not found: {path}")


def run_inference(
    weights: str,
    source: str,
    imgsz: int = 224,
    device: str = "",
    save_csv: str = "predictions.csv",
    topk: int = 5,
    conf_threshold: float = 0.0
):
    """
    Run YOLOv8 classification inference.

    Args:
        weights: Path to trained weights (e.g., best.pt or yolov8n-cls.pt).
        source: Image path, directory, or .txt listing image paths.
        imgsz: Inference image size (square).
        device: '', 'cpu', or 'cuda:0' etc. If empty, auto-select.
        save_csv: Output CSV path for results.
        topk: How many top classes to save (<= number of classes).
        conf_threshold: Optional minimum confidence to report (0~1).
    """
    # Load model
    model = YOLO(weights)
    if device:
        model.to(device)
    else:
        # Auto device
        model.to("cuda" if torch.cuda.is_available() else "cpu")

    # Resolve class names
    names = model.names  # dict like {0: 'class0', 1: 'class1', ...}
    num_classes = len(names)

    # Clamp topk
    topk = min(topk, num_classes)

    # Collect images
    images = list_images(source)
    if not images:
        print("No images found for inference.")
        return

    print(f"Running inference on {len(images)} image(s)...")

    rows = []
    for img_path in images:
        preds = model.predict(
            source=img_path,
            imgsz=imgsz,
            verbose=False
        )

        # Each item in preds corresponds to one image (we pass one image per call)
        p = preds[0]
        probs = p.probs  # classification probabilities
        if probs is None:
            print(f"[WARN] No probabilities for: {img_path}")
            continue

        # Top-1
        top1_idx = int(probs.top1)
        top1_conf = float(probs.top1conf)
        top1_name = names.get(top1_idx, str(top1_idx))

        # Optional threshold filter
        if top1_conf < conf_threshold:
            print(f"{os.path.basename(img_path)} -> below threshold ({top1_conf:.3f} < {conf_threshold})")
        else:
            print(f"{os.path.basename(img_path)} -> Top-1: {top1_name} ({top1_conf:.4f})")

        # Top-k
        # probs.data is a tensor of shape [num_classes]
        prob_tensor = probs.data
        confs, idxs = torch.topk(prob_tensor, k=topk)
        confs = confs.tolist()
        idxs = idxs.tolist()
        topk_pairs = [(names.get(i, str(i)), float(c)) for i, c in zip(idxs, confs)]

        # Build CSV row
        row = {
            "image": img_path,
            "top1_class": top1_name,
            "top1_conf": top1_conf
        }
        # Add top-k columns
        for rank, (cls_name, sc) in enumerate(topk_pairs, start=1):
            row[f"top{rank}_class"] = cls_name
            row[f"top{rank}_conf"] = sc
        rows.append(row)

    # Save CSV
    fieldnames = ["image", "top1_class", "top1_conf"] + sum(
        [[f"top{k}_class", f"top{k}_conf"] for k in range(1, topk + 1)], []
    )
    with open(save_csv, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for r in rows:
            writer.writerow(r)

    print(f"Saved predictions to: {save_csv}")


def parse_args():
    ap = argparse.ArgumentParser(description="YOLOv8 classification inference")
    ap.add_argument("--weights", type=str, required=True,
                    help="Path to weights (e.g., runs/classify/train*/weights/best.pt)")
    ap.add_argument("--source", type=str, required=True,
                    help="Image file, directory, or .txt with image paths")
    ap.add_argument("--imgsz", type=int, default=224, help="Inference image size")
    ap.add_argument("--device", type=str, default="", help="Device: '', 'cpu', 'cuda:0', etc.")
    ap.add_argument("--save-csv", type=str, default="predictions.csv", help="Output CSV path")
    ap.add_argument("--topk", type=int, default=5, help="Save top-k classes")
    ap.add_argument("--conf-threshold", type=float, default=0.0, help="Minimum top-1 confidence to report")
    return ap.parse_args()


if __name__ == "__main__":
    # Usage:
    #   pip install ultralytics
    #   python inference_yolov8_cls.py --weights runs/classify/train/weights/best.pt --source data/val
    args = parse_args()
    run_inference(
        weights=args.weights,
        source=args.source,
        imgsz=args.imgsz,
        device=args.device,
        save_csv=args.save_csv,
        topk=args.topk,
        conf_threshold=args.conf_threshold
    )
