# YOLO Image Classification Prediction

This notebook provides a script to predict image classifications using trained YOLO weights. You can specify the path to your trained weights and predict on single images or batch of images.

In [17]:
# Import required libraries
from ultralytics import YOLO
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import glob
from pathlib import Path
import pandas as pd

In [18]:
# Configuration
# Specify the path to your trained weights
WEIGHTS_PATH = "../runs/classify/train3/weights/best.pt"  # Change this to your weights path
IMAGE_PATH = "Predict/DJI_20231215160828_0110_D_000081_split_01_03.jpg"  # Single image path
BATCH_FOLDER = "."  # Folder containing multiple images for batch prediction

# Class names based on your dataset structure
CLASS_NAMES = ['catterpillar', 'healthy', 'mosaic', 'rust']

print(f"Using weights: {WEIGHTS_PATH}")
print(f"Classes: {CLASS_NAMES}")

Using weights: ../runs/classify/train3/weights/best.pt
Classes: ['catterpillar', 'healthy', 'mosaic', 'rust']


In [19]:
# Load the trained model
def load_model(weights_path):
    """
    Load YOLO model with specified weights
    """
    if not os.path.exists(weights_path):
        raise FileNotFoundError(f"Weights file not found: {weights_path}")
    
    model = YOLO(weights_path)
    print(f"Model loaded successfully from: {weights_path}")
    return model

# Load the model
model = load_model(WEIGHTS_PATH)

Model loaded successfully from: ../runs/classify/train3/weights/best.pt


In [20]:
# Single image prediction function
def predict_single_image(model, image_path, class_names=None, show_image=True):
    """
    Predict classification for a single image
    
    Args:
        model: Loaded YOLO model
        image_path: Path to the image file
        class_names: List of class names (optional)
        show_image: Whether to display the image with prediction
    
    Returns:
        Dictionary with prediction results
    """
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image file not found: {image_path}")
    
    # Run prediction
    results = model(image_path)
    
    # Extract prediction results
    result = results[0]
    
    # Get the predicted class index and confidence
    probs = result.probs
    top_class_idx = probs.top1
    confidence = probs.top1conf.item()
    
    # Get class name
    if class_names and top_class_idx < len(class_names):
        predicted_class = class_names[top_class_idx]
    else:
        predicted_class = f"Class_{top_class_idx}"
    
    # Create result dictionary
    prediction_result = {
        'image_path': image_path,
        'predicted_class': predicted_class,
        'confidence': confidence,
        'class_index': top_class_idx,
        'all_probabilities': probs.data.cpu().numpy()
    }
    
    # Display image with prediction if requested
    if show_image:
        img = Image.open(image_path)
        plt.figure(figsize=(8, 6))
        plt.imshow(img)
        plt.title(f'Prediction: {predicted_class}\nConfidence: {confidence:.2%}')
        plt.axis('off')
        plt.show()
    
    return prediction_result

In [None]:
# Batch prediction function
def predict_batch(model, folder_path, class_names=None, image_extensions=['*.jpg', '*.jpeg', '*.png', '*.bmp']):
    """
    Predict classifications for all images in a folder
    
    Args:
        model: Loaded YOLO model
        folder_path: Path to folder containing images
        class_names: List of class names (optional)
        image_extensions: List of image file extensions to process
    
    Returns:
        List of prediction dictionaries
    """
    if not os.path.exists(folder_path):
        raise FileNotFoundError(f"Folder not found: {folder_path}")
    
    # Get all image files
    image_files = []
    for ext in image_extensions:
        image_files.extend(glob.glob(os.path.join(folder_path, ext)))
        image_files.extend(glob.glob(os.path.join(folder_path, ext.upper())))
    
    if not image_files:
        print("No image files found in the specified folder.")
        return []
    
    print(f"Found {len(image_files)} image(s) for prediction...")
    
    batch_results = []
    
    for i, image_path in enumerate(image_files, 1):
        print(f"Processing {i}/{len(image_files)}: {os.path.basename(image_path)}")
        
        try:
            result = predict_single_image(model, image_path, class_names, show_image=False)
            batch_results.append(result)
            print(f"  → {result['predicted_class']} ({result['confidence']:.2%})")
        except Exception as e:
            print(f"  → Error: {str(e)}")
            continue
    
    return batch_results

In [21]:
# Results analysis and export functions
def analyze_batch_results(batch_results, class_names=None):
    """
    Analyze and summarize batch prediction results
    """
    if not batch_results:
        print("No results to analyze.")
        return None
    
    # Create DataFrame for analysis
    df_data = []
    for result in batch_results:
        df_data.append({
            'filename': os.path.basename(result['image_path']),
            'predicted_class': result['predicted_class'],
            'confidence': result['confidence'],
            'class_index': result['class_index']
        })
    
    df = pd.DataFrame(df_data)
    
    # Print summary statistics
    print("=== BATCH PREDICTION SUMMARY ===")
    print(f"Total images processed: {len(batch_results)}")
    print(f"Average confidence: {df['confidence'].mean():.2%}")
    print(f"Minimum confidence: {df['confidence'].min():.2%}")
    print(f"Maximum confidence: {df['confidence'].max():.2%}")
    
    print("\n=== CLASS DISTRIBUTION ===")
    class_counts = df['predicted_class'].value_counts()
    for class_name, count in class_counts.items():
        percentage = (count / len(df)) * 100
        print(f"{class_name}: {count} images ({percentage:.1f}%)")
    
    # Plot class distribution
    plt.figure(figsize=(10, 6))
    class_counts.plot(kind='bar')
    plt.title('Predicted Class Distribution')
    plt.xlabel('Class')
    plt.ylabel('Number of Images')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    return df

def export_results(batch_results, output_file='prediction_results.csv'):
    """
    Export batch prediction results to CSV file
    """
    if not batch_results:
        print("No results to export.")
        return
    
    df_data = []
    for result in batch_results:
        row = {
            'filename': os.path.basename(result['image_path']),
            'full_path': result['image_path'],
            'predicted_class': result['predicted_class'],
            'confidence': result['confidence'],
            'class_index': result['class_index']
        }
        
        # Add individual class probabilities
        if 'all_probabilities' in result:
            for i, prob in enumerate(result['all_probabilities']):
                class_name = CLASS_NAMES[i] if i < len(CLASS_NAMES) else f"Class_{i}"
                row[f'prob_{class_name}'] = prob
        
        df_data.append(row)
    
    df = pd.DataFrame(df_data)
    df.to_csv(output_file, index=False)
    print(f"Results exported to: {output_file}")
    return df

## Usage Examples

Below are examples of how to use the prediction functions:

In [22]:
# Example 1: Predict single image
if os.path.exists(IMAGE_PATH):
    print("=== SINGLE IMAGE PREDICTION ===")
    result = predict_single_image(model, IMAGE_PATH, CLASS_NAMES, show_image=True)
    
    print("\nDetailed Results:")
    print(f"Image: {result['image_path']}")
    print(f"Predicted Class: {result['predicted_class']}")
    print(f"Confidence: {result['confidence']:.2%}")
    print(f"Class Index: {result['class_index']}")
    
    # Show all class probabilities
    print(f"\nAll Class Probabilities:")
    for i, prob in enumerate(result['all_probabilities']):
        class_name = CLASS_NAMES[i] if i < len(CLASS_NAMES) else f"Class_{i}"
        print(f"  {class_name}: {prob:.2%}")
else:
    print(f"Image file not found: {IMAGE_PATH}")
    print("Please update IMAGE_PATH in the configuration section.")

Image file not found: Predict/DJI_20231215160828_0110_D_000081_split_01_03.jpg
Please update IMAGE_PATH in the configuration section.


In [None]:
# Example 2: Batch prediction on all images in current folder
print("=== BATCH PREDICTION ===")
batch_results = predict_batch(model, BATCH_FOLDER, CLASS_NAMES)

if batch_results:
    # Analyze results
    df_results = analyze_batch_results(batch_results, CLASS_NAMES)
    
    # Export results to CSV
    export_results(batch_results, 'prediction_results.csv')
    
    # Display first few results
    print("\n=== FIRST 5 PREDICTIONS ===")
    for i, result in enumerate(batch_results[:5]):
        print(f"{i+1}. {os.path.basename(result['image_path'])} → {result['predicted_class']} ({result['confidence']:.2%})")
else:
    print("No images found for batch prediction.")

In [None]:
# Example 3: Custom prediction with different weights or images
def custom_prediction(weights_path, image_path_or_folder, is_batch=False):
    """
    Custom prediction function that allows specifying different weights and images
    """
    try:
        # Load custom model
        custom_model = load_model(weights_path)
        
        if is_batch:
            print(f"Running batch prediction on folder: {image_path_or_folder}")
            results = predict_batch(custom_model, image_path_or_folder, CLASS_NAMES)
            if results:
                analyze_batch_results(results, CLASS_NAMES)
                export_results(results, f'custom_prediction_results.csv')
        else:
            print(f"Running single prediction on image: {image_path_or_folder}")
            result = predict_single_image(custom_model, image_path_or_folder, CLASS_NAMES, show_image=True)
            return result
            
    except Exception as e:
        print(f"Error in custom prediction: {str(e)}")
        return None

# Uncomment and modify the lines below to use custom prediction:
# custom_prediction("../runs/classify/train2/weights/best.pt", "your_image.jpg", is_batch=False)
# custom_prediction("../runs/classify/train2/weights/best.pt", "/path/to/image/folder", is_batch=True)

print("Custom prediction function defined. Uncomment the lines above to use it.")