# Inference Demo: Predicting Plant Health

This notebook demonstrates how to use a trained model to perform inference on sample images from the PlantVillage dataset. We will load the model, make predictions on a few images, and display the results.

## Step 1: Setup

In [4]:
import sys
import random
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, ConfusionMatrixDisplay, confusion_matrix

sys.path.append('..')
import config
from infer import load_model, predict_single_image

ModuleNotFoundError: No module named 'infer'

## Step 2: Load Model and Make Predictions

In [None]:


print("Loading model...")
model = load_model()
print("Model loaded successfully.")

# Use the PlantVillage validation directory
val_dir = config.DATA_PROCESSED_DIR / "PlantVillage" / "val"
if not val_dir.exists():
    raise FileNotFoundError(f"Validation directory not found: {val_dir}")

# Build class names from the validation directory
class_names = sorted([d.name for d in val_dir.iterdir() if d.is_dir()])
print(f"Found {len(class_names)} classes: {class_names[:5]}...")

# Pick a few images to test
all_images = list(val_dir.rglob("*.jpg"))
if len(all_images) == 0:
    raise RuntimeError(f"No images found in validation directory: {val_dir}")

sample_images = random.sample(all_images, min(5, len(all_images)))
print(f"\nTesting on {len(sample_images)} sample images:")

y_true = []
y_pred = []

correct = 0
for img_path in sample_images:
    try:
        result = predict_single_image(model, img_path, class_names)
        predicted = result.get('predicted_class', 'Unknown')
        confidence = result.get('confidence', 0)
        health = result.get('health', 'n/a')
        true_label = img_path.parent.name

        print(f"\nImage: {img_path.name}")
        print(f"Predicted: {predicted} (confidence: {confidence:.3f})")
        print(f"Health: {health}")
        print(f"True class: {true_label}")

        is_correct = (predicted == true_label)
        print(f"Correct: {is_correct}")
        correct += int(is_correct)

        img = plt.imread(img_path)
        plt.imshow(img)
        plt.title(f"{predicted} ({confidence:.2%}) — {health}")
        plt.axis('off')
        plt.show()

    except Exception as e:
        print(f"Error processing {img_path.name}: {e}")

# Small summary
accuracy = correct / len(sample_images)
print(f"\nSample accuracy: {accuracy:.2%}")

print("\nGenerating confusion matrix...")
cm = confusion_matrix(y_true, y_pred, labels=class_names)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
fig, ax = plt.subplots(figsize=(10, 10))
disp.plot(ax=ax, cmap='Blues', xticks_rotation=90)
plt.title("Confusion matrix — validation sample")
plt.show()

# Classification report
print("\nClassification Report:")
report = classification_report(y_true, y_pred, target_names=class_names, digits=3)
print(report)