In [1]:
from ultralytics import YOLO
from pathlib import Path
import os
import shutil

# Paths
DATA_PATH = Path("../data/images_augmented")  
PROJECT_NAME = "greenery_classification"       # name of the training run
RUNS_PATH = Path("../runs")                    # output directory

# Training parameters
MODEL_SIZE = "n"      # 'n', 's', 'm', 'l', 'x' (nano ‚Üí xlarge)
EPOCHS = 50
BATCH_SIZE = 16
IMAGE_SIZE = 128
PATIENCE = 20         # early stopping patience

CLASS_NAMES = ["greenery", "non_greenery"]

# TRAINING

def train_model():
    """Trains the YOLO classification model."""
    print("\n" + "=" * 70)
    print(" STARTING TRAINING")
    print("=" * 70)

    model_path = f"yolov8{MODEL_SIZE}-cls.pt"

    # Check pretrained weights
    weights_dir = Path("./weights")
    weights_dir.mkdir(exist_ok=True)
    weight_file = weights_dir / model_path

    if not weight_file.exists():
        import urllib.request
        url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_path}"
        print(f"‚¨á Downloading pretrained weights: {url}")
        urllib.request.urlretrieve(url, weight_file)
        print(" Download complete.")

    # Load model
    model = YOLO(str(weight_file))

    # Train
    results = model.train(
        data=str(DATA_PATH),
        epochs=EPOCHS,
        batch=BATCH_SIZE,
        imgsz=IMAGE_SIZE,
        patience=PATIENCE,
        project=str(RUNS_PATH),
        name=PROJECT_NAME,
        exist_ok=True,
        plots=True,
    )

    best_model_path = RUNS_PATH / PROJECT_NAME / "weights" / "best.pt"
    print(f"\n Best model saved at: {best_model_path}")

    return model, best_model_path

# VALIDATION + TEST

def evaluate_model(model_path):
    """Runs validation and test."""
    model = YOLO(str(model_path))

    # Validation
    val_metrics = model.val(split="val")
    print(f"\nüìä Validation Top-1 Accuracy: {val_metrics.top1:.2%}")

    # Test
    test_metrics = model.val(split="test")
    print(f"\nüìä Test Top-1 Accuracy: {test_metrics.top1:.2%}")

    return val_metrics, test_metrics

# EXAMPLE PREDICTIONS

def predict_examples(model_path, num_samples=10):
    """Predicts several random samples from the test set."""
    model = YOLO(str(model_path))

    test_images = []
    for cls in CLASS_NAMES:
        cls_images = list((DATA_PATH / "test" / cls).glob("*.jpg"))[:num_samples // 2]
        test_images.extend(cls_images)

    if not test_images:
        print(" No test images found!")
        return

    print(f"\n Predictions for {len(test_images)} sample images:\n")

    for img_path in test_images:
        results = model.predict(img_path, verbose=False)
        probs = results[0].probs

        top1_idx = probs.top1
        top1_conf = probs.top1conf.item()

        predicted_class = CLASS_NAMES[top1_idx]
        true_class = img_path.parent.name

        status = "‚úÖ" if predicted_class == true_class else "‚ùå"

        print(
            f"{status} {img_path.name:<30} "
            f"True: {true_class:<13} | Pred: {predicted_class:<13} "
            f"({top1_conf:.2%})"
        )


# MAIN

def main():
    print(" YOLO GREENERY CLASSIFICATION PIPELINE")


    if not DATA_PATH.exists():
        print(f" Dataset not found: {DATA_PATH}")
        return

    # Training
    model, best_model_path = train_model()

    # Evaluation
    val_metrics, test_metrics = evaluate_model(best_model_path)

    # Example Predictions
    predict_examples(best_model_path, num_samples=20)

    print("\n PIPELINE COMPLETE!")
    print(f"All results saved in: {RUNS_PATH / PROJECT_NAME}")


if __name__ == "__main__":
    main()


 YOLO GREENERY CLASSIFICATION PIPELINE

 STARTING TRAINING
‚¨á Downloading pretrained weights: https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt
 Download complete.
Ultralytics 8.3.228  Python-3.11.3 torch-2.9.1+cpu CPU (AMD Ryzen 5 2600 Six-Core Processor)
[34m[1mengine\trainer: [0magnostic_nms=False, amp=True, augment=False, auto_augment=randaugment, batch=16, bgr=0.0, box=7.5, cache=False, cfg=None, classes=None, close_mosaic=10, cls=0.5, compile=False, conf=None, copy_paste=0.0, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=..\data\images_augmented, degrees=0.0, deterministic=True, device=cpu, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, epochs=50, erasing=0.4, exist_ok=True, fliplr=0.5, flipud=0.0, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, imgsz=128, int8=False, iou=0.7, keras=False, kobj=1.0, line_width=None, lr0=0.01, lrf=0.01, mask_ratio=4, max_det=300, mixup=0.0, mode=train